# Tokenization
- Converting strings to integers so that language models can read them

In [2]:
train_text = open("../data/taylorswift.txt", "r").read()
len(train_text)

185561

In [3]:
all_unique_chars = ''.join(sorted(list(set(train_text))))
all_unique_chars

'\t\n !"#$&\'()+,-./0123456789:;?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz£®áéíñö–—•™'

In [4]:
stoi = {s : i for i, s in enumerate(all_unique_chars)}
itos = {i : s for i, s in enumerate(all_unique_chars)}

In [5]:
encoder = lambda s: [stoi[chr] for chr in s]
decoder = lambda d: "".join([itos[i] for i in d])

In [6]:
decoder(encoder("Hi there"))

'Hi there'

# Building BPE tokenizer from scratch

In [7]:
from typing import List, Union
import json

# Unicode code points
- Unicode is a world dictionary to represent any chararcter into integer
- ASCII is a subset of Unicode code points
- Only one integer for one character

In [11]:
[ord(x) for x in train_text[:10]] #--> uni code codepoints for each character

[67, 111, 112, 121, 32, 112, 97, 115, 116, 101]

In [12]:
ord('😄')

128516

In [13]:
chr(128516)

'😄'

## Drawbacks
- Pretty much alive in terms of update
- Have more than 1M+ characters in it
- Causing embedding table too huge to train
- Hence not recommended to use directly for tokenization
- Ref: https://en.wikipedia.org/wiki/Unicode

# UTF Encodings

- There multi utf encodings like utf-8, utf-16, utf-32.
- We will discuss utf-8, famously used for various LLMs
- Converts each code points into its raw bytes
- Can go upto 4 bytes
- 1 byte = 8 bits (number ranging from 0-255)
- Common ASCII characters like a-zA-Z 0-9 and somes symbols take 1 byte
- characters like 'é' take 2 bytes
- Some language's characters take 3 bytes like Hindi, Chinese, Arabic etc
- Emojis take 4 bytes to store in the machine
- Ref: https://en.wikipedia.org/wiki/UTF-8

In [14]:
list('a'.encode('utf-8'))

[97]

In [15]:
list('é'.encode('utf-8'))

[195, 169]

In [16]:
list(('न').encode('utf-8'))

[224, 164, 168]

In [17]:
list(('😄').encode('utf-8'))

[240, 159, 152, 132]

In [18]:
# Encoding the whole word
print(list(('नमस्ते, आप कैसे हैं').encode('utf-8')))

[224, 164, 168, 224, 164, 174, 224, 164, 184, 224, 165, 141, 224, 164, 164, 224, 165, 135, 44, 32, 224, 164, 134, 224, 164, 170, 32, 224, 164, 149, 224, 165, 136, 224, 164, 184, 224, 165, 135, 32, 224, 164, 185, 224, 165, 136, 224, 164, 130]


In [19]:
# Decoding
b = list(('नमस्ते, आप कैसे हैं').encode('utf-8'))
bytes(b).decode('utf-8')

'नमस्ते, आप कैसे हैं'

## Drawbacks
- Does not make any sense semantically
- Have only 0 - 255 tokens in vocabulary
- Not every sequence of bytes is valid UTF-8 encoding

# Byte Pair Encodings
- https://en.wikipedia.org/wiki/Byte-pair_encoding

In [21]:
ids = train_text.encode('utf-8')
print("Length of string chars:", len(train_text))
print("Lenght of utf-8 bytes:", len(ids))

Length of string chars: 185561
Lenght of utf-8 bytes: 185768


In [22]:
def get_stats(ids: List) -> dict:
    """
    Given a list of integers, return a dictionary to give the count of the pairs coming consecutively
    """
    stats_dict = {}
    for pair in zip(ids, ids[1:]):
        stats_dict[pair] = stats_dict.get(pair, 0) + 1
    return stats_dict

In [23]:
stats_dict = get_stats(ids)

In [24]:
max(stats_dict, key=stats_dict.get), min(stats_dict, key=stats_dict.get)

((101, 32), (10, 45))

In [25]:
stats_dict[(101, 32)]

2981

