# Byte Pair Encoding Tokenizer

Unicode is a mapping that assigns to every character an integer called the code point of that character. This scheme for converting character to numbers is called **character level tokenizaiton**. For concreteness, Unicode currently encodes slightly under 0.2 million characters. 

Since the number of all human characters is large, this one-to-one tokenization leads to a huge vocabulary. One way of reducing the vocabulary size is to group together bytes and leverage **byte level tokenization** by adopt an encoding. For example in the dominant encoding for text on the internet, the UTF-8 encoding, every code point is represented by a seqence of at most four bytes. Each byte encodes an integers in [0, 256) hence the vocabulary size of the tokens is 256. One can also adopt UTF-16 encoding, so as to use a vocabulary of size $2^{16}$. Observe the tradeoff between vocabulary size and the length of tokenized text with respect to that vocabulary. 

The idea of **subword tokenization** is to create tokens out of groups of characters in a word. One proposal of a method for identifying such subwords to be considered a token is called **byte-pair encoding**, which mints new tokens from the most freqently occuring pair of bytes. Thus freqently occuring sets of characters are considered units called tokens.

In [1]:
'ðŸ˜‚'.isidentifier(), 'Ï€'.isidentifier()

(False, True)

In [2]:
# single character
Ï€ = ord('Ï€')  # code point
chr(Ï€) # actual char

'Ï€'

In [3]:
for n in range(65, 123):
    print(chr(n), end='')

ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz

In [4]:
s = 'ðŸ™‚'
utf8_s = s.encode('utf-8')
print(type(utf8_s), utf8_s, list(utf8_s), utf8_s.decode('utf-8'))   # observe four bytes

<class 'bytes'> b'\xf0\x9f\x99\x82' [240, 159, 153, 130] ðŸ™‚


In [5]:
import regex

# pre-tokenization pattern
PAT = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"

corpus = r"""cat? cat! cat, cat. I'm a byte-pair tokenizer3.1415926"""

itr = regex.finditer(PAT, corpus)

for item in itr:
    print(item.group(), end='|')

cat|?| cat|!| cat|,| cat|.| I|'m| a| byte|-pair| tokenizer|3|.|141|592|6|

In [6]:
corpus = r"""
All the world's a stage,
And all the men and women merely players;
They have their exits and their entrances;
And one man in his time plays many parts,
His acts being seven ages. At first the infant,
Mewling and puking in the nurse's arms;
And then the whining school-boy, with his satchel
And shining morning face, creeping like snail
Unwillingly to school. And then the lover,
Sighing like furnace, with a woeful ballad
Made to his mistress' eyebrow. Then a soldier,
Full of strange oaths, and bearded like the pard,
Jealous in honour, sudden and quick in quarrel,
Seeking the bubble reputation
Even in the cannon's mouth. And then the justice,
In fair round belly with good capon lin'd,
With eyes severe and beard of formal cut,
Full of wise saws and modern instances;
And so he plays his part. The sixth age shifts
Into the lean and slipper'd pantaloon,
With spectacles on nose and pouch on side;
His youthful hose, well sav'd, a world too wide
For his shrunk shank; and his big manly voice,
Turning again toward childish treble, pipes
And whistles in his sound. Last scene of all,
That ends this strange eventful history,
Is second childishness and mere oblivion;
Sans teeth, sans eyes, sans taste, sans everything.
"""

# Training a BPE Tokenizer

At the start we know that all 256 possible bytes will be a subset of the final token vocabulary.

We need to pre-tokenize in order to treat semantically close words like `cat,` and `cat.` and `cat!` similarly.



### References

