# Byte Pair Encoding Tokenizer

Unicode is a mapping that assigns to every character an integer called the code point of that character. This scheme for converting character to numbers is called **character level tokenizaiton**. For concreteness, Unicode currently encodes slightly under 0.2 million characters. 

Since the number of all human characters is large, this one-to-one tokenization leads to a huge vocabulary. One way of reducing the vocabulary size is to group together bytes and leverage **byte level tokenization** by adopt an encoding. For example in the dominant encoding for text on the internet, the UTF-8 encoding, every code point is represented by a seqence of at most four bytes. Each byte encodes an integers in [0, 256) hence the vocabulary size of the tokens is 256. One can also adopt UTF-16 encoding, so as to use a vocabulary of size $2^{16}$. Observe the tradeoff between vocabulary size and the length of tokenized text with respect to that vocabulary. 

The idea of **subword tokenization** is to create tokens out of groups of characters in a word. One proposal of a method for identifying such subwords to be considered a token is called **byte-pair encoding**, which mints new tokens from the most freqently occuring pair of bytes. Thus freqently occuring sets of characters are considered units called tokens.

# Training a BPE Tokenizer

At the start we know that all 256 possible bytes will be a subset of the final token vocabulary.

We need to pre-tokenize in order to treat semantically close words like `cat,` and `cat.` and `cat!` similarly.



### References

