# Problems for BPE

## Problem 1: Understanding Unicode
- (a) What Unicode character does chr(0) return?  
    Deliverable: A one-sentence response.  
    '\x00', space.
- (b) How does this character’s string representation (__repr__()) differ from its printed representation?  
    Deliverable: A one-sentence response.  
    "'\x00'", " "
- (c) What happens when this character occurs in text? It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations: 

    Deliverable: A one-sentence response.
    ```python
    >>> chr(0) 
    >>> print(chr(0)) 
    >>> "this is a test" + chr(0) + "string"
    >>> print("this is a test" + chr(0) + "string")
    ```

    In string, it's '\x00'. If we print the string, it becomes space.

In [1]:
chr(0)

'\x00'

In [5]:
print(chr(0))

 


In [4]:
chr(0).__repr__()

"'\\x00'"

In [6]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [7]:
print("this is a test" + chr(0) + "string")

this is a test string


## Problem 2: Unicode Encodings
- (a) What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings.
    - Vocabulary Size: UTF-8 limits the vocabulary to 256 possible tokens, making the model simpler.
    - Efficiency: For texts with lots of ASCII characters, UTF-8 is much more space- and compute-efficient.
    - Simplicity in Tokenization: Handling a stream of bytes is simpler than dealing with variable-length code units or surrogate pairs
    - Robustness: Byte-level models trained on UTF-8 can better handle noisy or unexpected inputs.
    -Compatibility: UTF-8’s ubiquity makes it easier to integrate with various data sources and systems.
- (b) Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.  
    ```python
    def decode_utf8_bytes_to_str_wrong(bytestring: bytes): 
        return "".join([bytes([b]).decode("utf-8") for b in bytestring])  
    >>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8")) 
    'hello'  
    ``` 
    One token can has multiple bytes. For example "蔡".
- (c) Give a two byte sequence that does not decode to any Unicode character(s).
    b'\xe8\x94'. Typically, they have some special bytes for start and end.

In [11]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes): return "".join([bytes([b]).decode("utf-8") for b in bytestring])

print("蔡".encode("utf-8"))
b'\xe8\x94'

b'\xe8\x94\xa1'


b'\xe8\x94'

## Problem 3: BPE Tokenizer Training
Deliverable: Write a function that, given path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input parameters:

- input_path: str Path to a text file with BPE tokenizer training data.  

- vocab_size: int A non-negative integer that defines the maximum final vocabulary size (including the initial byte vocabulary, vocabulary items produced from merging, and any special tokens).  

- special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not otherwise affect BPE training.  

Your BPE training function should return the resulting vocabulary and merges:  

- vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).  

- merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with <token2>. The merges should be ordered by order of creation.  

To test your BPE training function against our provided tests, you will first need to implement the test adapter at [adapters.run_train_bpe]. Then, run pytest tests/test_train_bpe.py. Your implementation should be able to pass all tests.

In [1]:
import regex as re 

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""


To simplify the code for BPE, we construct a bi-direct list, which will be very useful for the bytes. 

In [14]:
class bilist:
    def __init__(self, val, id, previous=None, next=None):
        self.val = val 
        self.id = id 
        self.prev = previous
        self.next = next 
    
    def merge(self):
        if self.next is None:
            raise ValueError("Cannot merge last element")
        else:
            self.val = self.val + self.next.val
            self.next = self.next.next
            if self.next:
                self.next.prev = self

Now we can try to write the prototype for the BPE algorithm. We consider some trivial cases.

In [21]:
from collections import defaultdict

text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest" # text to be tokenized
prim_tokens = text.split() # split text into tokens
print(prim_tokens) # print tokens

merges = [] # list of merges, tuple[bytes, bytes]
vocab = {} # int to  bytes dictionary [int: bytes]

id_to_token = {} # id to token dictionary
token_to_id = {} # token to id dictionary
frequency_tokens = {} # frequency of tokens table
frequency_pairs = {}  # frequency of pairs table
positions_pairs = defaultdict(list) # positions of pairs table, position is a bilist object


