# Byte-Pair encoding 

Byte Pair Encoding (BPE) is a subword tokenization algorithm widely used in Natural Language Processing (NLP), particularly in large language models like GPT, BERT, and RoBERTa. It addresses the challenges of out-of-vocabulary words and large vocabulary sizes by representing words as sequences of smaller, meaningful units called subwords.

How BPE Works:
1. Initialization: The process begins by considering each unique character in the training text as an initial token in the vocabulary.

2. Iterative Merging: The core of BPE involves repeatedly finding and merging the most frequent adjacent pairs of tokens in the training data.
The algorithm identifies the pair of characters or subwords that appear most frequently next to each other.
This most frequent pair is then merged into a new, single token, and this new token is added to the vocabulary.
All occurrences of the merged pair in the training data are replaced with the new token.

3. Vocabulary Expansion: This merging process continues for a predefined number of iterations or until a desired vocabulary size is reached. Each merge adds a new, longer subword unit to the vocabulary.
Tokenization of New Text: When new text needs to be tokenized, the same sequence of merges learned during training is applied to the new data. This allows BPE to effectively handle unseen words by breaking them down into their constituent subword units, which are already present in the vocabulary.

Key Advantages of BPE:

1. Handling Out-of-Vocabulary (OOV) words: BPE can represent rare or unknown words by decomposing them into smaller, known subword units, preventing the issue of completely unknown tokens.

2. Reduced Vocabulary Size: By merging frequent character pairs, BPE can significantly reduce the overall vocabulary size compared to word-level tokenization, making models more efficient.

3. Balance between Word and Character Level: It strikes a balance between character-level and word-level tokenization, capturing both semantic information from subwords and the ability to compose full words.