1. [Python Unicode Documenation](https://docs.python.org/3/howto/unicode.html)
2. [GPT 2 tokenizer (paper)](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

In [1]:
'ðŸ˜‚'.isidentifier(), 'Ï€'.isidentifier()

(False, True)

In [2]:
# single character
Ï€ = ord('Ï€')  # code point
chr(Ï€) # actual char

'Ï€'

In [3]:
for n in range(65, 123):
    print(chr(n), end='')

ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz

In [4]:
s = 'ðŸ™‚'
utf8_s = s.encode('utf-8')
print(type(utf8_s), utf8_s, list(utf8_s), utf8_s.decode('utf-8'))   # observe four bytes

<class 'bytes'> b'\xf0\x9f\x99\x82' [240, 159, 153, 130] ðŸ™‚


In [5]:
import regex

# pre-tokenization pattern
PAT = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"

corpus = r"""cat? cat! cat, cat. I'm a byte-pair tokenizer3.1415926"""

itr = regex.finditer(PAT, corpus)

for item in itr:
    print(item.group(), end='|')

cat|?| cat|!| cat|,| cat|.| I|'m| a| byte|-pair| tokenizer|3|.|141|592|6|

In [6]:
corpus = r"""
All the world's a stage,
And all the men and women merely players;
They have their exits and their entrances;
And one man in his time plays many parts,
His acts being seven ages. At first the infant,
Mewling and puking in the nurse's arms;
And then the whining school-boy, with his satchel
And shining morning face, creeping like snail
Unwillingly to school. And then the lover,
Sighing like furnace, with a woeful ballad
Made to his mistress' eyebrow. Then a soldier,
Full of strange oaths, and bearded like the pard,
Jealous in honour, sudden and quick in quarrel,
Seeking the bubble reputation
Even in the cannon's mouth. And then the justice,
In fair round belly with good capon lin'd,
With eyes severe and beard of formal cut,
Full of wise saws and modern instances;
And so he plays his part. The sixth age shifts
Into the lean and slipper'd pantaloon,
With spectacles on nose and pouch on side;
His youthful hose, well sav'd, a world too wide
For his shrunk shank; and his big manly voice,
Turning again toward childish treble, pipes
And whistles in his sound. Last scene of all,
That ends this strange eventful history,
Is second childishness and mere oblivion;
Sans teeth, sans eyes, sans taste, sans everything.
"""

In [7]:
from collections import Counter

# initialize
vocab = list(range(256))

corpus_utf8 = list(corpus.encode('utf-8'))
byte_pairs = Counter(zip(corpus_utf8, corpus_utf8[1:]))

In [8]:
print(max(corpus_utf8), 'â‰¤', len(vocab))
print(len(corpus_utf8), 'chars', corpus_utf8)
print(len(byte_pairs), 'pairs', byte_pairs)

121 â‰¤ 256
1222 chars [10, 65, 108, 108, 32, 116, 104, 101, 32, 119, 111, 114, 108, 100, 39, 115, 32, 97, 32, 115, 116, 97, 103, 101, 44, 10, 65, 110, 100, 32, 97, 108, 108, 32, 116, 104, 101, 32, 109, 101, 110, 32, 97, 110, 100, 32, 119, 111, 109, 101, 110, 32, 109, 101, 114, 101, 108, 121, 32, 112, 108, 97, 121, 101, 114, 115, 59, 10, 84, 104, 101, 121, 32, 104, 97, 118, 101, 32, 116, 104, 101, 105, 114, 32, 101, 120, 105, 116, 115, 32, 97, 110, 100, 32, 116, 104, 101, 105, 114, 32, 101, 110, 116, 114, 97, 110, 99, 101, 115, 59, 10, 65, 110, 100, 32, 111, 110, 101, 32, 109, 97, 110, 32, 105, 110, 32, 104, 105, 115, 32, 116, 105, 109, 101, 32, 112, 108, 97, 121, 115, 32, 109, 97, 110, 121, 32, 112, 97, 114, 116, 115, 44, 10, 72, 105, 115, 32, 97, 99, 116, 115, 32, 98, 101, 105, 110, 103, 32, 115, 101, 118, 101, 110, 32, 97, 103, 101, 115, 46, 32, 65, 116, 32, 102, 105, 114, 115, 116, 32, 116, 104, 101, 32, 105, 110, 102, 97, 110, 116, 44, 10, 77, 101, 119, 108, 105, 110, 103, 32, 97,

In [9]:
for (a, b), freq in byte_pairs.most_common(10):
    print(f'({chr(a)}|{chr(b)}) pair count: {freq}')

(e| ) pair count: 29
(s| ) pair count: 29
(t|h) pair count: 28
( |s) pair count: 28
(d| ) pair count: 27
(a|n) pair count: 27
( |t) pair count: 25
(i|n) pair count: 24
(n|d) pair count: 23
(n| ) pair count: 23


In [10]:
from typing import List

def merge_pair(original_corpus: List[int], p1: int, p2: int, replacement_token: int) -> List[int]:
    L = len(original_corpus)
    if L < 2:
        return original_corpus
    merged_corpus = []
    i = 0
    while i < L:
        if i < L-1 and original_corpus[i: i+2] == [p1, p2]:
            merged_corpus.append(replacement_token)
            i +=2
        else:
            merged_corpus.append(original_corpus[i])
            i += 1
    return merged_corpus


In [11]:
merge_pair([1, 2, 3, 4, 5], 1, 2, 'x')

['x', 3, 4, 5]

In [12]:
original_corpus = corpus_utf8
(p1, p2), freq = byte_pairs.most_common(1)[0]
replacement_token = vocab[-1] + 1
vocab.append(replacement_token)

merged_corpus = merge_pair(original_corpus, p1, p2, replacement_token)
print(max(merged_corpus), merged_corpus)

256 [10, 65, 108, 108, 32, 116, 104, 256, 119, 111, 114, 108, 100, 39, 115, 32, 97, 32, 115, 116, 97, 103, 101, 44, 10, 65, 110, 100, 32, 97, 108, 108, 32, 116, 104, 256, 109, 101, 110, 32, 97, 110, 100, 32, 119, 111, 109, 101, 110, 32, 109, 101, 114, 101, 108, 121, 32, 112, 108, 97, 121, 101, 114, 115, 59, 10, 84, 104, 101, 121, 32, 104, 97, 118, 256, 116, 104, 101, 105, 114, 32, 101, 120, 105, 116, 115, 32, 97, 110, 100, 32, 116, 104, 101, 105, 114, 32, 101, 110, 116, 114, 97, 110, 99, 101, 115, 59, 10, 65, 110, 100, 32, 111, 110, 256, 109, 97, 110, 32, 105, 110, 32, 104, 105, 115, 32, 116, 105, 109, 256, 112, 108, 97, 121, 115, 32, 109, 97, 110, 121, 32, 112, 97, 114, 116, 115, 44, 10, 72, 105, 115, 32, 97, 99, 116, 115, 32, 98, 101, 105, 110, 103, 32, 115, 101, 118, 101, 110, 32, 97, 103, 101, 115, 46, 32, 65, 116, 32, 102, 105, 114, 115, 116, 32, 116, 104, 256, 105, 110, 102, 97, 110, 116, 44, 10, 77, 101, 119, 108, 105, 110, 103, 32, 97, 110, 100, 32, 112, 117, 107, 105, 110, 103

In [13]:
assert (p1, p2) not in list(zip(merged_corpus, merged_corpus[1:]))
print(f'{len(original_corpus)} = {len(merged_corpus)} + {freq} ({len(original_corpus) == len(merged_corpus) + freq})')

1222 = 1193 + 29 (True)


In [14]:
# naive, and without pretokenization
def train_bpe(num_steps: int, corpus: str):
    vocab = list(range(256))
    corpus = list(corpus.encode('utf-8'))
    for step in range(num_steps):
        pairs = Counter(zip(corpus, corpus[1:]))
        (p1, p2), freq = pairs.most_common(1)[0]
        print(f'At step {step} merging {freq} instances of {(p1, p2)}')
        replacement_token = vocab[-1]+1
        vocab.append(replacement_token)
        corpus = merge_pair(corpus, p1, p2, replacement_token)
        
    return vocab, corpus

vocab, tokenized_corpus = train_bpe(20, corpus)
len(vocab)

At step 0 merging 29 instances of (101, 32)
At step 1 merging 29 instances of (115, 32)
At step 2 merging 28 instances of (116, 104)
At step 3 merging 27 instances of (100, 32)
At step 4 merging 27 instances of (97, 110)
At step 5 merging 24 instances of (105, 110)
At step 6 merging 15 instances of (32, 115)
At step 7 merging 14 instances of (44, 10)
At step 8 merging 13 instances of (101, 110)
At step 9 merging 12 instances of (104, 105)
At step 10 merging 12 instances of (261, 103)
At step 11 merging 11 instances of (258, 256)
At step 12 merging 11 instances of (260, 259)
At step 13 merging 10 instances of (32, 267)
At step 14 merging 10 instances of (110, 259)
At step 15 merging 9 instances of (108, 108)
At step 16 merging 9 instances of (101, 114)
At step 17 merging 9 instances of (111, 110)
At step 18 merging 9 instances of (115, 116)
At step 19 merging 8 instances of (65, 270)


276