# Byte Pair Encoding
- Start with a small vocabulary (for our ease of use let's just assume only ASCII characters)
- "Learn" new vocabulary by scanning through corpus, use frequency of sub-sequence appearence to learn the most important new vocubulary

For simplicity, let's limit our vocabulary to only uppercase and lowercase ASCII, plus numbers. We will first purge the corpus of all other symbols while preserving spaces, note that this is not a realistic scenario.

In [1]:
# Imports
import numpy as np
import pyarrow.parquet as pq
import pandas as pd
import os
import re

In [2]:
# Load corpus, read all lines into memory, clean, then extract all characters, whitespace symbols
from pathlib import Path
from typing import List


WIKITEXT_PATH = "../data/Salesforce/wikitext/wikitext-103-raw-v1"
TRAIN_SET = ["train-00000-of-00002.parquet", "train-00001-of-00002.parquet"]
TRAIN_SET_PATHS = [os.path.join(WIKITEXT_PATH, p) for p in TRAIN_SET]
corpus = pq.ParquetDataset(TRAIN_SET_PATHS).read().to_pandas()

# Clean the corpus of whitespace symbols (assume they're not useful to us)
clear_ws_pattern = re.compile(r"[^a-zA-Z0-9]")

def clean_corpus(txt: str) -> str:
    spaced_txt = clear_ws_pattern.sub(' ', txt)
    return ' '.join(spaced_txt.strip().split())

corpus["text_clean"] = corpus["text"].apply(clean_corpus)

# Now start a dictionary of character counts (importance)
char_count_dict = {
    "a": 0  # as an example
}

def count_chars(txt: str) -> None:
    global char_count_dict
    # kind of like bag of words (bag of chars?)
    for c in list(txt):
        if c not in char_count_dict:
            char_count_dict[c] = 0
        char_count_dict[c] += 1

# Okay now the "cleaning" is done, we can start applying character and word counts
# Count character sum in corpus (not _really_ needded, just for viz)
for index, txt in corpus["text_clean"].items():
    count_chars(txt)

# Removal all empty rows
corpus = corpus[corpus["text_clean"].str.len() > 0]

# Count characters, then sort into descending order, delete the space parameter (considered whitespace)
# Note: the sorting is purely for visualization, it is not needed
del char_count_dict[' ']
char_count_dict = dict(sorted(char_count_dict.items(), key=lambda x: x[1], reverse=True))

# This forms our initial character set
# Note: why are we going through all this trouble to get a corpus? We originally extended to chars
# beyond just alphanum
char_set = set(char_count_dict.keys())

# Now we need to generate our initial word set, same as char_count_dict
# (basically bag-of-words)
word_count_dict = {}

def count_words(txt: str) -> None:
    global word_count_dict

    words = txt.split(' ')
    for word in words:
        if word not in word_count_dict:
            word_count_dict[word] = 0
        word_count_dict[word] += 1

for index, txt in corpus["text_clean"].items():
    count_words(txt)

print("corpus size:", len(word_count_dict))
print("initial vocab size:", len(char_count_dict))

# Should be 64 (a-zA-Z0-9)
print("initial vocab:", char_set)

corpus size: 575352
initial vocab size: 62
initial vocab: {'x', 't', 'p', 'G', 'W', 'B', 'R', 'J', 'T', '0', 'L', 'e', '6', 'A', 'E', '3', '5', 'N', 'r', '9', 'S', 'w', 'q', 'v', '4', 'C', 'Y', '2', 'X', 'K', 'Z', 'V', 'd', 'j', 'l', 'i', 'c', 'o', 's', 'n', 'H', 'z', 'b', 'I', 'k', '7', 'D', 'F', 'g', 'P', '1', 'u', 'f', '8', 'O', 'U', 'y', 'a', 'M', 'h', 'm', 'Q'}


Now we apply the "merging" strategy, essentially merge the most popular/frequent characters in the corpus into new characters, and subsequently replace those characters with our new characters.

As an example: `m,a,t`, may become `m,at` if `at` is popular enough.

To give an example, suppose we have example words: `cat`, `sat`, `hat`, `brat`, `broken`. We can see that common character pairs are `(a, t)`, `(b, r)`
both of which appear more than once. When we apply a "merge", we will merge the most popular character pairs.

The stopping function is typically when vocabulary size reaches a desired limit (ex. 1000), but I'd also like to try a stopping function based on the frequency of each pair. (ex. suppose frequency of `(x, y)` drops below 1% of total corpus and is our next best pair, we stop).

In [3]:
def calc_split_pairs(splits: dict[str, list[str]], word_counts: dict[str, int]) -> tuple[tuple[str, str], int]:
    # identifies the top pair from splits, generates pair from splits
    # this must be regenerated each time splits change
    split_pair_count = {}    
    for word, split in splits.items():
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])

            if pair not in split_pair_count:
                split_pair_count[pair] = 0
            split_pair_count[pair] += word_counts[word]
    
    max_pair = None
    max_count = 0
    for pair, count in split_pair_count.items():
        if count > max_count:
            max_pair = pair
            max_count = count
    return (max_pair, max_count)

