#### Tokenization: Byte Pair Encoding

Based on Andrej Karpathy youtube tutorial.

In [52]:
from collections import defaultdict
import pprint as pp

`Tokenization into bytes`: We can use UTF-8 encoding to convert our string of Unicode characters into a sequence of bytes. Then we could define each byte as a separate `token`.

In [53]:
# sample string of Unicode characters
s = 'café'
# convert to bytes using UTF-8 encoding
b = s.encode('utf8')
print(f"Original string: {s}, UTF-8 encoding: {b}, size of encoding: {len(b)} bytes")
# show each character and it's utf-8 byte representation
for c in s:
    print(f"{c} -> {c.encode('utf8')} --> {list(c.encode('utf8'))}, num bytes: {len(c.encode('utf8'))}")

# convert each of the 5 bytes in the utf-8 encoding of the sample string to its corresponding integer value (0-255)
byte_values = list(b)
print(f"\n UTF-8 encoding of '{s}' converted to a list of integers: {byte_values}")


Original string: café, UTF-8 encoding: b'caf\xc3\xa9', size of encoding: 5 bytes
c -> b'c' --> [99], num bytes: 1
a -> b'a' --> [97], num bytes: 1
f -> b'f' --> [102], num bytes: 1
é -> b'\xc3\xa9' --> [195, 169], num bytes: 2

 UTF-8 encoding of 'café' converted to a list of integers: [99, 97, 102, 195, 169]


Note that utf-8 encoding is variable length, the encoding for a character can range from 1 to 4 bytes. The first 3 chacracters `c`, `a` and `f` are each represented by a single byte, while the accented character `é` is represented by 2 bytes.

In [54]:
# longer sample text (taken from https://www.reedbeta.com/blog/programmers-intro-to-unicode/)
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious even 30 years after Unicode’s inception."

# encode text into utf-8 byte sequence
tokens = text.encode('utf-8') # byte stream
# convert bytes to integers
tokens = list(tokens) # integer tokens

print(f"Original text: {text} \nlength of text: {len(text)} characters \nUTF-8 encoded bytes (each byte converted to an integer): {tokens} \nlength of encoding: {len(tokens)} bytes")


Original text: Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious even 30 years after Unicode’s inception. 
length of text: 532 characters 
UTF-8 encoded bytes (each byte converted to an integer): [239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174,

In this simple tokenization scheme, since each byte is represented as an integer value in the range 0-255, we effectively have a vocabulary size of 256.

In [55]:
vocab = list(range(256)) # 0-255

We will now implement the `Byte-Pair Encoding` algorithm to obtain a new vocabulary which is created by iteratively merging the most frequency co-occuring tokens into a single new token. 

First, let's implement a function for finding the most commonly occuring pair of adjacent tokens and a function for merging the pair and augmenting the vocabulary.

In [56]:
def most_common_pair(tokens):
    """
    Given a list of integers, return the most common pair of integers
    """
    pair_count = defaultdict(int)
    for pair in zip(tokens, tokens[1:]):
        pair_count[pair] = pair_count[pair] + 1
    
    # get the most common pair
    pair = max(pair_count, key=pair_count.get)    

    return pair
    

def merge_pair(tokens, pair, vocab):
    """
    Given a list of integers and a pair of integers, merge the pair into a single integer
    """

    # create a new token that represents the merged pair
    new_token = len(vocab)
    vocab.append(new_token)
    # replace all occurances of the pair in the list of tokens with the new token
    updated_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == pair:
            updated_tokens.append(new_token)
            i += 2
        else:
            updated_tokens.append(tokens[i])
            i += 1
    return updated_tokens

In [57]:
# find most common pair
p = most_common_pair(tokens)
print(f"The most common pair of integers in the text is: {p}, which corresponds to the characters: '{chr(p[0])}' and '{chr(p[1])}'")

# merge the most common pair into a single new token
tokens = merge_pair(tokens, p, vocab)
print(f"The new token list after merging the most common pair is: {tokens}")
print(f"Length of new token list: {len(tokens)}")

The most common pair of integers in the text is: (101, 32), which corresponds to the characters: 'e' and ' '
The new token list after merging the most common pair is: [239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 256, 118, 101, 114, 121, 32, 110, 97, 109, 256, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 256, 105, 110, 116, 111, 32, 116, 104, 256, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119

We can combine the find most common pair and pair merger into a single operation for convenience.

In [59]:
def merge_most_common_pair(tokens, vocab, verbose=False):
    """
    Given a list of integers, find the most common pair and merge it into a single integer
    """
    pair = most_common_pair(tokens)
    updated_tokens = merge_pair(tokens, pair, vocab)
    if verbose:
        print(f"Merged pair {pair} into new token {vocab[-1]}")
        print(f"New token list: {updated_tokens}")
        print(f"Length of new token list: {len(updated_tokens)}")
    return updated_tokens

In [65]:
tokens = merge_most_common_pair(tokens, vocab, verbose=True)

Merged pair (97, 110) into new token 262
New token list: [239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 258, 133, 164, 258, 133, 157, 258, 133, 152, 258, 133, 146, 258, 133, 158, 258, 133, 147, 258, 133, 148, 259, 189, 32, 258, 135, 186, 259, 140, 258, 135, 179, 259, 140, 258, 135, 174, 259, 140, 258, 135, 168, 259, 140, 258, 135, 180, 259, 140, 258, 135, 169, 259, 140, 258, 135, 170, 33, 32, 258, 152, 132, 32, 84, 104, 256, 118, 101, 114, 121, 32, 110, 97, 109, 256, 115, 116, 114, 105, 107, 101, 261, 102, 101, 97, 114, 32, 262, 100, 32, 97, 119, 256, 260, 116, 111, 32, 116, 104, 256, 104, 101, 97, 114, 116, 261, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 261, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 256, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 256, 111, 117, 103, 104, 116, 32, 116, 111, 32, 259, 156, 115, 117, 112, 112, 111, 114, 116, 32, 85, 110, 105, 99, 111, 100, 101, 259, 