In [1]:
from typing import List, Dict, Tuple

'''
I learned BPE works by reading Karpathy's lovely (https://github.com/karpathy/minbpe) so all credit to him! 
This implementation takes a similar approach, but is more bare-bones if you just want to understand how the fundamentals 
work, no frills. 

How BPE works and is implemented, TLDR: 
    - We start with raw bytes (0-255) as our base tokens and iteratively merge the most frequent adjacent pairs
    - Key functions: merge() combines token pairs into new tokens, get_stats() tracks pair frequencies
    - BPE class inherits from Tokenizer and learns merges during training, applies them greedily during encoding
    - Simple but slow O(ML^2) encoding - we scan the sequence repeatedly looking for mergeable pairs
    - All merges get stored in self.merges dict mapping (token1,token2)->new_token_id for easy lookup

Key subtleties: 
    - Production BPE implementations would use a trie for faster lookup so it's O(ML) complexity
    - self.vocab and self.merges are sufficient statistics for encode and decode. train() exists only to 
    populate these two for use at inference time. 
    - self.vocab is directly used to decode in a trivial way using lookup 
    - self.merges is used to repeat the merging process we used during training to construct our vocabulary 
    with new, unseen text, so its specifies our "recipe" to tokenize. 
    - It's crucial we merge in the same order as in training, hence why self.merges is ordered; 
    if not, we could be trying to merge (a, b) when there are no instances of a or b yet in the 
    token list we're constructing.
'''


"\nI learned BPE works by reading Karpathy's (https://github.com/karpathy/minbpe) so all credit to him!\n\nHow BPE works and is implemented, TLDR: \n    - We \n"

In [2]:
def merge(ids: List[int], pair: Tuple[int, int], idx: int) -> List[int]: 
    # for instance, if ids is [1, 2, 3, 1, 2, 4]
    # and pair is [1, 2] -> idx=7
    # then we output [7, 3, 7, 4]
    # this is the core BPE operation - it takes a sequence of token IDs and replaces all occurrences 
    # of a specific pair with a new token ID
    n = len(ids)
    new = []
    i = 0
    while i < n:
        if i < n-1 and (ids[i], ids[i+1]) == pair: 
            new.append(idx)
            i += 2
        else: 
            new.append(ids[i])
            i += 1
    return new 

def get_stats(ids: List[int]) -> Dict[Tuple[int, int], int]:
    # counts frequencies of adjacent token pairs in a sequence
    # this is used to identify which pairs to merge next during training
    n = len(ids)
    freqs = {}
    for i in range(n-1): 
        pair = (ids[i], ids[i+1])
        freqs[pair] = freqs.get(pair, 0) + 1
    return freqs

class Tokenizer: 
    def __init__(self): 
        self.merges = {} # (int, int) -> int, these are our "rules" we learn during training and use during inference 
        self.pattern = "" # str  
        self.vocab = self._build_vocab() # takes merges -> vocab using 256 bytes as inirt

    def train(self):
        raise NotImplementedError

    def encode(self): 
        raise NotImplementedError

    def decode(self): 
        raise NotImplementedError

    def _build_vocab(self):
        # constructs vocabulary mapping from token IDs to their byte sequences
        # starts with base vocabulary of 256 bytes and builds up using merge rules
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (i1, i2), idx in self.merges:
            vocab[idx] = vocab[i1] + vocab[i2]

        return vocab 
        

class BPE(Tokenizer): 
    def __init__(self): 
        super().__init__()

    def train(self, text: List[str], vocab_size: int): # corpus Dict[Word: freq] -> populate merges and vocab 
        # use get_stats() and merge() update merges and vocab 
        merges, vocab = {}, self.vocab # (int, int) -> int and {idx:bytes} where former determines the latter
        
        BASE_VOCAB_SIZE = 256
        assert vocab_size >= BASE_VOCAB_SIZE
        num_merges = vocab_size-BASE_VOCAB_SIZE

        # convert text to sequence of byte IDs - this is our starting point
        ids = list(text.encode("utf-8")) # list of ints representing chars 

        for i in range(num_merges): 
            # find the most popular pair using get_stats
            freqs = get_stats(ids)
            if not freqs:  # If no more pairs to merge
                break
            best_pair = max(freqs, key=freqs.get)            
            new_idx = BASE_VOCAB_SIZE+i
            # then merge using merge() to update merges 
            ids = merge(ids, best_pair, new_idx) 
            # update merges and vocab 
            merges[best_pair] = new_idx 
            vocab[new_idx] = vocab[best_pair[0]] + vocab[best_pair[1]]

        self.merges, self.vocab = merges, vocab

    # readme 
    def encode(self, text: str) -> List[int]: # use all merges
        # Start with raw bytes
        ids = list(text.encode("utf-8"))
        
        # Apply merges iteratively until no more can be applied
        # this is a greedy algorithm - we keep merging pairs until we can't anymore
        # production implementations use tries for O(n) complexity instead of O(n^2)
        while True:
            # Try to find a mergeable pair
            merged = False
            for i in range(len(ids) - 1): # O(L)
                pair = (ids[i], ids[i+1])
                if pair in self.merges: 
                    # Apply the merge
                    new_ids = ids[:i] + [self.merges[pair]] + ids[i+2:] # O(L)
                    ids = new_ids
                    merged = True
                    break
            
            # If no merges were applied in this pass, we're done
            if not merged:
                break
                
        return ids # O(ML^2), so very slow for long sequences 

    def decode(self, ids: List[int]) -> str: 
        # decoding is simple - just look up each token ID in the vocabulary to get its bytes
        # then concatenate and decode back to UTF-8
        byte_strings = [self.vocab[i] for i in ids]
        return b''.join(byte_strings).decode('utf-8')


In [3]:
# test the correctness of our tokenizer
text = "The quick brown fox jumps over the lazy dog."
print(f"Original text: {text}")

# create and train it
tokenizer = BPE()
tokenizer.train(text, vocab_size=300)

# test the encode/decode functionality
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

print(f"\nEncoded tokens: {encoded}")
print(f"Number of tokens: {len(encoded)}")
print(f"Decoded text: {decoded}")
print(f"\nDecoding matches original: {text == decoded}")

# now try tokenizing on unseen text we didn't train on, swap fox and dog
new_text = "A quick brown dog jumps over the lazy fox." 
print(f"\nTesting on new text: {new_text}")
encoded_new = tokenizer.encode(new_text)
decoded_new = tokenizer.decode(encoded_new)

# hurra, it works!
print(f"Encoded tokens: {encoded_new}")
print(f"Number of tokens: {len(encoded_new)}")
print(f"Decoded text: {decoded_new}")
print(f"Decoding matches original: {new_text == decoded_new}")


Original text: The quick brown fox jumps over the lazy dog.

Encoded tokens: [296]
Number of tokens: 1
Decoded text: The quick brown fox jumps over the lazy dog.

Decoding matches original: True

Testing on new text: A quick brown dog jumps over the lazy fox.
Encoded tokens: [65, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 100, 111, 103, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 257, 108, 97, 122, 121, 32, 102, 111, 120, 46]
Number of tokens: 40
Decoded text: A quick brown dog jumps over the lazy fox.
Decoding matches original: True
