## Basic Byte-Pair-Encoding(BPE)

In [31]:
class BasicBPE:

    def __init__(self, vocab_size: int):
        self.vocab_size = vocab_size
        self.tokens_map = None
        self.merges = dict()

    def train(self, text):
        tokens = text.encode('utf-8')
        tokens = list(map(int, tokens))

        while len(self.merges) + 256 < self.vocab_size:
            tokens_stats = compute_pair_of_tokens_stats(tokens)
            most_common_pair = max(tokens_stats, key=tokens_stats.get)
            new_token_id = len(self.merges) + 1 + 256 # UTF-8 has 256 ints
            self.merges[most_common_pair] = new_token_id

            tokens = merge_pair(tokens, pair=most_common_pair, idx=new_token_id)

        self.tokens_map = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            self.tokens_map[idx] = self.tokens_map[p0] + self.tokens_map[p1]
    
    def encode(self, text):
        tokens = list(text.encode('utf-8'))
        while len(tokens) >= 2:
            stats = compute_pair_of_tokens_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = merge_pair(tokens, pair, idx)
        return tokens

    def decode(self, tokens):
        text = b"".join(self.tokens_map[x] for x in tokens)
        text = text.decode("utf-8", errors="replace")
        return text


def compute_pair_of_tokens_stats(tokens):

    info = {}
    for pair in zip(tokens, tokens[1:]):
        if pair not in info.keys():
            info[pair] = 1
        else:
            info[pair] += 1
    return info

def merge_pair(tokens, pair, idx):
    new_list = []
    i = 0
    total_tokens = len(tokens)
    while i < total_tokens:
        if i < total_tokens - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
            new_list.append(idx)
            i += 2
        else:
            new_list.append(tokens[i])
            i += 1
    return new_list

text = "The Tokenizer is a necessary and pervasive component of Large Language Models (LLMs), where it translates between strings and tokens (text chunks). Tokenizers are a completely separate stage of the LLM pipeline: they have their own training sets, training algorithms (Byte Pair Encoding), and after training implement two fundamental functions: encode() from strings to tokens, and decode() back from tokens to strings. In this lecture we build from scratch the Tokenizer used in the GPT series from OpenAI. In the process, we will see that a lot of weird behaviors and problems of LLMs actually trace back to tokenization. We'll go through a number of these issues, discuss why tokenization is at fault, and why someone out there ideally finds a way to delete this stage entirely."

bpe = BasicBPE(vocab_size=300)
bpe.train(text)
decoding_example = bpe.decode([270,260])
print("Decoded example: ", decoding_example)
encoding_example = bpe.encode(text[:5])
print("Encoded example: ", encoding_example)

Decoded example:  omen
Encoded example:  [84, 104, 258, 84]


In [29]:
print(bpe.decode(bpe.encode("hello world")))

hello world


In [30]:
tokens_level_utf8 = text.encode("utf-8")
tokens_level_bpe = bpe.encode(text)
print(f"Number of tokens using UFT-8: {len(tokens_level_utf8)}")
print(f"Number of tokens using BPE: {len(tokens_level_bpe)}")
print(f"Compression ratio: {len(tokens_level_utf8)/len(tokens_level_bpe)*100}")

Number of tokens using UFT-8: 781
Number of tokens using BPE: 484
Compression ratio: 161.36363636363635


# Regex Byte-Pair-Encoding (RegexBPE)