def merge_splits(splits: dict[str, list[str]], pair: tuple[str, str]) -> None:
    # TODO: is there a more efficient query structure? perhaps an inverse map
    # from pairs -> words?
    merged = pair[0] + pair[1]
    for word, split in splits.items():
        # there is probably a more efficient way to do this comparison
        for i in range(len(split) - 1):
            # because we're popping elements, i >= len(split) - 1, therefore
            # we need to do an additional check to exit the loop
            if i >= len(split) - 1:
                break

            if split[i] == pair[0] and split[i + 1] == pair[1]:
                # if we find a match, we need to merge this entry
                split[i] = merged
                split.pop(i + 1)

So the general strategy is:
- From splits, generate a pair -> count table
- Find the top ranked pair (max_pair)
- Go back and update splits to replace sub-pair in split, to merged pair (ex. `(t, h) becomes th`)
- Loop back to top, and find next max_pair, do this until criteria is met.

Let's first try the target vocabulary criteria. We will run until we hit 100 vocabulary words (so roughly 40 extra new merges)

In [4]:
# splits are responsible for keeping track of how each word is split
# as an example: cat can be ["c", "a", "t"] or ["c", "at"], splits should represet
# the current state of our vocabulary. So if we add the "at" token to our vocab, the splits
# also need to get updated to reflect that, our character pair in turn is updated each time
# split is updated, so that after sorting we have the new largest pair

# intuitively one can see that ("c", "a"), will likely be more general than ("c", "at"), so BPE
# will likely "learn" shorter pairs first (but not always, depending on corpus).
# see: https://huggingface.co/learn/nlp-course/en/chapter6/5 for more info

splits = {word: [c for c in word] for word in word_count_dict.keys()}

# Begin runs here
TARGET_VOCAB_SIZE = 100

vocab_set = char_set.copy()
it = 0
while len(vocab_set) < TARGET_VOCAB_SIZE:
    it += 1
    (max_pair, max_count) = calc_split_pairs(splits=splits, word_counts=word_count_dict)
    print(it, "max_pair, max_count", max_pair, max_count)
    # then merge relevant splits so we can calculate again
    vocab_set.add(max_pair[0] + max_pair[1])
    merge_splits(splits=splits, pair=max_pair)

1 max_pair, max_count ('t', 'h') 9702736
2 max_pair, max_count ('i', 'n') 7911135
3 max_pair, max_count ('e', 'r') 6860409
4 max_pair, max_count ('a', 'n') 6479598
5 max_pair, max_count ('th', 'e') 6222698
6 max_pair, max_count ('o', 'n') 5550016
7 max_pair, max_count ('e', 'd') 4747943
8 max_pair, max_count ('r', 'e') 4108949
9 max_pair, max_count ('a', 't') 4096933
10 max_pair, max_count ('e', 'n') 3748855
11 max_pair, max_count ('o', 'r') 3699380
12 max_pair, max_count ('a', 'l') 3399189
13 max_pair, max_count ('s', 't') 3390231
14 max_pair, max_count ('a', 'r') 3330730
15 max_pair, max_count ('an', 'd') 3044037
16 max_pair, max_count ('a', 's') 3000753
17 max_pair, max_count ('o', 'f') 2985590
18 max_pair, max_count ('in', 'g') 2665518
19 max_pair, max_count ('e', 's') 2607832
20 max_pair, max_count ('t', 'o') 2587730
21 max_pair, max_count ('i', 's') 2566371
22 max_pair, max_count ('i', 't') 2385249
23 max_pair, max_count ('o', 'u') 2225909
24 max_pair, max_count ('i', 'c') 214961

With the resulting vocabulary having ~40 (38) new elements. All that is left is to assign a unique index to each element. Imagine that the vocabulary is small enough such that each element in the vocab can be encoded in 1 byte. This vocabulary set would essentially "compress", the words it represented (ex. instead of `t, h, e`, being 3 bytes, it could be one `the`, or two `t, he`)

