In [1]:
import copy

### **"Much glory awaits someone who can delete the need for tokenization" -- (Andrej Karpathy)**

The tokenizer is a completely separate, independent module from the LLM. It has its own training dataset of text (which could be different from that of the LLM), on which the vocabulary is trained using the Byte Pair Encoding (BPE) algorithm. It then translates back and forth between raw text and sequence of tokens. The LLM only ever sees the tokens and never directly deals with any text.

<div align="center">
  <img src="../assets/tokenizer-llm-diagram.jpg" width="500"/>
</div>

# 1. Strings in Python

According to Python's documentation, "strings are immutable *sequences* of *Unicode code points*". The function to access the Unicode code point of a character is `ord()`. The function to access the character of a Unicode code point is `chr()`. Also, Unicode text is processed and stored as binary data *using one of several encodings*: `UTF-8`, `UTF-16`, `UTF-32`, among others. Of these, `UTF-8` is the most widely used, in part due to its backwards-compatibility with ASCII. The function to encode a string into a binary data is `encode()`. The function to decode a binary data into a string is `decode()`.

`UTF-8` means *Unicode Transformation Format - 8 bit* and supports all valid Unicode code points using a *variable-width encoding* of one to four one-byte code units. Code points with lower numerical values, which tend to occur more frequently, are encoded using fewer bytes. In the following table, the characters `u` to `z` are replaced by the bits of the code point, from the positions U+uvwxyz:

<div align="center">
  <img src="../assets/utf8-encoding.jpg" width="700"/>
</div>

Examples:
- U+0041 (‘A’) → 01000001 → 01000001 (same as ASCII)
- U+00A9 (‘©’)	→ 1010001001 → 11010100 10010001

Now, considering that `UTF-8` is represented as byte streams, it implies a maximum vocabulary length of 256 possible tokens. This means tiny embedding tables, counterweighted by very long sequences of tokens, which can be a hindrance to context length in transformer-based neural networks, where each token needs to attend to all other tokens in the sequence.

In [2]:
unicode_enc = [ord(x) for x in '안녕하세요']
unicode_enc

[50504, 45397, 54616, 49464, 50836]

In [3]:
utf8_enc = '안녕하세요'.encode('utf-8')
utf8_enc, list(utf8_enc)

(b'\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94',
 [236, 149, 136, 235, 133, 149, 237, 149, 152, 236, 132, 184, 236, 154, 148])

In [4]:
print('Unicode length: ', len(unicode_enc))
print('UTF-8 length: ', len(utf8_enc))

Unicode length:  5
UTF-8 length:  15


# 2. Byte Pair Encoding (BPE)

This algorithm was first described in 1994, by Philip Gage, for encoding strings of text into smaller strings by creating and using a translation table. It builds "tokens" (units of recognition) that match varying amounts of source text, from single characters (including single digits or single punctuation marks) to whole words (even long compound words).

Suppose the data to be encoded is:

```
aaabdaaabac
```

The byte pair "aa" occurs most often, so it is merged into a single token:

```
ZabdZabac
Z = aa
```

The process is repeated with byte pair "ab", replacing it with Y:

```
ZYdZYac
Y = ab
Z = aa
```

Finally, the byte pair "ZY" is merged into a single token X:

```
XdXac
X = ZY
Y = ab
Z = aa
```

The data cannot be compressed further because there are no pairs of bytes that occur more than once. We started with 11 bytes and 4 tokens, and ended with 5 bytes and 6 tokens.

