# Byte Pair Encoding
- Start with a small vocabulary (for our ease of use let's just assume only ASCII characters)
- "Learn" new vocabulary by scanning through corpus, use frequency of sub-sequence appearence to learn the most important new vocubulary

For simplicity, let's limit our vocabulary to only uppercase and lowercase ASCII, plus numbers. We will first purge the corpus of all other symbols while preserving spaces, note that this is not a realistic scenario.

In [4]:
# Imports
import numpy as np
import pyarrow.parquet as pq
import pandas as pd
import os
import re

In [5]:
# Load corpus, read all lines into memory, clean, then extract all characters, whitespace symbols
from pathlib import Path
from typing import List


WIKITEXT_PATH = "../data/Salesforce/wikitext/wikitext-103-raw-v1"
TRAIN_SET = ["train-00000-of-00002.parquet", "train-00001-of-00002.parquet"]
TRAIN_SET_PATHS = [os.path.join(WIKITEXT_PATH, p) for p in TRAIN_SET]
corpus = pq.ParquetDataset(TRAIN_SET_PATHS).read().to_pandas()

# Clean the corpus of whitespace symbols (assume they're not useful to us)
clear_ws_pattern = re.compile(r"[^a-zA-Z0-9]")

def clean_corpus(txt: str) -> str:
    spaced_txt = clear_ws_pattern.sub(' ', txt)
    return ' '.join(spaced_txt.strip().split())

corpus["text_clean"] = corpus["text"].apply(clean_corpus)

# Now start a dictionary of character counts (importance)
char_count_dict = {
    "a": 0  # as an example
}

def count_chars(txt: str) -> None:
    global char_count_dict
    # kind of like bag of words (bag of chars?)
    for c in list(txt):
        if c not in char_count_dict:
            char_count_dict[c] = 0
        char_count_dict[c] += 1

# Okay now the "cleaning" is done, we can start applying character and word counts
# Count character sum in corpus (not _really_ needded, just for viz)
for index, txt in corpus["text_clean"].items():
    count_chars(txt)

# Removal all empty rows
corpus = corpus[corpus["text_clean"].str.len() > 0]

# Count characters, then sort into descending order, delete the space parameter (considered whitespace)
# Note: the sorting is purely for visualization, it is not needed
del char_count_dict[' ']
char_count_dict = dict(sorted(char_count_dict.items(), key=lambda x: x[1], reverse=True))

# This forms our initial character set
# Note: why are we going through all this trouble to get a corpus? We originally extended to chars
# beyond just alphanum
char_set = set(char_count_dict.keys())

# Now we need to generate our initial word set, same as char_count_dict
# (basically bag-of-words)
word_count_dict = {}

def count_words(txt: str) -> None:
    global word_count_dict

    words = txt.split(' ')
    for word in words:
        if word not in word_count_dict:
            word_count_dict[word] = 0
        word_count_dict[word] += 1

for index, txt in corpus["text_clean"].items():
    count_words(txt)

print("corpus size:", len(word_count_dict))
print("initial vocab size:", len(char_count_dict))

# Should be 64 (a-zA-Z0-9)
print("initial vocab:", char_set)

corpus size: 575352
initial vocab size: 62
initial vocab: {'g', 'T', 'x', 'i', 'N', 'z', 'v', 'Z', 'm', 'B', 'u', 'E', 'W', 'p', 'o', '3', 't', '9', '7', 'b', 'q', '1', '8', 'a', 'j', 'F', '5', '4', 'X', 'Y', '2', 'J', '0', 'l', 'C', 'w', 'd', 'S', 'y', 'R', 'Q', 'A', 'c', 'D', 'H', 'k', 'U', 'M', '6', 'O', 'G', 'e', 'P', 'r', 's', 'I', 'V', 'L', 'h', 'f', 'n', 'K'}


Now we apply the "merging" strategy, essentially merge the most popular/frequent characters in the corpus into new characters, and subsequently replace those characters with our new characters.

As an example: `m,a,t`, may become `m,at` if `at` is popular enough.

To give an example, suppose we have example words: `cat`, `sat`, `hat`, `brat`, `broken`. We can see that common character pairs are `(a, t)`, `(b, r)`
both of which appear more than once. When we apply a "merge", we will merge the most popular character pairs.

The stopping function is typically when vocabulary size reaches a desired limit (ex. 1000), but I'd also like to try a stopping function based on the frequency of each pair. (ex. suppose frequency of `(x, y)` drops below 1% of total corpus and is our next best pair, we stop).

In [6]:
def calc_split_pairs(splits: dict[str, list[str]], word_counts: dict[str, int]) -> tuple[tuple[str, str], int]:
    # identifies the top pair from splits, generates pair from splits
    # this must be regenerated each time splits change
    split_pair_count = {}    
    for word, split in splits.items():
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])

            if pair not in split_pair_count:
                split_pair_count[pair] = 0
            split_pair_count[pair] += word_counts[word]
    
    max_pair = None
    max_count = 0
    for pair, count in split_pair_count.items():
        if count > max_count:
            max_pair = pair
            max_count = count
    return (max_pair, max_count)