In [26]:
def merge(ids, pair, new_idx):
    """Replace the pair at all the places with the new index
    """
    if len(ids) == 1:
        return ids
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
            new_ids.append(new_idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [27]:
ids = [1, 2, 1, 2, 3, 1, 2]
merge(ids, (1, 2), 11)

[11, 11, 3, 11]

In [28]:
n_vocab = 1000
new_merges = n_vocab - 256
train_raw_bytes = list(train_text.encode('utf-8'))
print(f"Length of raw bytes: {len(train_raw_bytes)}")
i = 0
merge_dict = {}
while i < new_merges:
    stats_dict = get_stats(train_raw_bytes)
    top_pair = max(stats_dict, key = stats_dict.get)
    new_token = 256 + i
    train_raw_bytes = merge(train_raw_bytes, top_pair, 256 + i)
    merge_dict[new_token] = top_pair
    i += 1
print(f"Length of merged bytes: {len(train_raw_bytes)}")

Length of raw bytes: 185768
Length of merged bytes: 58300


In [29]:
# Encoding
s = "Hi there, how are you? नमस्ते, आप कैसे हैं"
def encode(s):
    s_raw_bytes = list(s.encode('utf-8'))
    print(f"Length of raw bytes: {len(s_raw_bytes)}")
    for idx in merge_dict:
        s_raw_bytes = merge(s_raw_bytes, merge_dict[idx], idx)
    print(f"Length of final merged bytes: {len(s_raw_bytes)}")
    return s_raw_bytes

In [30]:
s_raw_bytes = encode(s)

Length of raw bytes: 72
Length of final merged bytes: 61


In [31]:
# Decode
vocab = {id : bytes([id]) for id in range(256)}
for idx in merge_dict:
    vocab[idx] = vocab[merge_dict[idx][0]] + vocab[merge_dict[idx][1]]
def decode(ids):
    bts = b"".join([vocab[i] for i in ids])
    bts = bts.decode('utf-8', errors = "replace")
    return bts

In [32]:
decode(s_raw_bytes)

'Hi there, how are you? नमस्ते, आप कैसे हैं'

In [33]:
decode(encode("a"))

Length of raw bytes: 1
Length of final merged bytes: 1


'a'

In [34]:
# need to save merge_dict and vocab objects
with open("merge_dict.json", 'w') as file:
    json.dump(merge_dict, file)

# Final code for BPE tokenizer

In [27]:
class Tokenizer:
    def __init__(self):
        self.merge_dict = None
        self.vocab = None

    def _get_stats(self, ids: List, stats_dict = None) -> dict:
        """
        Given a list of integers, return a dictionary to give the count of the pairs coming consecutively
        """
        stats_dict = {} if stats_dict is None else stats_dict
        for pair in zip(ids, ids[1:]):
            stats_dict[pair] = stats_dict.get(pair, 0) + 1
        return stats_dict
    
    def _merge(self, ids, pair, new_idx):
        """Replace the pair at all the places with the new index
        """
        if len(ids) == 1:
            return ids
        new_ids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
                new_ids.append(new_idx)
                i += 2
            else:
                new_ids.append(ids[i])
                i += 1
        return new_ids

    def _build_vocab(self):
        """Function builds the vocab dict mapping each token to its raw bytes"""
        self.vocab = {id : bytes([id]) for id in range(256)}
        for idx in self.merge_dict:
            # print(self.merge_dict[idx])
            self.vocab[idx] = self.vocab[self.merge_dict[idx][0]] + self.vocab[self.merge_dict[idx][1]]
            # print(self.vocab)
        print(f"Vocabulary has been built internally of length {len(self.vocab)}, ready to encode and decode")
        
    def train(self, train_text: str, n_vocab : int, merge_dict_name = "merge_dict") -> dict:
        """Function will take a train_text on which the BPE tokenizer will get trained
        Parameters:
        train_text: single python string
        n_vocab: size of the vocabulary to be built
        """
        new_merges = n_vocab - 256
        train_raw_bytes = list(train_text.encode('utf-8'))
        print(f"Length of raw bytes: {len(train_raw_bytes)}")
        i = 0
        merge_dict = {}
        while i < new_merges:
            stats_dict = self._get_stats(train_raw_bytes)
            top_pair = max(stats_dict, key = stats_dict.get)
            new_token = 256 + i
            train_raw_bytes = self._merge(train_raw_bytes, top_pair, 256 + i)
            merge_dict[new_token] = top_pair
            i += 1
        print(f"Length of merged bytes: {len(train_raw_bytes)}")
        # need to save merge_dict
        merge_dict_path = f"{merge_dict_name}.json"
        with open(merge_dict_path, 'w') as file:
            json.dump(merge_dict, file)
        print(f"Merge dict has been save on the path: {merge_dict_path}")
        self.merge_dict = merge_dict
        self._build_vocab()

    def from_pretrained(self, merge_dict_path):
        with open(merge_dict_path, "r") as file:
            self.merge_dict = json.load(file)
        # When laoding a saved json, all the object's key gets converted into string
        self.merge_dict = {int(key) : val for key, val in self.merge_dict.items()}
        self._build_vocab()
    
    def encode(self, s : str) -> List[int]:
        """Function takes a single string and encodes it into bytes"""
        s_raw_bytes = list(s.encode('utf-8'))
        print(f"Length of raw bytes: {len(s_raw_bytes)}")
        for idx in self.merge_dict:
            s_raw_bytes = self._merge(s_raw_bytes, self.merge_dict[idx], idx)
        print(f"Length of final merged bytes: {len(s_raw_bytes)}")
        return s_raw_bytes
    
    def decode(self, ids : List[int]) -> Union[str, List[str]]:
        """Function will take a sequence of bytes and decode it back to unicode codepoints"""
        if self.vocab is not None:
            bts = b"".join([self.vocab[i] for i in ids])
            bts = bts.decode('utf-8', errors = "replace")
            return bts
        else:
            raise ValueError("Vocab has not been built, please built it first and then call decode")

In [28]:
tokenizer = Tokenizer()

In [29]:
taylor_swift_train_data = open("./minbpe/tests/taylorswift.txt", "r").read()
print(len(taylor_swift_train_data))
print(taylor_swift_train_data[:100])

185561
Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.
---

Main menu

WikipediaTh


In [30]:
%%time
tokenizer.train(taylor_swift_train_data, 2000, "merge_dict_taylor_swift")

Length of raw bytes: 185768
Length of merged bytes: 45722
Merge dict has been save on the path: merge_dict_taylor_swift.json
Vocabulary has been built internally of length 2000, ready to encode and decode
CPU times: user 12.9 s, sys: 3.05 ms, total: 12.9 s
Wall time: 12.5 s


In [31]:
tokenizer.from_pretrained("./merge_dict_taylor_swift.json")

Vocabulary has been built internally of length 2000, ready to encode and decode


In [32]:
tokens = tokenizer.encode("Hi there, how are yopu. I havnt seen you for days")

Length of raw bytes: 49
Length of final merged bytes: 25


In [33]:
tokenizer.decode(tokens)

'Hi there, how are yopu. I havnt seen you for days'

# See the merges

In [34]:
for n in range(270, 500):
    b_final = tokenizer.merge_dict[n].copy()
    while any([x >= 256 for x in b_final]):
        b_final_new = []
        for b in b_final:
            if b >= 256:
                b_final_new.extend(tokenizer.merge_dict[b])
            else:
                b_final_new.append(b)
        b_final = b_final_new
    print(f"{[bytes([b]).decode('utf-8', errors='replace') for b in b_final]} --> {tokenizer.vocab[n]}")

['a', 'n'] --> b'an'
['a', 'r'] --> b'ar'
['e', 'r', ' '] --> b'er '
['y', ' '] --> b'y '
['a', 'l'] --> b'al'
['t', 'h', 'e', ' '] --> b'the '
['v', 'e', 'd', ' '] --> b'ved '
['w', 'i'] --> b'wi'
['e', 'r'] --> b'er'
['o', 'n', ' '] --> b'on '
['w', 'i', 'f'] --> b'wif'
['R', 'e'] --> b'Re'
['S', 'w', 'i', 'f'] --> b'Swif'
['o', 'r', ' '] --> b'or '
['c', 'h'] --> b'ch'
[',', ' ', '2', '0', '1'] --> b', 201'
['o', 'm'] --> b'om'
['b', 'e', 'r', ' '] --> b'ber '
[' ', 't', 'h', 'e', ' '] --> b' the '
['a', 'y'] --> b'ay'
['e', 'n'] --> b'en'
['o', 'r'] --> b'or'
['a', 'l', ' '] --> b'al '
['e', 'm'] --> b'em'
['.', '\n'] --> b'.\n'
['r', 'i', 'e'] --> b'rie'
['i', 'n', 'g'] --> b'ing'
[',', ' ', '2', '0', '2'] --> b', 202'
['t', 'i'] --> b'ti'
['a', 'y', 'l'] --> b'ayl'
['"', '.', ' '] --> b'". '
['l', 'l'] --> b'll'
['T', 'a', 'y', 'l'] --> b'Tayl'
['t', 'r', 'i', 'e'] --> b'trie'
['.', '\n', ' '] --> b'.\n '
['t', 'o'] --> b'to'
['.', ' ', 'R', 'e'] --> b'. Re'
['.', ' ', 'R', 'e', 