1. [Python Unicode Documenation](https://docs.python.org/3/howto/unicode.html)
2. [GPT 2 tokenizer (paper)](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

In [7]:
from collections import Counter

# initialize
vocab = list(range(256))

corpus_utf8 = list(corpus.encode('utf-8'))
byte_pairs = Counter(zip(corpus_utf8, corpus_utf8[1:]))

In [8]:
# # Approach 1
# s = ''
# for i in list(corpus.encode('utf-8')):
#     s += chr(i)
# print(s)

# # Approach 2
decoded_text = bytes(corpus_utf8).decode('utf-8')
print(decoded_text[:150])


All the world's a stage,
And all the men and women merely players;
They have their exits and their entrances;
And one man in his time plays many part


In [9]:
print(max(corpus_utf8), 'â‰¤', len(vocab))
print(len(corpus_utf8), 'chars', corpus_utf8[:10])
print(len(byte_pairs), 'pairs', 'Counter({(101, 32): 29, (115, 32): 29, (116, 104), ...)')

121 â‰¤ 256
1222 chars [10, 65, 108, 108, 32, 116, 104, 101, 32, 119]
313 pairs Counter({(101, 32): 29, (115, 32): 29, (116, 104), ...)


In [10]:
for (a, b), freq in byte_pairs.most_common(5):
    print(f'({chr(a)}|{chr(b)}) pair count: {freq}')

(e| ) pair count: 29
(s| ) pair count: 29
(t|h) pair count: 28
( |s) pair count: 28
(d| ) pair count: 27


In [11]:
from typing import List, Dict, Tuple

def merge_pair(original_corpus: List[int], p1: int, p2: int, replacement_token: int) -> List[int]:
    L = len(original_corpus)
    if L < 2:
        return original_corpus
    merged_corpus = []
    i = 0
    while i < L:
        # or use replace() method if cropus is str
        if i < L-1 and original_corpus[i: i+2] == [p1, p2]:
            merged_corpus.append(replacement_token)
            i +=2
        else:
            merged_corpus.append(original_corpus[i])
            i += 1
    return merged_corpus


In [12]:
merge_pair([1, 2, 3, 4, 5], 1, 2, 'x')

['x', 3, 4, 5]

In [13]:
original_corpus = corpus_utf8
(p1, p2), freq = byte_pairs.most_common(1)[0]
replacement_token = vocab[-1] + 1
vocab.append(replacement_token)

merged_corpus = merge_pair(original_corpus, p1, p2, replacement_token)
print(max(merged_corpus), merged_corpus[:10])

256 [10, 65, 108, 108, 32, 116, 104, 256, 119, 111]


In [14]:
assert (p1, p2) not in list(zip(merged_corpus, merged_corpus[1:]))
print(f'{len(original_corpus)} = {len(merged_corpus)} + {freq} ({len(original_corpus) == len(merged_corpus) + freq})')

1222 = 1193 + 29 (True)


In [15]:
# naive, and without pretokenization
def train_bpe(num_steps: int, corpus: str):

    initial_vocab_size = 256
    vocab = list(range(initial_vocab_size))
    corpus = list(corpus.encode('utf-8'))
    vocab_table = {i: i for i in vocab} # vocab: byte_pair
    vocab_table = {}
    
    for step in range(num_steps):

        # find most common byte pair
        pairs = Counter(zip(corpus, corpus[1:]))
        (p1, p2), freq = pairs.most_common(1)[0]

        if freq == 1:
            raise Exception('Frequency 1 met')

        # handle ties 
        tied_pairs = []
        for pair, f in pairs.items():
            if f == freq:
                tied_pairs.append(pair)
        p1, p2 = max(tied_pairs)

        # record newly created token
        replacement_token = vocab[-1]+1
        vocab.append(replacement_token)
        vocab_table[replacement_token] = (p1, p2)

        new_corpus = merge_pair(corpus, p1, p2, replacement_token)
        assert len(new_corpus) + freq == len(corpus)
        corpus = new_corpus
        # print(f'At step {step} merged {freq} instances of {(p1, p2)} as new token {replacement_token}')

    return vocab_table, corpus

vocab_size = 300
initial_vocab_size = 256
num_steps = vocab_size - 256
vocab, tokenized_corpus = train_bpe(num_steps, corpus)

In [16]:
# tokens representing merges
# in addition to the initial tokens 0-255
list(vocab.items())[:10]

[(256, (115, 32)),
 (257, (101, 32)),
 (258, (116, 104)),
 (259, (100, 32)),
 (260, (97, 110)),
 (261, (105, 110)),
 (262, (32, 115)),
 (263, (44, 10)),
 (264, (110, 32)),
 (265, (261, 103))]

In [17]:
compression_ratio = len(corpus) / len(tokenized_corpus)
compression_ratio = round(compression_ratio, 2)
compression_ratio

1.57

In [18]:
token_2_str = {i: chr(i) for i in range(initial_vocab_size)}
for token, (a, b) in vocab.items():
    token_2_str[token] = token_2_str[a] + token_2_str[b] # addition of str

# merges
limit = initial_vocab_size + 10
{token: string for (token, string) in token_2_str.items() if initial_vocab_size <= token <= limit}

{256: 's ',
 257: 'e ',
 258: 'th',
 259: 'd ',
 260: 'an',
 261: 'in',
 262: ' s',
 263: ',\n',
 264: 'n ',
 265: 'ing',
 266: 'hi'}

# Encoding and Decoding

In [19]:
def decode(tokens: List[int], token_to_string: Dict[int, str]) -> str:
    s = ''
    for t in tokens:
        s += token_to_string[t]
    return s

In [20]:
type(tokenized_corpus), len(tokenized_corpus), max(tokenized_corpus)

(list, 776, 299)

In [21]:
decoded_text = decode(tokenized_corpus, token_2_str) 
assert decoded_text == corpus
print(decoded_text[:150])


All the world's a stage,
And all the men and women merely players;
They have their exits and their entrances;
And one man in his time plays many part


In [22]:
def encode(corpus: str, pairs_to_token: Dict[Tuple[int, int], int]) -> List[int]:
    corpus_utf8 = list(corpus.encode('utf-8'))

    if len(corpus_utf8) == 1:
        return corpus_utf8
    
    pairs = list(zip(corpus_utf8, corpus_utf8[1:]))
    working_pairs = []

    while True:
        merge_exists = False
        i = 0
        while i < len(pairs):
            
            p = pairs[i]
            
            if p in pairs_to_token:
                token = pairs_to_token[p]
                working_pairs.append(token)
                merge_exists = True # some merging occured
                if i + 1 == len(pairs) - 1:
                    _, b = pairs[i+1]
                    working_pairs.append(b)
                    break
                i += 2
            else:
                if i == len(pairs)-1:
                    working_pairs += p
                    i += 2
                else:
                    working_pairs.append(p[0])
                    i += 1

        if not merge_exists or len(working_pairs) == 1:
            break
        pairs = list(zip(working_pairs, working_pairs[1:]))
        if len(pairs) == 1:
            break
        working_pairs = []
    return working_pairs

txt = 'hello world'
pairs_2_token = {pair: token for (token, pair) in vocab.items()}
for i in encode(txt, pairs_2_token):
    print(i, token_2_str[i])


104 h
101 e
271 ll
111 o
277  w
283 or
299 ld