In [5]:
print('resulting vocab set:', vocab_set)

resulting vocab set: {'t', 'G', '5', 'N', 'r', 'the', 'q', 'v', 'X', 'V', 'ic', 'an', 're', 's', 'se', 'in', 'le', 'en', 'x', 'was', 'am', 'B', 'R', 'to', 'L', 'A', '9', 'w', 'ent', '4', 'j', 'ed', 'i', 'c', 'o', 'z', '7', 'ou', 'g', '1', 'u', 'f', 'O', 'U', 'er', 'on', 'm', 'p', 'ro', 'J', 'om', 'T', '6', 'E', 'al', 'as', 'ar', 'C', '2', 'K', 'Z', 'd', 'it', 'n', 'b', 'be', 'D', 'P', 'ly', '8', 'of', 'ion', 'M', 'W', 'ac', 'ing', '0', 'or', 'es', 'e', '3', 'S', 'he', 'Y', 'and', 'at', 'il', 'l', 'is', 'ad', 'H', 'I', 'k', 'F', 'st', 'y', 'a', 'th', 'h', 'Q'}


Now let's try an alternative method, where we keep creating vocabulary sets until the "frequency" of the word falls below some threshold. To do this we need to modify `calc_split_pairs` to return an additional metric.

We define expectation as `1 / (max_count / total_count)`, so if the expectation is 200, then we should expect to see the pairing in roughly 1 out of every 200 pairs. An expectation of 1 means this pairing is present in every pair. So we want to check that our expectation is below some threshold `N`.

We have not fundamentally changed anything about the algo, since expectation in this sense is still a surrogate for popularity.

In [9]:
# Add freq suffix to not override, distinguish from previous run
splits_freq = {word: [c for c in word] for word in word_count_dict.keys()}

# Begin runs here
N = 100  # We want to capture all pairings that appear atleast once every 100 pairings (expected)

vocab_set_freq = char_set.copy()

def calc_split_pairs_freq(splits: dict[str, list[str]], word_counts: dict[str, int]) -> tuple[tuple[str, str], int]:
    # identifies the top pair from splits, generates pair from splits
    # this must be regenerated each time splits change
    split_pair_count = {}
    total_count = 0
    for word, split in splits.items():
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])

            if pair not in split_pair_count:
                split_pair_count[pair] = 0
            
            # add total counter
            word_count = word_counts[word]
            total_count += word_count
            split_pair_count[pair] += word_count
    
    max_pair = None
    max_count = 0
    for pair, count in split_pair_count.items():
        if count > max_count:
            max_pair = pair
            max_count = count
    return (max_pair, max_count, 1 / (max_count / total_count))


expectation = 0
it = 0
while expectation < N:
    it += 1

    (max_pair, max_count, expectation) = calc_split_pairs_freq(splits=splits_freq, word_counts=word_count_dict)
    
    if expectation > N:
        break
    
    vocab_set_freq.add(max_pair[0] + max_pair[1])
    print(it, "max_pair, expectation", max_pair, expectation)
    merge_splits(splits=splits_freq, pair=max_pair)

1 max_pair, expectation ('t', 'h') 33.9099562226572
2 max_pair, expectation ('i', 'n') 40.36293363720882
3 max_pair, expectation ('e', 'r') 45.391678834308564
4 max_pair, expectation ('a', 'n') 47.00061222933892
5 max_pair, expectation ('th', 'e') 47.899717293045555
6 max_pair, expectation ('o', 'n') 52.584132550248505
7 max_pair, expectation ('e', 'd') 60.29827253612775
8 max_pair, expectation ('r', 'e') 68.5199105659379
9 max_pair, expectation ('a', 't') 67.71794144546665
10 max_pair, expectation ('e', 'n') 72.91264559445484
11 max_pair, expectation ('o', 'r') 72.87439543923577
12 max_pair, expectation ('a', 'l') 78.2218055542072
13 max_pair, expectation ('s', 't') 77.4258485631215
14 max_pair, expectation ('a', 'r') 77.79113917969934
15 max_pair, expectation ('an', 'd') 84.02346981984779
16 max_pair, expectation ('a', 's') 84.2210318543379
17 max_pair, expectation ('o', 'f') 83.64368885211968
18 max_pair, expectation ('in', 'g') 92.56743754872412
19 max_pair, expectation ('e', 's') 

KeyboardInterrupt: 