# GPT Tokenizer from Scratch

In this notebook, I created a tokenizer from scratch following Andrej Karpathy's tutorial on the same.

In [2]:
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."

tokens = text.encode('utf-8')
tokens = list(map(int, tokens))
len(text), len(tokens)

(533, 616)

### Detour: Unicode and UTF

One key thing to keep in mind while thinking about Unicode and UTF-8 is that Unicode just defines a character to integer mapping. It does not define any specific binary representation for the character.

UTF-8 on the other hand takes the mapping provided by Unicode and provides a binary representation for it. If all the unicode characters are to be represented using a naive way, then what's going to happen is that for each character you would need about 32 bits, which is extremely wasteful.

So what UTF-8 does is that it specifies a variable length encoding. Each character is represented as a *stream* of bytes. But you would also need a way to specify the starting point, number of bytes, etc. for each character that you are encoding.


If the character is ASCII ( has integer 0 to 127 ), the UTF-8 representation of it starts as: ```0yyyzzzz```

For other characters, what UTF-8 does is that the first byte specifies how many bytes are there in the sequence. This first byte is part of that.

For example, consider the following UTF-8 encoding.

```
1110xxxx 10xxxxxx 10xxxxxx
```



Here, the initial byte has 3 ones followed by a zero. That means that for this character, there are 3 bytes in total. `x` is the actual binary representation, and `10` in the subsequent bytes represents the *continuation*.

UTF-8 uses 8-bit values in its encoding. There are 16-bit, 32-bit versions also but they are not used due to compatibility issues.

UTF-8 uses the following rules:

- If the code point is < 128, it's represented by the corresponding byte value. That is, it will be returned as a single byte object, and not as a list / sequence of bytes.

- If the code point is >= 128, it's turned into a sequence of two, three, or four bytes, where each byte of the sequence is between 128 and 255. That is, to extract the sequence, you will need lists.

Think about what this means, though. Because you have specified this particular format, there are going to be some numbers as per UTF that cannot be converted to any character. For example, take the number `10000000`. This is one byte and it doesn't start with a 0. So UTF-8 doesn't have a way of decoding this number!

So what should it do? In Python, the `encode` and `decode` functions take an optional error parameter. If the number cannot be decoded, then we can replace that particular number with a `�`.

I have referred to the following videos and articles to know a bit more about Unicode and UTF-8.