In [5]:
with open('../data/unicode.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print('Number of characters in the text: ', len(text))

Number of characters in the text:  1414


In [6]:
# NOTE: unicode text encoded in utf-8 has up to 4 bytes per character
tokens = list(map(int, text.encode('utf-8')))
print('Number of single tokens in the text: ', len(tokens))

Number of single tokens in the text:  2058


In [7]:
unique_tokens = set(tokens)
print('Number of unique tokens in the text: ', len(unique_tokens))
print('Max token: ', max(unique_tokens))

Number of unique tokens in the text:  105
Max token:  240


In [8]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
print('Number of unique bigrams: ', len(stats))
print('Most common bigrams: ', sorted(stats.items(), key=lambda x: x[1], reverse=True)[:5])
print('Most common bigram in text: ', (chr(101), chr(32)))

Number of unique bigrams:  617
Most common bigrams:  [((101, 32), 24), ((204, 173), 18), ((205, 153), 18), ((204, 178), 18), ((115, 32), 17)]
Most common bigram in text:  ('e', ' ')


In [9]:
# merging the most common pair
top_pair = max(stats, key=stats.get)
top_pair

(101, 32)

In [10]:
def merge(tokens: list, pair: tuple[int, int], new_token: int) -> list:
    """Merges the most common pair in the given list of tokens into a single token."""
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

print('Example tokens before merging: ', ex_tokens := [5, 6, 6, 7, 9, 1])
print('Example tokens after merging: ', merge(ex_tokens, (6, 6), 10))

Example tokens before merging:  [5, 6, 6, 7, 9, 1]
Example tokens after merging:  [5, 10, 7, 9, 1]


In [11]:
merged_tokens = merge(tokens, top_pair, max(unique_tokens) + 1)
print('Number of tokens before merging: ', len(tokens))
print('Number of tokens after merging: ', len(merged_tokens))
print('Number of unique tokens before merging: ', len(unique_tokens))
print('Number of unique tokens after merging: ', len(set(merged_tokens)))
print('Max token before merging: ', max(unique_tokens))
print('Max token after merging: ', max(set(merged_tokens)))

Number of tokens before merging:  2058
Number of tokens after merging:  2034
Number of unique tokens before merging:  105
Number of unique tokens after merging:  106
Max token before merging:  240
Max token after merging:  241


### 2.1. Training the tokenizer

In [12]:
vocab_size = 276                 # desired number of unique tokens in vocabulary
max_tokens_per_byte = 2 ** 8     # encoding string into utf-8 converts characters into bytes
num_merges = vocab_size - max_tokens_per_byte
trainable_tokens = copy.deepcopy(tokens)

In [13]:
# `bpe_forest` is an inverted tree that stores merges: (int, int) -> int
bpe_forest = {}
for i in range(num_merges):
    stats = get_stats(trainable_tokens)
    top_pair = max(stats, key=stats.get)
    new_token = max_tokens_per_byte + i
    print(f'Merging pair {top_pair} into new token {new_token}')
    trainable_tokens = merge(trainable_tokens, top_pair, new_token)
    bpe_forest[top_pair] = new_token

Merging pair (101, 32) into new token 256
Merging pair (204, 173) into new token 257
Merging pair (205, 153) into new token 258
Merging pair (204, 178) into new token 259
Merging pair (115, 32) into new token 260
Merging pair (204, 171) into new token 261
Merging pair (204, 177) into new token 262
Merging pair (240, 159) into new token 263
Merging pair (205, 136) into new token 264
Merging pair (204, 185) into new token 265
Merging pair (226, 128) into new token 266
Merging pair (105, 110) into new token 267
Merging pair (205, 150) into new token 268
Merging pair (204, 187) into new token 269
Merging pair (205, 135) into new token 270
Merging pair (204, 188) into new token 271
Merging pair (204, 164) into new token 272
Merging pair (204, 166) into new token 273
Merging pair (97, 110) into new token 274
Merging pair (204, 176) into new token 275


In [14]:
print('Number of unique tokens before BPE: ', len(unique_tokens))
print('Number of unique tokens after BPE: ', len(set(trainable_tokens)))
print(f'Compression rate: {len(set(trainable_tokens)) / len(unique_tokens):.2f}X')

Number of unique tokens before BPE:  105
Number of unique tokens after BPE:  113
Compression rate: 1.08X


### 2.2. Decoding tokens into strings

UTF-8 follows a specific schema that bytes can take, which is used to encode and decode strings. Per this schema, a multi-byte character must follow certain rules as to how each byte is structured (see section 1 above). In order to avoid running into errors, the binary decode function can take a `errors` argument, which can be set to `replace`, which replaces any byte that cannot be decoded to a Unicode character with a question mark.

In [15]:
vocab = {i: bytes([i]) for i in range(max_tokens_per_byte)}
for (i, j), new_token in bpe_forest.items():
    vocab[new_token] = vocab[i] + vocab[j]
print('Number of tokens in the vocabulary: ', len(vocab))

Number of tokens in the vocabulary:  276


In [16]:
def decode(tokens: list[int]) -> str:
    """Decodes the given list of tokens using the given vocabulary."""
    binary = b''.join(vocab[token] for token in tokens)
    return binary.decode('utf-8', errors='replace')

print('Decoded token: ', decode([97]))

Decoded token:  a


In [17]:
decode([128])

'�'

### 2.3. Encoding strings into tokens

In [18]:
text = 'hello software engineering'
print('Length of the text: ', len(text))

Length of the text:  26


In [19]:
def encode(text: str) -> list[int]:
    """Encodes the given text using the given vocabulary."""
    tokens = list(text.encode('utf-8'))

    while len(tokens) > 1:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: bpe_forest.get(p, float('inf')))
        if pair not in bpe_forest:
            break
          
        new_token = bpe_forest[pair]
        tokens = merge(tokens, pair, new_token)

    return tokens

print('Encoded tokens: ', encode(text))
print('Length of the encoded tokens: ', len(encode(text)))

