In [1]:
import regex as re
from dataclasses import dataclass
from collections import defaultdict

from transformers import GPT2Tokenizer

# References

1. [YT. Stanford CS336 (2025) Overview and Tokenization](https://www.youtube.com/watch?v=msHyYioAyNE&list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_&index=3)
2. [Git. Stanford CS336 (2025) Assignment 1 - Basics](https://github.com/stanford-cs336/assignment1-basics/blob/main/cs336_spring2025_assignment1_basics.pdf)

# 1. Overview

## 1.1. GPT-2 Tokenization

In [2]:
def get_compression_ratio(string: str, indices: list[int]) -> float:
    """Given `string` that has been tokenized into `indices`, calculate
    how many bites are represented by a token."""
    num_bytes = len(bytes(string, encoding="utf-8"))
    num_tokens = len(indices)
    return num_bytes / num_tokens

In [3]:
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained("gpt2")

In [4]:
text = "Hello, 🌍! 你好!"

In [5]:
# tokenize
indices = tokenizer_gpt2.encode(text)
indices

[15496, 11, 12520, 234, 235, 0, 220, 19526, 254, 25001, 121, 0]

In [6]:
# reconstruct
reconstructed_string = tokenizer_gpt2.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [7]:
# compression ratio
get_compression_ratio(text, indices)

1.6666666666666667

## 1.2. Character based tokenization

In [8]:
ord("a")

97

In [9]:
ord("🌍")

127757

In [10]:
chr(97)

'a'

In [11]:
chr(127757)

'🌍'

In [12]:
class CharacterTokenizer:
    """Represent a string as a sequence of Unicode code points."""
    
    def encode(self, string: str) -> list[int]:
        return list(map(ord, string))
        
    def decode(self, indices: list[int]) -> str:
        return "".join(map(chr, indices))

In [13]:
tokenizer_char = CharacterTokenizer()

In [14]:
indices = tokenizer_char.encode(text)
indices

[72, 101, 108, 108, 111, 44, 32, 127757, 33, 32, 20320, 22909, 33]

In [15]:
reconstructed_string = tokenizer_char.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [16]:
get_compression_ratio(text, indices)

1.5384615384615385

## 1.3. Byte-Based Tokenization

In [17]:
bytes("a", encoding="utf-8")

b'a'

In [18]:
bytes("🌍", encoding="utf-8")

b'\xf0\x9f\x8c\x8d'

In [19]:
class ByteTokenizer:
    """Represent a string as a sequence of bytes."""
    
    def encode(self, string: str) -> list[int]:
        string_bytes = string.encode("utf-8")
        indices = list(map(int, string_bytes))
        return indices

    def decode(self, indices: list[int]) -> str:
        string_bytes = bytes(indices)
        string = string_bytes.decode("utf-8")
        return string

In [20]:
tokenizer_byte = ByteTokenizer()

In [21]:
indices = tokenizer_byte.encode(text)
indices

[72,
 101,
 108,
 108,
 111,
 44,
 32,
 240,
 159,
 140,
 141,
 33,
 32,
 228,
 189,
 160,
 229,
 165,
 189,
 33]

In [22]:
reconstructed_string = tokenizer_byte.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [23]:
get_compression_ratio(text, indices)

1.0

## 1.4. Word-Based Tokenization

In [24]:
text = "I'll say supercalifragilisticexpialidocious!"

In [25]:
segments = re.findall(r"\w+|.", text)
segments

['I', "'", 'll', ' ', 'say', ' ', 'supercalifragilisticexpialidocious', '!']

In [26]:
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py#L23
GPT2_TOKENIZER_REGEX = \
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [47]:
segments = re.findall(GPT2_TOKENIZER_REGEX, text)
segments

['the', ' quick', ' brown', ' fox']

## 1.5. Byte Pair Encoding (BPE)

In [28]:
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]             # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index


def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    """Return `indices`, but with all instances of `pair` replaced with `new_index`."""
    new_indices = []
    i = 0
    while i < len(indices):
        if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]:
            new_indices.append(new_index)
            i += 2
        else:
            new_indices.append(indices[i])
            i += 1
    return new_indices


class BPETokenizer:
    """BPE tokenizer given a set of merges and a vocabulary."""
    
    def __init__(self, params: BPETokenizerParams):
        self.params = params
        
    def encode(self, string: str) -> list[int]:
        indices = list(map(int, string.encode("utf-8")))
        # Note: this is a very slow implementation
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices
        
    def decode(self, indices: list[int]) -> str:
        bytes_list = list(map(self.params.vocab.get, indices))
        string = b"".join(bytes_list).decode("utf-8")
        return string


def train_bpe(string: str, num_merges: int) -> BPETokenizerParams:
    # Start with the list of bytes of string.
    indices = list(map(int, string.encode("utf-8")))
    merges: dict[tuple[int, int], int] = {}  # index1, index2 => merged index
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}  # index -> bytes
    for i in range(num_merges):
        # Count the number of occurrences of each pair of tokens
        counts = defaultdict(int)
        for index1, index2 in zip(indices, indices[1:]):  # For each adjacent pair
            counts[(index1, index2)] += 1
        # Find the most common pair.
        pair = max(counts, key=counts.get)
        index1, index2 = pair
        # Merge that pair.
        new_index = 256 + i
        merges[pair] = new_index
        vocab[new_index] = vocab[index1] + vocab[index2]
        indices = merge(indices, pair, new_index)
    return BPETokenizerParams(vocab=vocab, merges=merges)