def merge_splits(splits: dict[str, list[str]], pair: tuple[str, str]) -> None:
    # TODO: is there a more efficient query structure? perhaps an inverse map
    # from pairs -> words?
    merged = pair[0] + pair[1]
    for word, split in splits.items():
        # there is probably a more efficient way to do this comparison
        for i in range(len(split) - 1):
            # because we're popping elements, i >= len(split) - 1, therefore
            # we need to do an additional check to exit the loop
            if i >= len(split) - 1:
                break

            if split[i] == pair[0] and split[i + 1] == pair[1]:
                # if we find a match, we need to merge this entry
                split[i] = merged
                split.pop(i + 1)

So the general strategy is:
- From splits, generate a pair -> count table
- Find the top ranked pair (max_pair)
- Go back and update splits to replace sub-pair in split, to merged pair (ex. `(t, h) becomes th`)
- Loop back to top, and find next max_pair, do this until criteria is met.

Let's first try the target vocabulary criteria. We will run until we hit 100 vocabulary words (so roughly 40 extra new merges)

In [None]:
# splits are responsible for keeping track of how each word is split
# as an example: cat can be ["c", "a", "t"] or ["c", "at"], splits should represet
# the current state of our vocabulary. So if we add the "at" token to our vocab, the splits
# also need to get updated to reflect that, our character pair in turn is updated each time
# split is updated, so that after sorting we have the new largest pair

# intuitively one can see that ("c", "a"), will likely be more general than ("c", "at"), so BPE
# will likely "learn" shorter pairs first (but not always, depending on corpus).
# see: https://huggingface.co/learn/nlp-course/en/chapter6/5 for more info

splits = {word: [c for c in word] for word in word_count_dict.keys()}

# Begin runs here
TARGET_VOCAB_SIZE = 100

vocab_set = char_set.copy()
it = 0
while len(vocab_set) < TARGET_VOCAB_SIZE:
    it += 1
    (max_pair, max_count) = calc_split_pairs(splits=splits, word_counts=word_count_dict)
    print(it, "max_pair, max_count", max_pair, max_count)
    # then merge relevant splits so we can calculate again
    vocab_set.add(max_pair[0] + max_pair[1])
    merge_splits(splits=splits, pair=max_pair)

1 max_pair, max_count ('ar', 'k') 138393
2 max_pair, max_count ('A', 'n') 138087
3 max_pair, max_count ('c', 'ent') 137841
4 max_pair, max_count ('ab', 'le') 136648
5 max_pair, max_count ('ad', 'e') 136536
6 max_pair, max_count ('m', 'on') 136099
7 max_pair, max_count ('the', 'y') 135793
8 max_pair, max_count ('in', 'cl') 134373
9 max_pair, max_count ('u', 'e') 134268
10 max_pair, max_count ('it', 'e') 134267
11 max_pair, max_count ('an', 'g') 134086
12 max_pair, max_count ('a', 'il') 133774
13 max_pair, max_count ('U', 'n') 132786
14 max_pair, max_count ('wor', 'k') 132091
15 max_pair, max_count ('e', 'ver') 131766
16 max_pair, max_count ('c', 'ed') 131612
17 max_pair, max_count ('f', 'f') 131229
18 max_pair, max_count ('a', 'ir') 130917
19 max_pair, max_count ('c', 'r') 130712
20 max_pair, max_count ('s', 'ion') 130606
21 max_pair, max_count ('cor', 'd') 130382
22 max_pair, max_count ('tim', 'e') 130098


KeyboardInterrupt: 