# Problems for BPE

## Problem 1: Understanding Unicode
- (a) What Unicode character does chr(0) return?  
    Deliverable: A one-sentence response.  
    '\x00', space.
- (b) How does this character’s string representation (__repr__()) differ from its printed representation?  
    Deliverable: A one-sentence response.  
    "'\x00'", " "
- (c) What happens when this character occurs in text? It may be helpful to play around with the following in your Python interpreter and see if it matches your expectations: 

    Deliverable: A one-sentence response.
    ```python
    >>> chr(0) 
    >>> print(chr(0)) 
    >>> "this is a test" + chr(0) + "string"
    >>> print("this is a test" + chr(0) + "string")
    ```

    In string, it's '\x00'. If we print the string, it becomes space.

In [1]:
chr(0)

'\x00'

In [5]:
print(chr(0))

 


In [4]:
chr(0).__repr__()

"'\\x00'"

In [6]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [7]:
print("this is a test" + chr(0) + "string")

this is a test string


## Problem 2: Unicode Encodings
- (a) What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings.
    - Vocabulary Size: UTF-8 limits the vocabulary to 256 possible tokens, making the model simpler.
    - Efficiency: For texts with lots of ASCII characters, UTF-8 is much more space- and compute-efficient.
    - Simplicity in Tokenization: Handling a stream of bytes is simpler than dealing with variable-length code units or surrogate pairs
    - Robustness: Byte-level models trained on UTF-8 can better handle noisy or unexpected inputs.
    -Compatibility: UTF-8’s ubiquity makes it easier to integrate with various data sources and systems.
- (b) Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.  
    ```python
    def decode_utf8_bytes_to_str_wrong(bytestring: bytes): 
        return "".join([bytes([b]).decode("utf-8") for b in bytestring])  
    >>> decode_utf8_bytes_to_str_wrong("hello".encode("utf-8")) 
    'hello'  
    ``` 
    One token can has multiple bytes. For example "蔡".
- (c) Give a two byte sequence that does not decode to any Unicode character(s).
    b'\xe8\x94'. Typically, they have some special bytes for start and end.

In [11]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes): return "".join([bytes([b]).decode("utf-8") for b in bytestring])

print("蔡".encode("utf-8"))
b'\xe8\x94'

b'\xe8\x94\xa1'


b'\xe8\x94'

## Problem 3: BPE Tokenizer Training
Deliverable: Write a function that, given path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input parameters:

- input_path: str Path to a text file with BPE tokenizer training data.  

- vocab_size: int A non-negative integer that defines the maximum final vocabulary size (including the initial byte vocabulary, vocabulary items produced from merging, and any special tokens).  

- special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not otherwise affect BPE training.  

Your BPE training function should return the resulting vocabulary and merges:  

- vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).  

- merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with <token2>. The merges should be ordered by order of creation.  

To test your BPE training function against our provided tests, you will first need to implement the test adapter at [adapters.run_train_bpe]. Then, run pytest tests/test_train_bpe.py. Your implementation should be able to pass all tests.

In [2]:
import regex as re 

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""


To simplify the code for BPE, we construct a bi-direct list, which will be very useful for the bytes. 

In [14]:
class bilist:
    def __init__(self, val, id, previous=None, next=None):
        self.val = val 
        self.id = id 
        self.prev = previous
        self.next = next 
    
    def merge(self):
        if self.next is None:
            raise ValueError("Cannot merge last element")
        else:
            self.val = self.val + self.next.val
            self.next = self.next.next
            if self.next:
                self.next.prev = self

Now we can try to write the prototype for the BPE algorithm. We consider some trivial cases.

In [21]:
from collections import defaultdict

text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest" # text to be tokenized
prim_tokens = text.split() # split text into tokens
print(prim_tokens) # print tokens

merges = [] # list of merges, tuple[bytes, bytes]
vocab = {} # int to  bytes dictionary [int: bytes]

id_to_token = {} # id to token dictionary
token_to_id = {} # token to id dictionary
frequency_tokens = {} # frequency of tokens table
frequency_pairs = {}  # frequency of pairs table
positions_pairs = defaultdict(list) # positions of pairs table, position is a bilist object


# create token to id, id to token dictionaries, frequency_tokens table, and positions_pairs table
for prim_token in prim_tokens:
    chars = [bytes([b]) for b in prim_token.encode("utf-8")] # get bytes of token
    if prim_token in token_to_id:
        frequency_tokens[token_to_id[prim_token]] += 1 # increment frequency of token
    else:
        id = len(token_to_id) # get new id
        token_to_id[prim_token] = id # add token to id dictionary
        id_to_token[id] = prim_token # add id to token dictionary
        frequency_tokens[id] = 1 # set frequency of token to 1

        bi_char = bilist(chars[0], id) # create first bilist object
        for char in chars[1:]:
            new_bi_char = bilist(char, id, bi_char) # create new bilist object
            bi_char.next = new_bi_char # link new bilist object to previous one
            pairs = (bi_char.val, new_bi_char.val) # create pair
            positions_pairs[pairs].append(bi_char) # add position to pair
            bi_char = new_bi_char # move to next bilist object

# update the frequency_pairs table
for pairs, positions in positions_pairs.items(): 
    for position in positions:
        frequency_pairs[pairs] = frequency_pairs.get(pairs, 0) + frequency_tokens[position.id]  # increment frequency of pair

# print(frequency_pairs)

# find the pair with the highest frequency and marge 
for _ in range(6):
    
    max_pair = max(frequency_pairs, key=lambda x:(frequency_pairs[x], x)) # get pair with highest frequency
    # print(max_pair) # print pair
    merges.append(max_pair) # add pair to merges list
    positions = positions_pairs[max_pair] # get positions of pair
    
    for position in positions: 

        val1 = position.val # get value of position
        val2 = position.next.val # get value of next position
        temp_pos =  position.next # get next position
        position.merge() # merge position with next position

        if position.prev: # if there is a previous position
            prev_pair = (position.prev.val, val1)
            frequency_pairs[prev_pair] -= frequency_tokens[position.id] # decrement frequency of pair   
            positions_pairs[prev_pair].remove(position.prev) # remove previous position from pair
            new_prev_pair = (position.prev.val, position.val) # create new pair
            frequency_pairs[new_prev_pair] = frequency_pairs.get(new_prev_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
            positions_pairs[new_prev_pair].append(position.prev) # add previous position to new pair

        if position.next: # if there is a next position
            next_pair = (val2, position.next.val)
            frequency_pairs[next_pair] -= frequency_tokens[position.id] # decrement frequency of pair
            positions_pairs[next_pair].remove(temp_pos) # remove next position from pair
            new_next_pair = (position.val, position.next.val) # create new pair
            frequency_pairs[new_next_pair] = frequency_pairs.get(new_next_pair, 0) + frequency_tokens[position.id] # increment frequency of new pair
            positions_pairs[new_next_pair].append(position) # add position to new pair  

    del positions_pairs[max_pair] # delete pair from positions_pairs
    del frequency_pairs[max_pair] # delete pair from frequency_pairs

    # print(frequency_pairs)
print(merges)

['low', 'low', 'low', 'low', 'low', 'lower', 'lower', 'widest', 'widest', 'widest', 'newest', 'newest', 'newest', 'newest', 'newest', 'newest']
[(b's', b't'), (b'e', b'st'), (b'o', b'w'), (b'l', b'ow'), (b'w', b'est'), (b'n', b'e')]


So for now, this code is almost working. The remaining thing is to implement this for the general text. It's worth noting that for the general case, there is a special token.