# Tokenisation

In this notebook. we are going to build a basic tokenizer, which will allow us to train a tokenizer and then test it, through encoding and decoding. We will be recreating a variant BPE algorithm, which has been widely used in OpenAI GPT models and other LLMs.

**Note:** This is an implementation obtained from https://github.com/karpathy/minbpe - the legend Andrej Karpathy!

In [1]:
# Import Libraries
from abc import abstractmethod, ABC

## Base Tokenizer

In [2]:
class BaseTokenizer(ABC):
    
    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes

    def train(self, text, vocab_size, verbose=False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        raise NotImplementedError

    def encode(self, text):
        # Tokenizer can encode a string into a list of integers
        raise NotImplementedError

    def decode(self, ids):
        # Tokenizer can decode a list of integers into a string
        raise NotImplementedError

    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

In [3]:
def get_stat(ids):
    # Given a list of units, find the 2-gram tuple pairs and number of occurrences
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [4]:
def merge(pair, ids, idx):
    # Given ids, we take the most occurring tuple pair and merge them
    temp = []
    # Loop through existing ids
    i = 0
    while i < len(ids):
        # If the ids match the maximum pair occurrence then replace it with new token index and move onto next pair
        # Make sure that last token is not cut off
        if i < len(ids) - 1 and pair == (ids[i], ids[i+1]):
            temp.append(idx)
            i += 2
        # Else, add the first idx of the token and shift
        else:
            temp.append(ids[i])
            i+=1
    return temp

## Basic Tokenizer

The tokenizer used in GPT is a modification of the original BPE algorithm. 
- The original BPE algorithm, developed for compressing data, merges the most frequent pair of bytes and replaces them with a new, unused byte, requiring a lookup table for decoding. 
- In contrast, the modified version used in large language models merges the most frequent pairs of characters (not bytes) into longer tokens without substituting them with unseen characters, thus eliminating the need for a lookup table.

In [5]:
class BasicTokenizer(BaseTokenizer):

    def __init__(self):
        super().__init__()
        
    def train(self, text, vocab_size):
        # Check vocab size is larger than 256 (as UTF-8 has this minimum conversion size)
        assert vocab_size >= 256
        # Number of additional merges 
        num_merges = vocab_size - 255
        # Change text to raw bytes
        text_bytes = text.encode("utf-8")
        # Create list of integers in range 0 to 255
        ids = list(text_bytes)
        # Ieratively merge the most common pairs to create new tokens
        merges = {} 
        # Store new tokens in vocab list
        vocab = {idx: bytes([idx]) for idx in range(256)} 
        # Loop through number of merges
        for i in range(num_merges):
            print(f'Iteration {i+1} - text as ids: \n{ids}')
            # Obtain the pair occurrences
            pairs = get_stat(ids)
            # Find the pair with the maximum number of occurrences
            try:
                pair = max(pairs, key=lambda x: pairs[x])
            except ValueError as v:
                print('Number of preferred merges exceeds number of available merges!')
                break
            # Increment this to be a new token
            idx = 256 + i
            # Update the ids through merge
            ids = merge(pair, ids, idx)
            # Save the merged pair
            merges[pair] = idx
            print(f'Pair merged: {pair}\n')
            # Update the vocab list
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
        # Save merges
        self.merges = merges
        # Save pairs
        self.vocab = vocab
        
    def encode(self, text):
        # Change text to raw bytes        
        text_bytes = text.encode("utf-8") 
        # Create list of integers in range 0 to 255
        ids = list(text_bytes) 
        # Make sure text is at least length 2, otherwise there is no merge and we just return the byte
        counter = 0
        while len(ids) >= 2:
            print(f'Iteration {counter+1} - text as ids: \n{ids}')
            counter += 1
            # Get the occurrences in the text
            pairs = get_stat(ids)
            # Find the pair to merge from the merges dictionary - start from the recently added tokens and scale up
            pair = min(pairs, key=lambda x: self.merges.get(x, float("inf")))
            # We must make sure we match the minimum pair to the keys in the merges otherwise we could get issues
            if pair not in self.merges:
                print(f'{pair} pair not in merge - merge process finished!')
                break
            # Else merge the lowest pair
            print(f'Pair merged: {pair}\n')
            idx = self.merges[pair]
            ids = merge(pair, ids, idx)
        return ids
        
    def decode(self, ids):
        # Given the ids, obtain the bytes
        text_bytes = b''.join(self.vocab[idx] for idx in ids)
        # Decode the bytes
        return text_bytes.decode('utf-8', errors='replace')

### Example

In [6]:
train_text = 'Hi there! What are you doing? Do you know what the weather is like today? If you do, where would you go?'

In [7]:
tk = BasicTokenizer()

In [8]:
tk.train(train_text, vocab_size=258)

Iteration 1 - text as ids: 
[72, 105, 32, 116, 104, 101, 114, 101, 33, 32, 87, 104, 97, 116, 32, 97, 114, 101, 32, 121, 111, 117, 32, 100, 111, 105, 110, 103, 63, 32, 68, 111, 32, 121, 111, 117, 32, 107, 110, 111, 119, 32, 119, 104, 97, 116, 32, 116, 104, 101, 32, 119, 101, 97, 116, 104, 101, 114, 32, 105, 115, 32, 108, 105, 107, 101, 32, 116, 111, 100, 97, 121, 63, 32, 73, 102, 32, 121, 111, 117, 32, 100, 111, 44, 32, 119, 104, 101, 114, 101, 32, 119, 111, 117, 108, 100, 32, 121, 111, 117, 32, 103, 111, 63]
Pair merged: (111, 117)

Iteration 2 - text as ids: 
[72, 105, 32, 116, 104, 101, 114, 101, 33, 32, 87, 104, 97, 116, 32, 97, 114, 101, 32, 121, 256, 32, 100, 111, 105, 110, 103, 63, 32, 68, 111, 32, 121, 256, 32, 107, 110, 111, 119, 32, 119, 104, 97, 116, 32, 116, 104, 101, 32, 119, 101, 97, 116, 104, 101, 114, 32, 105, 115, 32, 108, 105, 107, 101, 32, 116, 111, 100, 97, 121, 63, 32, 73, 102, 32, 121, 256, 32, 100, 111, 44, 32, 119, 104, 101, 114, 101, 32, 119, 256, 108, 100, 32, 

In [9]:
tk.merges

{(111, 117): 256, (104, 101): 257, (32, 121): 258}

In [10]:
tk.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [11]:
test_text = 'Hi there! You look amazing today. You should go out!'Neural Language Model

In [12]:
test_ids = tk.encode(test_text)

Iteration 1 - text as ids: 
[72, 105, 32, 116, 104, 101, 114, 101, 33, 32, 89, 111, 117, 32, 108, 111, 111, 107, 32, 97, 109, 97, 122, 105, 110, 103, 32, 116, 111, 100, 97, 121, 46, 32, 89, 111, 117, 32, 115, 104, 111, 117, 108, 100, 32, 103, 111, 32, 111, 117, 116, 33]
Pair merged: (111, 117)

Iteration 2 - text as ids: 
[72, 105, 32, 116, 104, 101, 114, 101, 33, 32, 89, 256, 32, 108, 111, 111, 107, 32, 97, 109, 97, 122, 105, 110, 103, 32, 116, 111, 100, 97, 121, 46, 32, 89, 256, 32, 115, 104, 256, 108, 100, 32, 103, 111, 32, 256, 116, 33]
Pair merged: (104, 101)

Iteration 3 - text as ids: 
[72, 105, 32, 116, 257, 114, 101, 33, 32, 89, 256, 32, 108, 111, 111, 107, 32, 97, 109, 97, 122, 105, 110, 103, 32, 116, 111, 100, 97, 121, 46, 32, 89, 256, 32, 115, 104, 256, 108, 100, 32, 103, 111, 32, 256, 116, 33]
(72, 105) pair not in merge - merge process finished!


In [13]:
test_decoded_text = tk.decode(test_ids)

In [14]:
assert test_decoded_text == test_text, 'Text through encoding and decoding via tokenisation is NOT the same as the original text!'

Hooray!