<a href="https://colab.research.google.com/github/zedware/notebook/blob/master/Byte_Pair_Encoding_(BPE)_Algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import re
import collections

def get_word_counts(text):
    """
    Counts the frequency of each 'word' in the text.
    We use a simple regex to split text into words and punctuation.
    This pre-tokenization step is important.
    """
    # Find all sequences of letters/numbers (words) or any single non-whitespace/non-word char (punctuation)
    words = re.findall(r"\w+|[^\s\w]+", text)
    return collections.Counter(words)

def initialize_splits(word_counts):
    """
    Initializes the 'splits' dictionary.
    Each word is split into a list of its individual characters.
    Example: "low" -> ['l', 'o', 'w']
    """
    return {word: list(word) for word in word_counts.keys()}

def get_pair_frequencies(splits, word_counts):
    """
    Counts the frequency of each adjacent pair of tokens in the corpus.
    This is the core of the BPE algorithm.
    """
    pair_freqs = collections.defaultdict(int)
    for word, count in word_counts.items():
        tokens = splits[word]
        # Iterate through adjacent pairs
        for i in range(len(tokens) - 1):
            pair = (tokens[i], tokens[i+1])
            pair_freqs[pair] += count
    return pair_freqs

def merge_pair(best_pair, splits):
    """
    Merges the 'best_pair' in all words in our 'splits' dictionary.
    Example: if best_pair = ('l', 'o'),
    ['l', 'o', 'w'] becomes ['lo', 'w']
    """
    new_token = "".join(best_pair)
    new_splits = {}
    for word, tokens in splits.items():
        new_tokens = []
        i = 0
        while i < len(tokens):
            # Check if we found the pair to merge
            if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
                new_tokens.append(new_token)
                i += 2  # Skip both tokens
            else:
                new_tokens.append(tokens[i])
                i += 1
        new_splits[word] = new_tokens
    return new_splits

def train_bpe(text, vocab_size):
    """
    Trains a BPE tokenizer from a text corpus.

    Args:
        text (str): The input text corpus.
        vocab_size (int): The desired final vocabulary size.

    Returns:
        tuple: (merges, vocab)
            merges (dict): A dictionary of merge rules, e.g., {('l', 'o'): 'lo'}
            vocab (set): The final vocabulary of tokens.
    """

    # --- 1. Pre-tokenization and Initialization ---

    # Get word counts
    word_counts = get_word_counts(text)

    # Initialize how each word is split (initially, into characters)
    splits = initialize_splits(word_counts)

    # Initialize the vocabulary with all unique characters
    vocab = set()
    for word in word_counts:
        vocab.update(list(word))

    # --- 2. Iterative Merging ---

    # Calculate how many merges we need to do
    num_merges = vocab_size - len(vocab)

    # We store merge rules in an ordered dict (Python 3.7+ dicts are ordered)
    # This preserves the merge priority, which is crucial for encoding.
    merges = {}

    for i in range(num_merges):
        # Find the most frequent pair
        pair_freqs = get_pair_frequencies(splits, word_counts)

        # If there are no more pairs, we're done
        if not pair_freqs:
            print("No more pairs to merge. Stopping early.")
            break

        best_pair = max(pair_freqs, key=pair_freqs.get)
        new_token = "".join(best_pair)

        # Print the merge
        print(f"Merge {i+1}/{num_merges}: {best_pair} -> {new_token} (freq: {pair_freqs[best_pair]})")

        # Apply this merge to our 'splits' data
        splits = merge_pair(best_pair, splits)

        # Save the merge rule and add the new token to the vocab
        merges[best_pair] = new_token
        vocab.add(new_token)

    return merges, vocab

def encode(text, merges):
    """
    Encodes a new string using the learned merge rules.

    Args:
        text (str): The text to encode.
        merges (dict): The learned merge rules.

    Returns:
        list: A list of tokens.
    """

    # Pre-tokenize the text into 'words'
    words = re.findall(r"\w+|[^\s\w]+", text)

    final_tokens = []

    for word in words:
        # Start by splitting the word into characters
        tokens = list(word)

        # Keep applying merges until no more merges are possible
        while True:
            # Find the first merge rule that can be applied
            best_pair = None
            min_priority = float('inf')

            # Find the highest-priority (earliest learned) merge
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i+1])
                if pair in merges:
                    # Find the priority (order) of this merge
                    priority = list(merges.keys()).index(pair)
                    if priority < min_priority:
                        min_priority = priority
                        best_pair = pair
                        best_pair_idx = i

            # If no merge rules can be applied to this word, we're done
            if best_pair is None:
                break

            # Apply the highest-priority merge
            new_token = merges[best_pair]
            # Rebuild the token list
            tokens = tokens[:best_pair_idx] + [new_token] + tokens[best_pair_idx+2:]

        # Add the fully tokenized word to our final list
        final_tokens.extend(tokens)

    return final_tokens

# --- Main execution ---
if __name__ == "__main__":

    # 1. Define a sample corpus
    # (Using the example from the previous explanation)
    corpus = (
        "low low low lower lowest\n"
        "new newer newest\n"
        "slow slower slowest"
    )

    # 2. Train the tokenizer
    # We want a small vocab size for this demo
    # Initial chars: l, o, w, e, r, s, t, n,
    # (approx 10 chars)
    # Let's aim for a total vocab of 20
    target_vocab_size = 20

    print("--- Training BPE ---")
    merges, vocab = train_bpe(corpus, target_vocab_size)
    print("--------------------")

    # 3. Show the results
    print(f"\nFinal Vocabulary ({len(vocab)} tokens):")
    # Sort for readability
    print(sorted(list(vocab)))

    print("\nLearned Merge Rules (in order of priority):")
    print(merges)

    # 4. Test the encoder
    print("\n--- Encoding Examples ---")

    text1 = "low"
    print(f"'{text1}' -> {encode(text1, merges)}")

    text5 = "aloha"
    print(f"'{text5}' -> {encode(text5, merges)}")


--- Training BPE ---
Merge 1/12: ('l', 'o') -> lo (freq: 8)
Merge 2/12: ('lo', 'w') -> low (freq: 8)
Merge 3/12: ('low', 'e') -> lowe (freq: 4)
Merge 4/12: ('s', 't') -> st (freq: 3)
Merge 5/12: ('n', 'e') -> ne (freq: 3)
Merge 6/12: ('ne', 'w') -> new (freq: 3)
Merge 7/12: ('lowe', 'r') -> lower (freq: 2)
Merge 8/12: ('lowe', 'st') -> lowest (freq: 2)
Merge 9/12: ('new', 'e') -> newe (freq: 2)
Merge 10/12: ('newe', 'r') -> newer (freq: 1)
Merge 11/12: ('newe', 'st') -> newest (freq: 1)
Merge 12/12: ('s', 'low') -> slow (freq: 1)
--------------------

Final Vocabulary (20 tokens):
['e', 'l', 'lo', 'low', 'lowe', 'lower', 'lowest', 'n', 'ne', 'new', 'newe', 'newer', 'newest', 'o', 'r', 's', 'slow', 'st', 't', 'w']

Learned Merge Rules (in order of priority):
{('l', 'o'): 'lo', ('lo', 'w'): 'low', ('low', 'e'): 'lowe', ('s', 't'): 'st', ('n', 'e'): 'ne', ('ne', 'w'): 'new', ('lowe', 'r'): 'lower', ('lowe', 'st'): 'lowest', ('new', 'e'): 'newe', ('newe', 'r'): 'newer', ('newe', 'st'): 'ne