# create token to id, id to token dictionaries, frequency_tokens table, and positions_pairs table
for prim_token in prim_tokens:
    chars = [bytes([b]) for b in prim_token.encode("utf-8")] # get bytes of token
    if prim_token in token_to_id:
        frequency_tokens[token_to_id[prim_token]] += 1 # increment frequency of token
    else:
        id = len(token_to_id) # get new id
        token_to_id[prim_token] = id # add token to id dictionary
        id_to_token[id] = prim_token # add id to token dictionary
        frequency_tokens[id] = 1 # set frequency of token to 1

        bi_char = bilist(chars[0], id) # create first bilist object
        for char in chars[1:]:
            new_bi_char = bilist(char, id, bi_char) # create new bilist object
            bi_char.next = new_bi_char # link new bilist object to previous one
            pairs = (bi_char.val, new_bi_char.val) # create pair
            positions_pairs[pairs].append(bi_char) # add position to pair
            bi_char = new_bi_char # move to next bilist object

# update the frequency_pairs table
for pairs, positions in positions_pairs.items(): 
    for position in positions:
        frequency_pairs[pairs] = frequency_pairs.get(pairs, 0) + frequency_tokens[position.id]  # increment frequency of pair

# print(frequency_pairs)

# find the pair with the highest frequency and marge 
for _ in range(6):
    
    max_pair = max(frequency_pairs, key=lambda x:(frequency_pairs[x], x)) # get pair with highest frequency
    # print(max_pair) # print pair
    merges.append(max_pair) # add pair to merges list
    positions = positions_pairs[max_pair] # get positions of pair
    
    for position in positions: 

        val1 = position.val # get value of position
        val2 = position.next.val # get value of next position
        temp_pos =  position.next # get next position
        position.merge() # merge position with next position

        if position.prev: # if there is a previous position
            prev_pair = (position.prev.val, val1)
            frequency_pairs[prev_pair] -= frequency_tokens[position.id] # decrement frequency of pair   
            positions_pairs[prev_pair].remove(position.prev) # remove previous position from pair
            new_prev_pair = (position.prev.val, position.val) # create new pair
            frequency_pairs[new_prev_pair] = frequency_pairs.get(new_prev_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
            positions_pairs[new_prev_pair].append(position.prev) # add previous position to new pair

        if position.next: # if there is a next position
            next_pair = (val2, position.next.val)
            frequency_pairs[next_pair] -= frequency_tokens[position.id] # decrement frequency of pair
            positions_pairs[next_pair].remove(temp_pos) # remove next position from pair
            new_next_pair = (position.val, position.next.val) # create new pair
            frequency_pairs[new_next_pair] = frequency_pairs.get(new_next_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
            positions_pairs[new_next_pair].append(position) # add position to new pair  

    del positions_pairs[max_pair] # delete pair from positions_pairs
    del frequency_pairs[max_pair] # delete pair from frequency_pairs

    # print(frequency_pairs)
print(merges)

['low', 'low', 'low', 'low', 'low', 'lower', 'lower', 'widest', 'widest', 'widest', 'newest', 'newest', 'newest', 'newest', 'newest', 'newest']
[(b's', b't'), (b'e', b'st'), (b'o', b'w'), (b'l', b'ow'), (b'w', b'est'), (b'n', b'e')]


So for now, this code is almost working. The remaining thing is to implement this for the general text. It's worth noting that for the general case, there is a special token.

In [8]:
import regex as re
from collections import defaultdict
from tqdm import tqdm

class bilist:
    def __init__(self, val, id, previous=None, next=None):
        self.val = val 
        self.id = id 
        self.prev = previous
        self.next = next 
    
    def merge(self):
        if self.next is None:
            raise ValueError("Cannot merge last element")
        else:
            self.val = self.val + self.next.val
            self.next = self.next.next
            if self.next:
                self.next.prev = self


PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" # pattern for tokenization
def train_bpe(input_path:str, vocab_size:int, special_tokens:list[str]):
    """ 
    Train a BPE model on a text file.
    
    Args:

    input_path: str
        The path to the input text file.
    vocab_size: int
        The size of the vocabulary.
    special_tokens: list[str]
        A list of special tokens.
    
    Returns:
    vocab: dict[int, bytes]
        A dictionary mapping token ids to tokens.
    merges: list[tuple[bytes, bytes]]
        A list of merges.

    """
    print("Training BPE model")
    print("Tokenizing text")
    with open(input_path, "r") as file:
        text = file.read()

    # Split text into lines for progress monitoring
    lines = text.splitlines()

    # Initialize an empty list to store tokens
    prim_tokens = []

    # Process each line and update the progress bar
    for line in tqdm(lines, desc="Processing text"):
        tokens = re.findall(PAT, line)
        prim_tokens.extend(tokens)

    # Initialize the vocab with 256 bytes and sepcial tokens
    print("Initializing vocab")
    vocab = {i: bytes([i]) for i in range(256)}
    for i, token in enumerate(special_tokens):
        vocab[256+i] = token.encode("utf-8")
    
    merges = [] # list of merges, tuple[bytes, bytes]
    id_to_token = {} # id to token dictionary
    token_to_id = {} # token to id dictionary
    frequency_tokens = {} # frequency of tokens table
    frequency_pairs = {}  # frequency of pairs table
    positions_pairs = defaultdict(list) # positions of pairs table, position is a bilist object


    # create token to id, id to token dictionaries, frequency_tokens table, and positions_pairs table

    print("Reading tokens and creating bilist objects")
    for prim_token in tqdm(prim_tokens):
        chars = [bytes([b]) for b in prim_token.encode("utf-8")] # get bytes of token
        if prim_token in token_to_id:
            frequency_tokens[token_to_id[prim_token]] += 1 # increment frequency of token
        else:
            id = len(token_to_id) # get new id
            token_to_id[prim_token] = id # add token to id dictionary
            id_to_token[id] = prim_token # add id to token dictionary
            frequency_tokens[id] = 1 # set frequency of token to 1

            bi_char = bilist(chars[0], id) # create first bilist object
            for char in chars[1:]:
                new_bi_char = bilist(char, id, bi_char) # create new bilist object
                bi_char.next = new_bi_char # link new bilist object to previous one
                pairs = (bi_char.val, new_bi_char.val) # create pair
                positions_pairs[pairs].append(bi_char) # add position to pair
                bi_char = new_bi_char # move to next bilist object

    # update the frequency_pairs table
    print("Updating frequency pairs table")
    for pairs, positions in tqdm(positions_pairs.items()): 
        for position in positions:
            frequency_pairs[pairs] = frequency_pairs.get(pairs, 0) + frequency_tokens[position.id]  # increment frequency of pair

    # print(frequency_pairs)

    # find the pair with the highest frequency and marge 
    # while len(vocab) < vocab_size:
    print("Finding merges")
    for _ in tqdm(range(vocab_size - len(vocab))):
            
        max_pair = max(frequency_pairs, key=lambda x:(frequency_pairs[x], x)) # get pair with highest frequency
        # print(max_pair) # print pair
        merges.append(max_pair) # add pair to merges list
        new_char = max_pair[0] + max_pair[1] # create new character
        vocab[len(vocab)] = new_char # add new character to vocab
        positions = positions_pairs[max_pair] # get positions of pair
        
        for position in positions: 

            val1 = position.val # get value of position
            val2 = position.next.val # get value of next position
            temp_pos =  position.next # get next position
            position.merge() # merge position with next position

            if position.prev: # if there is a previous position
                prev_pair = (position.prev.val, val1)
                frequency_pairs[prev_pair] -= frequency_tokens[position.id] # decrement frequency of pair   
                positions_pairs[prev_pair].remove(position.prev) # remove previous position from pair
                new_prev_pair = (position.prev.val, position.val) # create new pair
                frequency_pairs[new_prev_pair] = frequency_pairs.get(new_prev_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
                positions_pairs[new_prev_pair].append(position.prev) # add previous position to new pair

            if position.next: # if there is a next position
                next_pair = (val2, position.next.val)
                frequency_pairs[next_pair] -= frequency_tokens[position.id] # decrement frequency of pair
                positions_pairs[next_pair].remove(temp_pos) # remove next position from pair
                new_next_pair = (position.val, position.next.val) # create new pair
                frequency_pairs[new_next_pair] = frequency_pairs.get(new_next_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
                positions_pairs[new_next_pair].append(position) # add position to new pair  

        del positions_pairs[max_pair] # delete pair from positions_pairs
        del frequency_pairs[max_pair] # delete pair from frequency_pairs
    
    return vocab, merges

vocab, merges = train_bpe("./data/fixtures/corpus.en", 500, ["<|endoftext|>"])
print(vocab)
print(merges)

Training BPE model
Tokenizing text


Processing text: 100%|██████████| 1015/1015 [00:00<00:00, 101500.10it/s]


Initializing vocab
Reading tokens and creating bilist objects


100%|██████████| 26743/26743 [00:00<00:00, 122575.01it/s]


Updating frequency pairs table


100%|██████████| 1069/1069 [00:00<00:00, 213825.69it/s]


Finding merges


100%|██████████| 243/243 [00:00<00:00, 2585.12it/s]

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',




## Problem 4: BPE training on TinyStories 
- (a) Train a byte-level BPE tokenizer on the TinyStories dataset, using a maximum vocabulary size of 10,000. Make sure to add the TinyStories <|endoftext|> special token to the vocabulary. Serialize the resulting vocabulary and merges to disk for further inspection. How many hours and memory did training take? What is the longest token in the vocabulary? Does it make sense?  
Resource requirements: ≤ 30 minutes (no GPUs), ≤ 30GB RAM 

- (b) Profile your code. What part of the tokenizer training process takes the most time?  

In [9]:
import psutil 
import time 

process = psutil.Process()
mem_before = process.memory_info().rss/1024/1024
path_file = "./data/TinyStoriesV2-GPT4-train.txt"
vocab, merges = train_bpe(path_file, 10000, ["<|endoftext|>"])

mem_after = process.memory_info().rss/1024/1024
print(f"Memory before: {mem_before} MB")
print(f"Memory after: {mem_after} MB")

Training BPE model
Tokenizing text


Processing text: 100%|██████████| 15600063/15600063 [03:18<00:00, 78537.56it/s]


Initializing vocab
Reading tokens and creating bilist objects


100%|██████████| 529467880/529467880 [06:24<00:00, 1377433.38it/s]


Updating frequency pairs table


100%|██████████| 2098/2098 [00:00<00:00, 24976.23it/s]


Finding merges


100%|██████████| 9743/9743 [00:49<00:00, 198.12it/s]

Memory before: 8579.41015625 MB
Memory after: 1245.51953125 MB





In [10]:
# Find the longest token in the vocabulary

max_vocab = max(vocab.values(), key=lambda x: len(x))
print(max_vocab)
print(len(max_vocab)) 

b' accomplishment'
15


# Implementing the tokenizer

Deliverable: Implement a Tokenizer class that, given a vocabulary and a list of merges, encodes text into integer IDs and decodes integer IDs into text. Your tokenizer should also support user-provided special tokens (appending them to the vocabulary if they aren’t already there). We recommend the following interface:  

`def __init__(self, vocab, merges, special_tokens=None)`

Construct a tokenizer from a given vocabulary, list of merges, and (optionally) a list of special tokens. This function should accept the following parameters:  

`vocab: dict[int, bytes]`  
`merges: list[tuple[bytes, bytes]]`  
`special_tokens: list[str] | None = None`  

`def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None)` 

Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges (in the same format that your BPE training code output) and (optionally) a list of special tokens. This method should accept the following additional parameters:  

`vocab_filepath: str`  
`merges_filepath: str`  
`special_tokens: list[str] | None = None`  

`def encode(self, text: str) -> list[int]` 

Encode an input text into a sequence of token IDs.  

`def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]` 

Given an iterable of strings (e.g., a Python file handle), return a generator that lazily yields token IDs. This is required for memory-efficient tokenization of large files that we cannot directly load into memory.  

`def decode(self, ids: list[int]) -> str` 

Decode a sequence of token IDs into text.  

To test your Tokenizer against our provided tests, you will first need to implement the test adapter at [adapters.get_tokenizer]. Then, run pytest tests/test_tokenizer.py. Your implementation should be able to pass all tests.

In [21]:
class Tokenizer:
    def __init__(self, vocab:dict[int,bytes], merges:list[tuple[(bytes,bytes)]], special_tokens=None):
        # load vocab 
        self.vocab = {}
        self.vocab['int_to_byte'] = vocab 
        self.vocab['byte_to_int'] = {v:k for k,v in vocab.items()} 

        # load merges 
        self.merges = {}
        for a, b in merges:
            id_pair = (self.vocab['byte_to_int'][a], self.vocab['byte_to_int'][b])
            self.merges[id_pair] = self.vocab['byte_to_int'][a+b]
        
        # load special tokens
        self.special_tokens = {}
        if special_tokens:
            special_tokens = sorted(special_tokens, key=len, reverse=True)
            for token in special_tokens:
                token_bytes = token.encode("utf-8")
                if token_bytes not in self.vocab['byte_to_int']:
                    self.vocab['int_to_byte'][len(self.vocab['int_to_byte'])] = token_bytes
                    self.vocab['byte_to_int'][token_bytes] = len(self.vocab['int_to_byte']) 
                    self.special_tokens[token] = self.vocab['byte_to_int'][token_bytes]
                else:
                    self.special_tokens[token] = self.vocab['byte_to_int'][token_bytes]

    
    @classmethod
    def from_files(cls, vocab_file_path, merges_file_path, special_tokens=None):
        vocab, merges = get_tokenizer_from_path(vocab_file_path, merges_file_path)
        return cls(vocab, merges, special_tokens)

    def encode(self, text:str, progress_bar:bool=False)-> list[int]:
        """
        Encode a text into token ids.
        """
        if self.special_tokens:
            chunk_pattern = "(" + "|".join(re.escape(token) for token in self.special_tokens) + ")"
            split_chunks = re.split(chunk_pattern, text)
        else:
            split_chunks = [text]
        
        ids = [] 
        for chunk in tqdm(split_chunks, disable=not progress_bar, desc="Encoding {len(split_chunks)} chunks"):
            new_ids = self.encode_chunk(chunk)
            ids.extend(new_ids)
        return ids

    def encode_chunk(self, chunk:str)-> list[int]:
        """
        Encode a chunk of text into token ids.
        """
        if chunk in self.special_tokens:
            return [self.special_tokens[chunk]]
        else:
            tokens = re.findall(PAT, chunk)
            total_ids = []
            for token in tokens:
                token_bytes = token.encode("utf-8")
                token_ids = [self.vocab['byte_to_int'][bytes([byte])] for byte in token_bytes]
            
                while len(token_ids) > 1:
                    pairs = [(token_ids[i], token_ids[i+1]) for i in range(len(token_ids)-1)] # get all pairs of token_ids
                    high_priority_pair = min(pairs, key=lambda x: self.merges.get(x, float('inf'))) # get the pair with the highest merge priority

                    # We need to merge all instances of high_priority_pair in token_ids
                    if high_priority_pair in self.merges: # if the pair is in merges, we merge
                        new_token_id = self.merges[high_priority_pair]
                        new_token_ids = []
                        ind = 0
                        while ind < len(token_ids): 
                            if ind < len(token_ids) - 1 and (token_ids[ind], token_ids[ind+1]) == high_priority_pair:
                                new_token_ids.append(new_token_id)
                                ind += 2
                            else:
                                new_token_ids.append(token_ids[ind])
                                ind += 1
                        token_ids = new_token_ids
                    else: # if the pair is not in merges, we break
                        break
                total_ids.extend(token_ids)
            return total_ids # return the token ids
                


    def encode_iterable(self, texts):
        for text in texts:
            ids = self.encode(text)
            for id in ids:
                yield id


    def decode(self, ids: list[int])-> str:
        text_bytes = b"".join([self.vocab['int_to_byte'][id] for id in ids])
        return text_bytes.decode("utf-8", errors="replace")
    

Let's try the example. Here 

```python
vocab = {0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'}
merges = [(b't', b'h'), (b' ', b'c'), (b' ', 'a'), (b'th', b'e'), (b' a', b't')]
text = 'the cat ate'
```


In [23]:
vocab = {0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'}
merges = [(b't', b'h'), (b' ', b'c'), (b' ', b'a'), (b'th', b'e'), (b' a', b't')]
text = 'the cat ate'
tokenizer = Tokenizer(vocab, merges)
# print(tokenizer.vocab['byte_to_int'])
ids = tokenizer.encode(text)
print(ids)
decoded_text = tokenizer.decode(ids)
print(decoded_text)

[9, 7, 1, 5, 10, 3]
the cat ate