- [Computerphile on Unicode & UTF-8](https://youtu.be/MijmeoH9LT4?si=9ZJIAK8xHRQbrEsz)
- [UTF-8 Wikipedia](https://en.wikipedia.org/wiki/UTF-8)
- [Unicode Python HOWTO](https://docs.python.org/3/howto/unicode.html)

In [4]:
s = '😉' # Hex Repr in Unicode: 'f09f9889', binary Repr in Unicode: [240, 159, 152, 137]. In hex, two characters represent one byte. # F0 in hex = 240, 9F in hex = 159, etc.

# Encode as a stream of bytes since the Unicode representation is >= 128
utf_enc = s.encode('utf-8')
utf_enc_hex = s.encode('utf-8').hex()
utf_enc_list = list(s.encode('utf-8'))

# encode a single byte
utf_single_byte = 'A'.encode('utf-8')

# encode with invalid encoding
invalid_enc = b'\x80abc'.decode('utf-8', 'replace')

invalid_enc, utf_single_byte, utf_enc_list

('�abc', b'A', [240, 159, 152, 137])

## Byte Pair Encoding

What BPE does is similar to huffman coding. You iterate over the text and find out which byte pairs are occurring most frequently. Then you merge those byte pairs. Here, by byte pairs, we mean the two consecutive pairs of bytes.

This function is usually called `get_stats` and that's what we are also calling it here.

In [5]:
from collections import Counter

def get_stats(ids):
    counts = Counter(zip(ids, ids[1:]))
    return counts

counts = get_stats(tokens)
top_pair = counts.most_common(1)[0][0]
top_pair

(101, 32)

The way to interpret this would be to say that the most common pair of bytes in this sequence has ids (101) followed by (32), and this has count equal 20. We can find the characters by using `chr` function in Python. It happens that these two characters are `e` followed by a space.

Now we can write a merge function that replaces every pair of `(101, 32)` with some new character. Notice that even though some characters have multiple bytes, when we are thinking in terms of integer tokens for those, we still have a single byte for each number. None of the numbers in our list are > 255. Some fancy characters like emojis may have multiple numbers, one after the other, but they are still within the range `[0, 255]`. Thus, if we want to create a new token that indicates a merged character `(101, 32)`, then we have to assign this merged character the number 256. The number 256 represents (101, 32).

In [6]:
def merge(ids, pair, new_idx):
    new_ids = [ ]
    i = 0

    while i < len(ids):
        # replace all instances of the pair with the new_idx
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_ids.append(new_idx)
            i += 2
        # append all other tokens as is
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [7]:
tokens_after_one_merge = merge(tokens, top_pair, new_idx=256)
len(tokens_after_one_merge)

596

As we can see, the number of tokens now is reduced since we merged a few tokens.

This was one merge. If we do this iteratively, we will get more tokens and that's that!

How many times should you do the merge operation? That's a hyperparameter based on hardware, etc. constraints. The more the number of tokens, the greater the storage and compute requirement. But the smaller the number of tokens, the shorter the vocabulary but the bigger the sequence length.

GPT-4 uses around 100k tokens in the vocabulary.

In [8]:
desired_vocab_size = 276
num_of_merges = desired_vocab_size - 256 # because we already have 256 tokens in our vocab
ids = list(tokens) # make a copy of the original tokens list

merges = { } #  (int, int) -> int
for i in range(num_of_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    new_idx = 256 + i

    print(f"Merging {pair} into a new token {new_idx}")
    ids = merge(ids, pair, new_idx=new_idx)
    merges[pair] = new_idx

Merging (101, 32) into a new token 256
Merging (240, 159) into a new token 257
Merging (226, 128) into a new token 258
Merging (105, 110) into a new token 259
Merging (115, 32) into a new token 260
Merging (97, 110) into a new token 261
Merging (116, 104) into a new token 262
Merging (257, 133) into a new token 263
Merging (257, 135) into a new token 264
Merging (97, 114) into a new token 265
Merging (239, 189) into a new token 266
Merging (258, 140) into a new token 267
Merging (267, 264) into a new token 268
Merging (101, 114) into a new token 269
Merging (111, 114) into a new token 270
Merging (116, 32) into a new token 271
Merging (259, 103) into a new token 272
Merging (115, 116) into a new token 273
Merging (261, 100) into a new token 274
Merging (32, 262) into a new token 275


Notice how merged tokens can be merged even further.

If you think about it, you are creating a binary forest of the merges. Each time you're merging two tokens, so you have two children and a new parent token, and you're doing this from the leaves up. Not all tokens are going to get merged into a single tree, like Huffman encoding, but you're going to end up with several binary trees.



**Compression Ratio:**

Initially, you start off with all characters in your vocabulary. Tokenizing in this way gives us a length that is equal to the string length ( or slightly more due to multi-byte encoding of some characters ). But after merging, if you tokenize again, you're going to get a token length that is less than the original text.

This reduction is measured by compression ratio:
$$\dfrac{len(tokens)}{len(newtokens)}$$


The more the number of merges, the greater this compression ratio would be.

Tokenizer are **completely** separate stage than the large language model. Typically, this is the preprocessing stage, which may have its own data, and training stage. Once it is trained on some corpus, you can use this tokenizer to encode and decode the text which can be used with the LLM.

Typically, you run tokenizer on all the raw text data that you have gathered to train your LLM on. Once you have the tokens, you can store them on disk and get rid of the text data, and work with tokens hereonafter.

Other considerations are the languages to support, different encodings, etc. that you want in your language model when training this tokenizer.

## Encoding and Decoding

Now that we have a training algorithm that will give us the tokens, we would want to encode text into tokens and decode from the tokens.

### Decoding

In the decoding stage, we would want to accept a sequence of integers in the range $ [0, \text{vocabsize}] $, and produce the corresponding text.

In [14]:
# Construct an intermediate variable for integer to bytes mapping
vocab = {idx: bytes([idx]) for idx in range(256)}

# Add merged pairs to the integer-bytes mapping
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1] # Concatenate bytes of the pairs

def decode(ids):
    # Get the bytes representation of the idx
    tokens = b"".join(vocab[idx] for idx in ids)

    # Decode the bytes into a string
    text = tokens.decode('utf-8', 'replace')
    return text

decode([97, 128, 98])

'a�b'

If your LLM predicts bad tokens, then you might not get valid utf-8 tokens. So that's why we need the `replace` parameter because otherwise we won't be able to decode the output.

### Encoding

Given a text, we want to convert it into a list of integers. But remember some of the characters are now merged. So we need to merge the text as well in the same order in which we merged when we trained the tokenizer. 

That means we want to find the pair which has the pair which has the lowest index in the merges dictionary (because merges was `(pair): idx`).

In [16]:
def encode(text:str):
    tokens = list(text.encode('utf-8'))

    while len(tokens) >= 2: # number of tokens needs to be at least 2 if it needs to be considered as a pair
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p:merges.get(p, float('inf')))
        
        if pair not in merges:
            break # nothing to merge
        
        # If something to merge, replace old tokens
        new_idx = merges[pair]
        tokens = merge(tokens, pair, new_idx)

    return tokens
encode("hello world")

([104, 101, 108, 108, 111, 32, 119, 270, 108, 100], 'hello world')