Encoded tokens:  [104, 101, 108, 108, 111, 32, 115, 111, 102, 116, 119, 97, 114, 256, 101, 110, 103, 267, 101, 101, 114, 267, 103]
Length of the encoded tokens:  23


In [20]:
print('Decoded text: ', decode(encode(text)))

Decoded text:  hello software engineering


# 3. Regex patterns to force splits across categories

This section is based on the following excerpt from the GPT-2 paper: "We observed BPE including many versions of common words like `dog` since they occur in many variations such as `dog.`, `dog!` and `dog?`. This results in a sub-optimal allocation of limited vocabulary slots and model capacity. To avoid this, we prevent BPE from merging across character categories for any byte sequence". 

In order to prevent BPE from merging across character categories, regex patterns are used to force splits across categories and then tokenization can be performed on the resulting splits. In the end, the results of that processing are concatenated back together. This way, byte-pair merges can only happen within the same category.

In [21]:
import regex as re

pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

In [22]:
print(re.findall(pattern, "It's aren't they're they've I'm I'll He'd        Hello123 World!?!?"))

['It', "'s", ' aren', "'t", ' they', "'re", ' they', "'ve", ' I', "'m", ' I', "'ll", ' He', "'d", '       ', ' Hello', '123', ' World', '!?!?']


### 3.1. `Tiktoken` library intro

In [23]:
import tiktoken

In [24]:
text = '    hello world!!!'

# GPT-2 (does not merge spaces)
gpt2_encoding = tiktoken.get_encoding('gpt2')
print('GPT-2 encoding: ', gpt2_encoding.encode(text))

# GPT-4 (merges spaces)
gpt4_encoding = tiktoken.get_encoding('cl100k_base')
print('GPT-4 encoding: ', gpt4_encoding.encode(text))

GPT-2 encoding:  [220, 220, 220, 23748, 995, 10185]
GPT-4 encoding:  [262, 24748, 1917, 12340]


# 4. GPT-2 `encoder.py` walkthrough

References: 
- Code repository: https://github.com/openai/gpt-2/blob/master/src/encoder.py
- Vocabulary: https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json
- BPE merges: https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe

In [30]:
import os
import json
import regex as re

In [31]:
with open('../data/encoder.json', 'r', encoding='utf-8') as f:
    vocab = json.load(f)

print('Number of tokens in the vocab: ', len(vocab))

Number of tokens in the vocab:  50257


In [32]:
with open('../data/vocab.bpe', 'r', encoding='utf-8') as f:
    bpe_data = f.read()
    
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
print('Number of BPE merges: ', len(bpe_merges))
print('First 10 BPE merges: ', bpe_merges[:10])

Number of BPE merges:  50000
First 10 BPE merges:  [('Ġ', 't'), ('Ġ', 'a'), ('h', 'e'), ('i', 'n'), ('r', 'e'), ('o', 'n'), ('Ġt', 'he'), ('e', 'r'), ('Ġ', 's'), ('a', 't')]


In [37]:
# adapted code from original at https://github.com/openai/gpt-2/blob/master/src/encoder.py
class Tokenizer:
    def __init__(self, vocab, bpe_merges, errors='replace'):
        self.vocab = vocab
        self.decoder = {v:k for k,v in self.vocab.items()}
        self.errors = errors # how to handle errors in decoding
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def get_pairs(word):
        pairs = set()
        prev_char = word[0]
        for char in word[1:]:
            pairs.add((prev_char, char))
            prev_char = char
        return pairs

    def bpe(self, token):
        word = tuple(token)
        pairs = self.get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = self.get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.vocab[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

In [38]:
gpt2_tokenizer = Tokenizer(vocab=vocab, bpe_merges=bpe_merges)

In [39]:
gpt2_tokenizer.encode('hello world')

AttributeError: 'Tokenizer' object has no attribute 'byte_encoder'

# 17. special tokens, tiktoken handling of, GPT-2/GPT-4 differences
# 18. minbpe exercise time! write your own GPT-4 tokenizer
# 19. sentencepiece library intro, used to train Llama 2 vocabulary
# 20. how to set vocabulary set? revisiting gpt.py transformer
# 21. training new tokens, example of prompt compression
# 22. multimodal [image, video, audio] tokenization with vector quantization
# 23. revisiting and explaining the quirks of LLM tokenization
# 24. final recommendations

# Sources

1. [Ground truth - Let's build the GPT Tokenizer, by Andrej Karpathy](https://www.youtube.com/watch?v=zduSFxRajkE&t=38s)
2. [A programmer's introduction to Unicode, by Nathan Reed](https://www.reedbeta.com/blog/programmers-intro-to-unicode)
3. [Language models are unsupervised multitask learners [GPT-2 paper], by Alec Radford; Dario Amodei; Ilya Sutskever; et al.](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)