In [29]:
# training the tokenizer
string = "the cat in the hat"
params = train_bpe(string, num_merges=3)

In [30]:
tokenizer_bpe_valid = BPETokenizer(params)

In [31]:
text = "the quick brown fox"

In [32]:
indices = tokenizer_bpe_valid.encode(text)
indices

[258, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120]

In [33]:
reconstructed_string = tokenizer_bpe_valid.decode(indices)
reconstructed_string

'the quick brown fox'

# 2. BPE Implementation from Scratch

CS336 Assignment 1

Goals:

1) `encode()` currently loops over all merges. Only loop over merges that matter.
2) Detect and preserve special tokens (e.g., `<|endoftext|>`).
3) Use pre-tokenization (e.g., the GPT-2 tokenizer regex).
4) Try to make the implementation as fast as possible.

---

Problem (train_bpe): BPE Tokenizer Training (15 points)

**Deliverable**: Write a function that, given a path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input parameters:

|Parameter|Typing|Functionality|
|:-|:-|:-|
| `input_path`|`str` (Path)| Path to a text file containing BPE tokenizer training data.|
| `vocab_size`|`int`| A positive integer defining the maximum final vocabulary size (includes initial byte vocabulary, merged items, and special tokens).|
|`special_tokens`|`list[str]`|List of strings to add to the vocabulary (these tokens don't affect BPE training).|

Your BPE training function should return the resulting vocabulary and merges:

| Parameter | Typing | Functionality |
|:-|:-|:-|
| `vocab` | `dict[int, bytes]` | The tokenizer vocabulary, a mapping from `int` (token ID in the vocabulary) to `bytes` (token bytes). |
| `merges` | `list[tuple[bytes, bytes]]` | A list of BPE merges produced from training. Each list item is a tuple of bytes `(<token1>, <token2>)`, representing that `<token1>` was merged with `<token2>`. The merges should be ordered by order of creation. |

To test your BPE training function against our provided tests, you will first need to implement the test adapter at `[adapters.run_train_bpe]`. Then, run `uv run pytest tests/test_train_bpe.py`. Your implementation should be able to pass all tests.

In [34]:
# sample_text = "the quick brown fox"
sample_text = "the cat in the hat"

In [35]:
# convert the string into a sequence of bytes using UTF-8 encoding
# each character may be 1+ bytes long
sample_text.encode("utf-8")[2]

101

In [36]:
text_bytes = list(map(int, sample_text.encode("utf-8")))

In [37]:
# bytes(text_bytes).decode('utf-8') == text_bytes  # True

In [38]:
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]             # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index

In [39]:
def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    indices_merged = []
    pair_idx = 0
    while pair_idx != len(indices) - 1:
        old_pair = (indices[pair_idx], indices[pair_idx+1])
        if old_pair == pair:
            indices_merged.append(new_index)
            pair_idx += 2
        else:
            indices_merged.append(indices[pair_idx])
            pair_idx += 1
    indices_merged.append(indices[-1])
    return indices_merged


def train_bpe(text: str, num_merges: int) -> BPETokenizerParams:
    indices = list(map(int, text.encode("utf-8")))
    merges = {}
    vocab = {x: bytes([x]) for x in range(256)}
    for i in range(num_merges):
        counts = {}
        for idx_1, idx_2 in zip(indices, indices[1:]):
            counts[(idx_1, idx_2)] = counts.get((idx_1, idx_2), 0) + 1
        pair = max(counts, key=counts.get)
        new_idx = 256 + i
        merges[pair] = new_idx
        indices = merge(indices, pair, new_idx)
        vocab[new_idx] = vocab[pair[0]] + vocab[pair[1]]
    return BPETokenizerParams(vocab, merges)

In [40]:
train_res = train_bpe(sample_text, num_merges=3)
# train_res

In [41]:
class BPETokenizerCustom:
    """BPE tokenizer given a set of merges and a vocabulary."""
    
    def __init__(self, params: BPETokenizerParams):
        self.params = params
    
    def encode(self, string: str) -> list[int]:
        indices = list(map(int, string.encode("utf-8")))
        for merge_pair, merge_idx in self.params.merges.items():
            indices = merge(indices, merge_pair, merge_idx)
        return indices
        
    def decode(self, indices: list[int]) -> str:
        bytes_list = [self.params.vocab[idx] for idx in indices]
        string = b"".join(bytes_list).decode("utf-8")
        return string

In [42]:
tokenizer_bpe_custom = BPETokenizerCustom(train_res)
_t_encoded = tokenizer_bpe_custom.encode(sample_text)
_t_encoded

[258, 99, 97, 116, 32, 105, 110, 32, 258, 104, 97, 116]

In [43]:
_t_encoded == tokenizer_bpe_valid.encode(sample_text)

True

In [44]:
tokenizer_bpe_custom.decode(_t_encoded)

'the cat in the hat'

In [52]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [53]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']