In [1]:
corpus = [
    "This is the first document.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

In [2]:
print("Training Corpus:")
for doc in corpus:
    print(doc)

Training Corpus:
This is the first document.
This document is the second document.
And this is the third one.
Is this the first document?


In [4]:
# Initialize vocabulary with unique characters
unique_chars = set()
for doc in corpus:
    for char in doc:
        unique_chars.add(char)

unique_chars

{' ',
 '.',
 '?',
 'A',
 'I',
 'T',
 'c',
 'd',
 'e',
 'f',
 'h',
 'i',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u'}

In [5]:
vocab = list(unique_chars)
vocab.sort()
vocab

[' ',
 '.',
 '?',
 'A',
 'I',
 'T',
 'c',
 'd',
 'e',
 'f',
 'h',
 'i',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u']

In [6]:
end_of_word = "</w>"
vocab.append(end_of_word)

print("Initial Vocabulary:")
print(vocab)
print(f"Vocabulary Size: {len(vocab)}")


Initial Vocabulary:
[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
Vocabulary Size: 20


In [7]:
# Pre-tokenize the corpus: Split into words and then characters
# We'll split by space for simplicity and add the end-of-word token
word_splits = {}
for doc in corpus:
    words = doc.split(' ')
    for word in words:
        if word:
            char_list = list(word) + [end_of_word]
            # Use tuple for immutability if storing counts later - you can't change tuple once it's created (values, order, adding, removing elements, etc.), so they can be used as dictionary keys because of that.
            word_tuple = tuple(char_list)
            if word_tuple not in word_splits:
                 word_splits[word_tuple] = 0
            word_splits[word_tuple] += 1 # Count frequency of each initial word split

print("\nPre-tokenized Word Frequencies:")
print(word_splits)


Pre-tokenized Word Frequencies:
{('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}


Helper function

In [9]:
import collections

In [13]:
def get_pair_stats(splits):
    """Counts the frequency of adjacent pairs in the word splits."""
    # Initialize a dictionary with default values of 0 to count pairs of symbols.
    # defaultdict: It's like a regular dictionary (dict), but with a key difference.
    # If you try to access or modify a key that doesn't exist, instead of raising a KeyError,
    # it automatically creates that key and assigns it a default value.
    # int: This is the "default factory" you provide when creating the defaultdict. When a new key is created, it needs a default value, defaultdict calls this factory function. int() called with no arguments returns 0.
    pair_counts = collections.defaultdict(int)
    for word_tuple , freq in splits.items():
        symbols = list(word_tuple)
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pair_counts[pair] += freq 
    
    return pair_counts

In [26]:
get_pair_stats(splits=word_splits)

defaultdict(int,
            {('T', 'h'): 2,
             ('h', 'i'): 5,
             ('i', 's'): 7,
             ('s', '</w>'): 8,
             ('t', 'h'): 7,
             ('h', 'e'): 4,
             ('e', '</w>'): 4,
             ('f', 'i'): 2,
             ('i', 'r'): 3,
             ('r', 's'): 2,
             ('s', 't'): 2,
             ('t', '</w>'): 3,
             ('d', 'o'): 4,
             ('o', 'c'): 4,
             ('c', 'u'): 4,
             ('u', 'm'): 4,
             ('m', 'e'): 4,
             ('e', 'n'): 4,
             ('n', 't'): 4,
             ('t', '.'): 2,
             ('.', '</w>'): 3,
             ('s', 'e'): 1,
             ('e', 'c'): 1,
             ('c', 'o'): 1,
             ('o', 'n'): 2,
             ('n', 'd'): 2,
             ('d', '</w>'): 3,
             ('A', 'n'): 1,
             ('r', 'd'): 1,
             ('n', 'e'): 1,
             ('e', '.'): 1,
             ('I', 's'): 1,
             ('t', '?'): 1,
             ('?', '</w>'): 1})

In [47]:
def merge_pair(pair_to_merge, splits):
    """Merges the specified pair in the word splits."""
    new_splits = {}
    (first, second) = pair_to_merge
    merged_token = first + second
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        new_symbols = []
        i = 0
        while i < len(symbols):
            # If the current and next symbol match the pair to merge
            if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
                new_symbols.append(merged_token)
                i += 2 # Skip the next symbol
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[tuple(new_symbols)] = freq # Use the updated symbol list as the key
    return new_splits

 Iterative BPE Merging Loop

Now we perform the core BPE training. We'll loop for a fixed number of merges (`num_merges`). In each iteration:
1. Calculate the frequencies of all adjacent pairs in the current word representations using `get_pair_stats`.
2. Find the pair with the highest frequency (`best_pair`).
3. Merge this `best_pair` across all word representations using `merge_pair`.
4. Add the newly formed token (concatenation of `best_pair`) to our vocabulary (`vocab`).
5. Store the merge rule (mapping the pair to the new token) in the `merges` dictionary.

We'll add print statements to observe the state at each step of the loop.


In [48]:
# --- BPE Training Loop Initialization ---
num_merges = 15
# Stores merge rules, e.g., {('a', 'b'): 'ab'}
# Example: {('T', 'h'): 'Th'}
merges = {}
# Initial word splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 2, ...}
current_splits = word_splits.copy() # Start with initial word splits

print("\n--- Starting BPE Merges ---")
print(f"Initial Splits: {current_splits}")
print("-" * 30)

for i in range(num_merges):
    print(f"\nMerge Iteration {i+1}/{num_merges}")

    # 1. Calculate Pair Frequencies
    pair_stats = get_pair_stats(current_splits)
    if not pair_stats:
        print("No more pairs to merge.")
        break
    # Optional: Print top 5 pairs for inspection
    sorted_pairs = sorted(pair_stats.items(), key=lambda item: item[1], reverse=True)
    print(f"Top 5 Pair Frequencies: {sorted_pairs[:5]}")

    # 2. Find Best Pair
    # The 'max' function iterates over all key-value pairs in the 'pair_stats' dictionary
    # The 'key=pair_stats.get' tells 'max' to use the frequency (value) for comparison, not the pair (key) itself
    # This way, 'max' selects the pair with the highest frequency
    best_pair = max(pair_stats, key=pair_stats.get)
    best_freq = pair_stats[best_pair]
    print(f"Found Best Pair: {best_pair} with Frequency: {best_freq}")

    # 3. Merge the Best Pair
    current_splits = merge_pair(best_pair, current_splits)
    new_token = best_pair[0] + best_pair[1]
    print(f"Merging {best_pair} into '{new_token}'")
    print(f"Splits after merge: {current_splits}")

    # 4. Update Vocabulary
    vocab.append(new_token)
    print(f"Updated Vocabulary: {vocab}")

    # 5. Store Merge Rule
    merges[best_pair] = new_token
    print(f"Updated Merges: {merges}")

    print("-" * 30)


--- Starting BPE Merges ---
Initial Splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}
------------------------------

Merge Iteration 1/15
Top 5 Pair Frequencies: [(('s', '</w>'), 8), (('i', 's'), 7), (('t', 'h'), 7), (('h', 'i'), 5), (('h', 'e'), 4)]
Found Best Pair: ('s', '</w>') with Frequency: 8
Merging ('s', '</w>') into 's</w>'
Splits after merge: {('T', 'h', 'i', 's</w>'): 2, ('i', 's</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', '

Review Final Results

After the loop finishes, we can examine the final state:
- The learned merge rules (`merges`).
- The final representation of words after merges (`current_splits`).
- The complete vocabulary (`vocab`) containing initial characters and learned subword tokens.


In [49]:
print("\n--- BPE Merges Complete ---")
print(f"Final Vocabulary Size: {len(vocab)}")
print("\nLearned Merges (Pair -> New Token):")
# Pretty print merges
for pair, token in merges.items():
    print(f"{pair} -> '{token}'")

print("\nFinal Word Splits after all merges:")
print(current_splits)

print("\nFinal Vocabulary (sorted):")
# Sort for consistent viewing
final_vocab_sorted = sorted(list(set(vocab))) # Use set to remove potential duplicates if any step introduced them
print(final_vocab_sorted)



--- BPE Merges Complete ---
Final Vocabulary Size: 35

Learned Merges (Pair -> New Token):
('s', '</w>') -> 's</w>'
('i', 's</w>') -> 'is</w>'
('t', 'h') -> 'th'
('th', 'e') -> 'the'
('the', '</w>') -> 'the</w>'
('d', 'o') -> 'do'
('do', 'c') -> 'doc'
('doc', 'u') -> 'docu'
('docu', 'm') -> 'docum'
('docum', 'e') -> 'docume'
('docume', 'n') -> 'documen'
('documen', 't') -> 'document'
('i', 'r') -> 'ir'
('.', '</w>') -> '.</w>'
('d', '</w>') -> 'd</w>'

Final Word Splits after all merges:
{('T', 'h', 'is</w>'): 2, ('is</w>',): 3, ('the</w>',): 4, ('f', 'ir', 's', 't', '</w>'): 2, ('document', '.</w>'): 2, ('document', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd</w>'): 1, ('A', 'n', 'd</w>'): 1, ('th', 'is</w>'): 2, ('th', 'ir', 'd</w>'): 1, ('o', 'n', 'e', '.</w>'): 1, ('I', 's</w>'): 1, ('document', '?', '</w>'): 1}

Final Vocabulary (sorted):
[' ', '.', '.</w>', '</w>', '?', 'A', 'I', 'T', 'c', 'd', 'd</w>', 'do', 'doc', 'docu', 'docum', 'docume', 'documen', 'document', 'e', 'f', 'h', '