In [None]:
import os
import numpy as np

# Exercise 2: BPE Tokenizer (3p)

Next, we move on to language modeling, a common use case for autoregressive models. A crucial part of language modeling is the tokenizer, with Byte Pair Encoding (BPE) being widely adopted in state-of-the-art language models. The central idea of BPE is iteratively replacing the most frequent pairs of tokens (initially bytes) with new, unused tokens. Your task is to implement a simplified version of the BPE tokenizer using the Shakespeare dataset.

First, we download and preview the Shakespeare dataset:

In [None]:
# Check if the file already exists
if not os.path.exists("input.txt"):
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt
else:
    print("input.txt already exists. Skipping download.\n")

# Preview dataset
with open('input.txt', 'r') as file:
    content = file.read()
print(content[:250])

We will use the first 5000 characters from the dataset for training. Let's encode it into UTF-8 bytes and represent it as integers:

In [None]:
training_text = content[:5000]
training_data = training_text.encode("utf-8")
training_data = list(map(int, training_data))
training_data[:10]

Although we could directly train a language model on these byte-level tokens, BPE enhances efficiency by merging frequent byte-pairs into new tokens. This approach is standard practice for training large language models. Your task is to implement two functions necessary for BPE: get_stats and merge.

In [None]:
def get_stats(token_ids):
    """
    Counts occurrences of adjacent token pairs.

    Args:
        token_ids (list of int): Input token IDs.

    Returns:
        dict: Dictionary with pairs as keys and their frequencies as values.

    Example:
        Input: [1, 2, 3, 1, 2]
        Output: {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# Test case for get_stats
tokens = [1, 2, 3, 1, 2]
expected_stats = {
    (1, 2): 2,
    (2, 3): 1,
    (3, 1): 1
}
assert get_stats(tokens) == expected_stats, "get_stats returned incorrect solution"

# Another test case for get_stats
tokens = [5, 5, 5, 6, 5, 5]
expected_stats = {
    (5, 5): 3,
    (5, 6): 1,
    (6, 5): 1
}
assert get_stats(tokens) == expected_stats, "get_stats failed on repeated elements"

print("Tests pass - success!")

In [None]:
def merge(token_ids, pair, idx):
    """
    Merges occurrences of a specific token pair into a new token.

    Args:
        token_ids (list of int): Input token IDs.
        pair (tuple of int): Pair of tokens to merge.
        idx (int): New token ID for merged pair.

    Returns:
        list of int: Token IDs after merging.

    Example:
        Input: token_ids = [1, 2, 3, 1, 2], pair = (1, 2), idx = 99
        Output: [99, 3, 99]
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
tokens = [1, 2, 3, 1, 2]
merged = merge(tokens, pair=(1, 2), idx=99)
expected_merged = [99, 3, 99]
assert merged == expected_merged, "merge failed, the returned solution is incorrect"

tokens = [4, 4, 4, 4, 5, 4, 4]
merged = merge(tokens, pair=(4, 4), idx=77)
expected_merged = [77, 77, 5, 77]
assert merged == expected_merged, "merge failed on repeated merges"

print("Tests pass - success!")

After implementation, we use these functions to learn the tokenizer:

In [None]:
vocab_size = 300
num_merges = vocab_size - 256
ids = training_data
merges = {}
# Performing token merges
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    if i < 5:
        print(f"merging {pair} into a new token {idx}")
    elif i == 5:
        print("Going quiet...\n")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

# Create vocabulary
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]
    
# Illustrating the vocabulary
print("Example tokens that were added")
indices = sorted(np.random.choice(np.arange(256, 300), size=5, replace=False))
for idx in indices:
    print(f"{idx}: {vocab[idx].decode('utf-8')}")

Finally, implement the encode and decode functions to test your tokenizer:

In [None]:
def encode(text):
    """
    Encodes text into token IDs using the merge-based encoding.

    The function starts by converting the input string into UTF-8 bytes. 
    Then, using a merge table, it repeatedly merges known byte-pairs into
    new token IDs. This continues until no more merges can be applied.

    Args:
        text (str): Input text.

    Returns:
        list of int: Encoded token IDs.
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
text = content[:25]
print(f"Encoding the following piece of text:\n---\n{text}\n---")
otp = encode(text)
assert otp == [292, 289, 32, 285, 66, 101, 102, 271, 256, 119, 256], "The encode function is incorrect"

print("First test passes\n")

text = content[25:50]
print(f"Encoding the following piece of text:\n---\n{text}\n---")
otp = encode(text)
assert otp == [112, 114, 111, 99, 101, 297, 268, 266, 102, 117, 114, 257, 264, 259, 104, 101, 273], "The encode function is incorrect"

print("Tests pass - success!")

In [None]:
def decode(ids):
    """
    Decodes token IDs back to the original text.

    This function reconstructs the original text by looking up each token ID 
    in the vocabulary (which stores byte strings), joining them together, 
    and decoding the result from UTF-8.

    Args:
        ids (list of int): Token IDs to decode.

    Returns:
        str: Decoded text.
    """
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
otp = decode([292, 289, 32, 285, 66, 101, 102, 271, 256])
assert otp == "First Citizen:\nBefore ", f"the decode function returns an incorrect solution {otp}"

otp = decode([112, 114, 111, 99, 101, 297, 268, 266, 102, 117, 114, 257, 264, 259, 104, 101, 273])
assert otp == "proceed any further, hear", f"the decode function returns an incorrect solution {otp}"

print("Tests pass - success!")

In [None]:
print("Running a bigger match of data through the encoder-decoder to ensure it does not change")
assert decode(encode(training_text)) == training_text, "the encoder-decoder pair does not return the correct output, when the training data is encoded and decoded"
print("Tests pass - success!")