In [1]:
import os
import regex as re
from typing import Iterator, Any
from multiprocessing import Pool
from dataclasses import dataclass
from collections import defaultdict

from transformers import GPT2Tokenizer

import tokenization.bpe as custom_bpe
import tokenization.pretokenization as custom_pretok

from utils import timer
from aux.stanford_cs336.basics.pretokenization_example import find_chunk_boundaries

# References

1. [YT. Stanford CS336 (2025) Overview and Tokenization](https://www.youtube.com/watch?v=msHyYioAyNE&list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_&index=3)
2. [Git. Stanford CS336 (2025) Assignment 1 - Basics](https://github.com/stanford-cs336/assignment1-basics/blob/main/cs336_spring2025_assignment1_basics.pdf)

# 1. Overview

## 1.1. GPT-2 Tokenization

In [2]:
def get_compression_ratio(string: str, indices: list[int]) -> float:
    """Given `string` that has been tokenized into `indices`, calculate
    how many bites are represented by a token."""
    num_bytes = len(bytes(string, encoding="utf-8"))
    num_tokens = len(indices)
    return num_bytes / num_tokens

In [3]:
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained("gpt2")

In [4]:
text = "Hello, 🌍! 你好!"

In [5]:
# tokenize
indices = tokenizer_gpt2.encode(text)
indices

[15496, 11, 12520, 234, 235, 0, 220, 19526, 254, 25001, 121, 0]

In [6]:
# reconstruct
reconstructed_string = tokenizer_gpt2.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [7]:
# compression ratio
get_compression_ratio(text, indices)

1.6666666666666667

## 1.2. Character based tokenization

In [8]:
ord("a")

97

In [9]:
ord("🌍")

127757

In [10]:
chr(97)

'a'

In [11]:
chr(127757)

'🌍'

In [12]:
class CharacterTokenizer:
    """Represent a string as a sequence of Unicode code points."""
    
    def encode(self, string: str) -> list[int]:
        return list(map(ord, string))
        
    def decode(self, indices: list[int]) -> str:
        return "".join(map(chr, indices))

In [13]:
tokenizer_char = CharacterTokenizer()

In [14]:
indices = tokenizer_char.encode(text)
indices

[72, 101, 108, 108, 111, 44, 32, 127757, 33, 32, 20320, 22909, 33]

In [15]:
reconstructed_string = tokenizer_char.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [16]:
get_compression_ratio(text, indices)

1.5384615384615385

## 1.3. Byte-Based Tokenization

In [17]:
bytes("a", encoding="utf-8")

b'a'

In [18]:
bytes("🌍", encoding="utf-8")

b'\xf0\x9f\x8c\x8d'

In [19]:
class ByteTokenizer:
    """Represent a string as a sequence of bytes."""
    
    def encode(self, string: str) -> list[int]:
        string_bytes = string.encode("utf-8")
        indices = list(map(int, string_bytes))
        return indices

    def decode(self, indices: list[int]) -> str:
        string_bytes = bytes(indices)
        string = string_bytes.decode("utf-8")
        return string

In [20]:
tokenizer_byte = ByteTokenizer()

In [21]:
indices = tokenizer_byte.encode(text)
indices

[72,
 101,
 108,
 108,
 111,
 44,
 32,
 240,
 159,
 140,
 141,
 33,
 32,
 228,
 189,
 160,
 229,
 165,
 189,
 33]

In [22]:
reconstructed_string = tokenizer_byte.decode(indices)
reconstructed_string

'Hello, 🌍! 你好!'

In [23]:
get_compression_ratio(text, indices)

1.0

## 1.4. Word-Based Tokenization

In [24]:
text = "I'll say supercalifragilisticexpialidocious!"

In [25]:
segments = re.findall(r"\w+|.", text)
segments

['I', "'", 'll', ' ', 'say', ' ', 'supercalifragilisticexpialidocious', '!']

In [26]:
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py#L23
GPT2_TOKENIZER_REGEX = \
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [27]:
segments = re.findall(GPT2_TOKENIZER_REGEX, text)
segments

['I', "'ll", ' say', ' supercalifragilisticexpialidocious', '!']

## 1.5. Byte Pair Encoding (BPE)

In [28]:
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]             # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index


def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    """Return `indices`, but with all instances of `pair` replaced with `new_index`."""
    new_indices = []
    i = 0
    while i < len(indices):
        if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]:
            new_indices.append(new_index)
            i += 2
        else:
            new_indices.append(indices[i])
            i += 1
    return new_indices


class BPETokenizer:
    """BPE tokenizer given a set of merges and a vocabulary."""
    
    def __init__(self, params: BPETokenizerParams):
        self.params = params
        
    def encode(self, string: str) -> list[int]:
        indices = list(map(int, string.encode("utf-8")))
        # Note: this is a very slow implementation
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices
        
    def decode(self, indices: list[int]) -> str:
        bytes_list = list(map(self.params.vocab.get, indices))
        string = b"".join(bytes_list).decode("utf-8")
        return string


def train_bpe(string: str, num_merges: int) -> BPETokenizerParams:
    # Start with the list of bytes of string.
    indices = list(map(int, string.encode("utf-8")))
    merges: dict[tuple[int, int], int] = {}  # index1, index2 => merged index
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}  # index -> bytes
    for i in range(num_merges):
        # Count the number of occurrences of each pair of tokens
        counts = defaultdict(int)
        for index1, index2 in zip(indices, indices[1:]):  # For each adjacent pair
            counts[(index1, index2)] += 1
        # Find the most common pair.
        pair = max(counts, key=counts.get)
        index1, index2 = pair
        # Merge that pair.
        new_index = 256 + i
        merges[pair] = new_index
        vocab[new_index] = vocab[index1] + vocab[index2]
        indices = merge(indices, pair, new_index)
    return BPETokenizerParams(vocab=vocab, merges=merges)

In [29]:
# training the tokenizer
string = "the cat in the hat"
params = train_bpe(string, num_merges=3)

In [30]:
tokenizer_bpe_valid = BPETokenizer(params)

In [31]:
text = "the quick brown fox"

In [32]:
indices = tokenizer_bpe_valid.encode(text)
indices

[258, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120]

In [33]:
reconstructed_string = tokenizer_bpe_valid.decode(indices)
reconstructed_string

'the quick brown fox'

# 2. BPE Implementation from Scratch

CS336 Assignment 1

Goals:

1) `encode()` currently loops over all merges. Only loop over merges that matter.
2) Detect and preserve special tokens (e.g., `<|endoftext|>`).
3) Use pre-tokenization (e.g., the GPT-2 tokenizer regex).
4) Try to make the implementation as fast as possible.

You are free to use the starter code at the following link verbatim to obtain chunk boundaries, which you can then use to distribute work across your processes:

https://github.com/stanford-cs336/assignment1-basics/blob/main/cs336_basics/pretokenization_example.py

---

#### Problem (train_bpe): BPE Tokenizer Training (15 points)

**Deliverable**: Write a function that, given a path to an input text file, trains a (byte-level) BPE tokenizer. Your BPE training function should handle (at least) the following input parameters:

|Parameter|Typing|Functionality|
|:-|:-|:-|
| `input_path`|`str` (Path)| Path to a text file containing BPE tokenizer training data.|
| `vocab_size`|`int`| A positive integer defining the maximum final vocabulary size (includes initial byte vocabulary, merged items, and special tokens).|
|`special_tokens`|`list[str]`|List of strings to add to the vocabulary (these tokens don't affect BPE training).|

Your BPE training function should return the resulting vocabulary and merges:

| Parameter | Typing | Functionality |
|:-|:-|:-|
| `vocab` | `dict[int, bytes]` | The tokenizer vocabulary, a mapping from `int` (token ID in the vocabulary) to `bytes` (token bytes). |
| `merges` | `list[tuple[bytes, bytes]]` | A list of BPE merges produced from training. Each list item is a tuple of bytes `(<token1>, <token2>)`, representing that `<token1>` was merged with `<token2>`. The merges should be ordered by order of creation. |

To test your BPE training function against our provided tests, you will first need to implement the test adapter at `[adapters.run_train_bpe]`. Then, run `uv run pytest tests/test_train_bpe.py`. Your implementation should be able to pass all tests.

---

#### Problem (tokenizer): implementing the tokenizer (15 points)

## 2.1. Vocabulary initialization

In [34]:
SPECIAL_TOKENS = ['<|endoftext|>']

In [35]:
def init_vocabulary(special_tokens: list[str]) -> dict[int, bytes]:
    vocab = {x: bytes([x]) for x in range(256)}
    for spec_tok in special_tokens:
        vocab[len(vocab)] = spec_tok.encode("utf-8")
    return vocab

In [36]:
vocabulary = init_vocabulary(SPECIAL_TOKENS)
len(vocabulary)

257

## 2.2. Pre-Tokenization

In [37]:
def remove_special_tokens(text: str, tokens: list[str]) -> str:
    # Create a regex pattern that matches all keys
    replacements = {tok: "" for tok in tokens}
    pattern = re.compile("|".join(map(re.escape, replacements.keys())))
    # Use a lambda to replace each match with its corresponding value
    return pattern.sub(lambda m: replacements[m.group(0)], text)

In [38]:
def count_tokens(text: str, pattern: str) -> dict[bytes, int]:
    token_count = {}
    for match in re.finditer(pattern, text):
        token = match.group()
        token_bytes = tuple(token.encode("utf-8"))
        token_count[token_bytes] = token_count.get(token_bytes, 0) + 1
    return token_count

In [39]:
def pretokenize(text: str, special_tokens: list[str], pretoken_pat: str) -> dict[tuple[bytes], int]:
    """Return bytes counts after special tokens removal and pre-tokenization."""
    text_clear = remove_special_tokens(text, special_tokens)
    token_count = count_tokens(text_clear, pretoken_pat)  # frequency table
    return token_count

In [40]:
@dataclass
class PreTokenizerArgs:
    n_proc: int
    token_split: str
    special_tokens: list[str]
    pretoken_pat: str

In [41]:
def pretokenize_file_parallel(filep: str, pt_args: PreTokenizerArgs) -> dict[tuple[bytes], int]:
    with open(filep, "rb") as file:
        bounds = find_chunk_boundaries(file, pt_args.n_proc, pt_args.token_split)
        # Create arguments for each chunk
        args = []
        for beg, end in zip(bounds[:-1], bounds[1:]):
            file.seek(beg)
            chunk = file.read(end - beg).decode("utf-8", errors="ignore")
            args.append((chunk, pt_args.special_tokens, pt_args.pretoken_pat))
    # Process chunks in parallel
    with Pool(processes=N_PROC) as pool:
        results = pool.starmap(pretokenize, args)
    # Reduce results
    pretoken_res = {}  # frequency table
    for chunk_res in results:
        for token_bytes, token_count in chunk_res.items():
            pretoken_res[token_bytes] = pretoken_res.get(token_bytes, 0) + token_count
    return pretoken_res

In [42]:
DATA_ROOT_P = '/mnt/data/DatasetsML/NLP/natural_language_corpus/tiny_stories'

TRAIN_P = os.path.join(DATA_ROOT_P, 'TinyStoriesV2-GPT4-train.txt')
VALID_P = os.path.join(DATA_ROOT_P, 'TinyStoriesV2-GPT4-valid.txt')

In [43]:
N_PROC = 8
TOKEN_SPLIT = "<|endoftext|>".encode("utf-8")
SPECIAL_TOKENS = ['<|endoftext|>']
PRETOKEN_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [44]:
pre_tok_params = PreTokenizerArgs(
    N_PROC,
    TOKEN_SPLIT,
    SPECIAL_TOKENS,
    PRETOKEN_PAT,
)
_t_res_parallel = pretokenize_file_parallel(VALID_P, pre_tok_params)

In [45]:
_t_res_sub = {
    bytes(k): _t_res_parallel[k]
    for k, v in sorted(_t_res_parallel.items(), reverse=True, key=lambda x: x[1])[:30]
}

In [46]:
_t_res_sub

{b'.': 421616,
 b',': 235432,
 b' the': 211031,
 b' and': 196057,
 b' a': 152161,
 b' to': 150493,
 b'\n': 139288,
 b' was': 108019,
 b' They': 52425,
 b' it': 51670,
 b' He': 49241,
 b' "': 47784,
 b' The': 46977,
 b' said': 43900,
 b' day': 43230,
 b' with': 42981,
 b' her': 38925,
 b' his': 38766,
 b' in': 38658,
 b' She': 38040,
 b' Tim': 37647,
 b' big': 35022,
 b' he': 32790,
 b' they': 29903,
 b' had': 28997,
 b' you': 28401,
 b' not': 27019,
 b' happy': 25863,
 b' on': 25720,
 b' of': 25467}

## 2.3. BPE Merges

In [47]:
sample_text = """
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
""".replace('\n', ' ').strip()

In [48]:
def pretokenize_dummy(text: str) -> dict[tuple[bytes], int]:
    freq_table = {}
    for token in text.split(' '):
        token_bytes = tuple(token.encode('utf-8'))
        freq_table[token_bytes] = freq_table.get(token_bytes, 0) + 1
    return freq_table

In [49]:
t_res_dummy = pretokenize_dummy(sample_text)
t_res_dummy

{(108, 111, 119): 5,
 (108, 111, 119, 101, 114): 2,
 (119, 105, 100, 101, 115, 116): 3,
 (110, 101, 119, 101, 115, 116): 6}

In [50]:
t_res_dummy_sub = {bytes(k): v for k, v in t_res_dummy.items()}
t_res_dummy_sub

{b'low': 5, b'lower': 2, b'widest': 3, b'newest': 6}

The idea of optimization:

1. Store pair to token map for each pair.
2. Update the only pair count that overlaped with merged pair in updated token.

Algorithm steps:

1. Precompute initial pairs counts
2. Precompute initial pair to token map
3. In cycle:
    - find the best pair for merge (`pair_counts`)
    - update merges and vocab (`pair_locations`)
    - for each affected token (all tokens that contain pairs to merge):
        - merge and obtain new token
        - update frequency table with new token, remove old one
        - update pair counts for pairs that overlap with merged one
        - update pair locations with new token, remove old one

In [51]:
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]            # index -> bytes
    merges: list[tuple[bytes, bytes]]  # index1,index2 -> new_index

In [52]:
def to_bytes(bytes_tuple: tuple[bytes]) -> Iterator[bytes]:
    return map(lambda x: bytes([x]) if isinstance(x, int) else bytes(x), bytes_tuple)


def merge_bytes(bytes_tuple: tuple[bytes], sep=b'') -> bytes:
    return sep.join(to_bytes(bytes_tuple))

In [53]:
def print_debug_msg(msg: str):
    print()
    print('='*60)
    print(msg)
    print()


def print_debug_structs(
    freq_table: dict[tuple[bytes], int],
    pair_counts: dict[tuple[bytes], int],
    pair_locations: dict[tuple[bytes], set[tuple[bytes]]],
):
    # print frequency table
    print('freq_table:')
    _t_freq_table_cnt = {tuple(to_bytes(pair)): cnt for pair, cnt in freq_table.items()}
    print(_t_freq_table_cnt)
    print()
    # print pair count
    print('pair_counts:')
    _t_pair_cnt = {merge_bytes(pair, b"|"): cnt for pair, cnt in pair_counts.items()}
    print(_t_pair_cnt)
    print()
    # print pair locations
    print('pair_locations:')
    _t_pair_locations_f = {
        merge_bytes(pair, b"|"): [tuple(to_bytes(loc)) for loc in locs]
        for pair, locs in pair_locations.items()
    }
    print(f'{_t_pair_locations_f}')
    print()

In [54]:
def train_bpe_optimized_debug(
    vocab_src: dict[int, bytes],
    freq_table_src: dict[tuple[bytes], int],
    num_merges: int,
    vocab_size: int = 10_000,
) -> BPETokenizerParams:
    """Optimized BPE training with incremental pair frequency updates.

    This is the implementation for debug purposes.
    
    Args:
        vocab_src: Initial vocabulary
        freq_table_src: Pre-tokenized vocabulary with frequencies: {(b'l',b'o',b'w'): 5, ...}
        num_merges: Number of merge operations to perform
        
    Returns:
        BPETokenizerParams consisting of learned vocab and list of merge operations
            in order they were learned
    """
    vocab = vocab_src.copy()
    freq_table = freq_table_src.copy()
    # Init data structures
    merges = []
    pair_counts = defaultdict(int)
    pair_locations = defaultdict(set)
    # Precompute all initial pairs and their locations
    for token, freq in freq_table.items():
        for pair in zip(token, token[1:]):
            pair_counts[pair] += freq
            pair_locations[pair].add(token)

    for merge_idx in range(num_merges):
        
        if len(pair_counts) == 0 or len(vocab) >= vocab_size:
            # stop if all possible pairs were merged or met vocab_size threshold
            msgs = ['All possible pairs were merged', 'Met vocab_size threshold']
            msg = msgs[0] if len(pair_counts) == 0 else msgs[1]
            print(f'\n>>> {msg}')
            break
        
        print_debug_msg(f'>>> Merge ({merge_idx+1})')
        print_debug_structs(freq_table, pair_counts, pair_locations)
        
        # TODO: Could use heap instead of max
        pair_to_merge = max(pair_counts.items(), key=lambda p_cnt: (p_cnt[1], merge_bytes(p_cnt[0], b"|")))[0]

        print(f'pair_to_merge: {merge_bytes(pair_to_merge, b"|")}')
    
        # Update structures
        merges.append(tuple(to_bytes(pair_to_merge)))
        merged = merge_bytes(pair_to_merge)
        vocab[len(vocab)] = merged
        first, second = pair_to_merge
        print(f'merged_pair: {merged}')
    
        print()
        
        # Update frequency table and tracking structures
        affected_tokens = pair_locations[pair_to_merge].copy()
        pairs_to_remove = set()  # Non existing pairs after merge
        for old_token in affected_tokens:
            print(f'old_token: {tuple(to_bytes(old_token))}')
            # Merge and obtain new_token
            new_token = []
            pi = 0
            while pi < len(old_token):
                if pi < len(old_token)-1 and old_token[pi] == first and old_token[pi+1] == second:
                    new_token.append(merged)
                    pi += 2
                else:
                    new_token.append(old_token[pi])
                    pi += 1

            new_token = tuple(new_token)
            print(f'  new_token: {tuple(to_bytes(new_token))}')
            
            # Initialize frequency table for new token
            freq_table[new_token] = freq_table[old_token]
            # Remove old token from frequency table
            del freq_table[old_token]

            # Update pairs_count and pair_locations for new token
            for new_pair in zip(new_token, new_token[1:]):
                if new_pair[0] == merged and new_pair[1] == merged:
                    # merged pairs are joint: (re, re)
                    rem_pair = (second, first)
                elif new_pair[0] == merged or new_pair[1] == merged:
                    # pairs that intersect with merged: (re, a) | (a, re)
                    rem_pair = (second, new_pair[1]) if new_pair[0] == merged else (new_pair[0], first)
                else:
                    # old pairs that don't intersect with merged: (a, b)
                    rem_pair = new_pair
                pair_counts[rem_pair] -= freq_table[new_token]
                pair_counts[new_pair] += freq_table[new_token]
                if pair_counts[rem_pair] == 0:
                    pairs_to_remove.add(rem_pair)
                pair_locations[new_pair].add(new_token)
                if old_token in pair_locations[new_pair]:
                    # basically "if old pair"
                    pair_locations[new_pair].remove(old_token)
                if old_token in pair_locations[rem_pair]:
                    pair_locations[rem_pair].remove(old_token)

        # Remove non existing pairs
        for rem_pair in pairs_to_remove:
            del pair_counts[rem_pair]
            del pair_locations[rem_pair]
        del pair_counts[pair_to_merge]
        del pair_locations[pair_to_merge]
    
    res = BPETokenizerParams(vocab, merges)
    return res

In [80]:
sample_text = """
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
ererer
""".replace('\n', ' ').strip()

t_res_dummy = pretokenize_dummy(sample_text)

# _t_bpe_res = train_bpe_optimized_debug(vocabulary, t_res_dummy, num_merges=20)
_t_bpe_res = custom_bpe.train_bpe_optimized(vocabulary, t_res_dummy, num_merges=20)


>>> All possible pairs were merged


In [81]:
for idx in range(256, 280):
    if idx in _t_bpe_res.vocab:
        print(f'{idx}: {_t_bpe_res.vocab[idx]}')

256: b'<|endoftext|>'
257: b'st'
258: b'est'
259: b'ow'
260: b'low'
261: b'west'
262: b'ne'
263: b'newest'
264: b'er'
265: b'wi'
266: b'wid'
267: b'widest'
268: b'lower'
269: b'erer'
270: b'ererer'


In [82]:
_t_bpe_res.merges

[(b's', b't'),
 (b'e', b'st'),
 (b'o', b'w'),
 (b'l', b'ow'),
 (b'w', b'est'),
 (b'n', b'e'),
 (b'ne', b'west'),
 (b'e', b'r'),
 (b'w', b'i'),
 (b'wi', b'd'),
 (b'wid', b'est'),
 (b'low', b'er'),
 (b'er', b'er'),
 (b'erer', b'er')]

## 2.4. Encoding-Decoding

In [83]:
def encoding_pretokenize(text: str, pat: str) -> list[tuple[bytes]]:
    tokens = []
    for match in re.finditer(pat, text):
        token = match.group()
        token_bytes = tuple(to_bytes(token.encode("utf-8")))
        tokens.append(token_bytes)
    return tokens

In [84]:
def merge(token_src: tuple[bytes], merges: list[tuple[bytes, bytes]]) -> tuple[bytes]:
    token = token_src
    for first, second in merges:
        new_token = []
        merged = merge_bytes(to_bytes((first, second)))
        pi = 0
        while pi < len(token):
            if pi < len(token)-1 and token[pi] == first and token[pi+1] == second:
                new_token.append(merged)
                pi += 2
            else:
                new_token.append(token[pi])
                pi += 1
        token = tuple(new_token)
        if len(token) == 1:
            break
    return token

In [85]:
class BPETokenizerCustom:
    def __init__(self, params: BPETokenizerParams):
        self.params = params
        
    def encode(self, string: str) -> list[int]:
        pre_tokens = encoding_pretokenize(sample_text, PRETOKEN_PAT)
        encoding = []
        for pre_tok in pre_tokens:
            merged_bytes = merge(pre_tok, self.params.merges)
            token_encoding = list(map(sample_vocab_inv.get, merged_bytes))
            encoding.extend(token_encoding)
        return encoding
        
    def decode(self, indices: list[int]) -> str:
        merged_bytes = merge_bytes(map(self.params.vocab.get, indices))
        string = merged_bytes.decode("utf-8", errors='replace')
        return string

In [86]:
sample_vocab = {
    0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't',
    6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at',
}

sample_vocab_inv = {byte: idx for idx, byte in sample_vocab.items()}

sample_merges = [
    (b't', b'h'), (b' ', b'c'), (b' ', b'a'),
    (b'th', b'e'), (b' a', b't'),
]

In [87]:
sample_text = "the cat ate"

bpe_encode_params = BPETokenizerParams(vocab=sample_vocab, merges=sample_merges)
bpe_encode = BPETokenizerCustom(bpe_encode_params)

_t_enc = bpe_encode.encode(sample_text)
_t_enc

[9, 7, 1, 5, 10, 3]

In [88]:
_t_dec = bpe_encode.decode(_t_enc)
_t_dec

'the cat ate'

## 2.5. Assignment Format

To test your BPE training function against our provided tests, you will first need to implement the test adapter at `[adapters.run_train_bpe]`. Then, run `uv run pytest tests/test_train_bpe.py`. Your implementation should be able to pass all tests.

In [89]:
DATA_ROOT_P = '/mnt/data/DatasetsML/NLP/natural_language_corpus/tiny_stories'

TRAIN_P = os.path.join(DATA_ROOT_P, 'TinyStoriesV2-GPT4-train.txt')
VALID_P = os.path.join(DATA_ROOT_P, 'TinyStoriesV2-GPT4-valid.txt')

In [90]:
N_PROC = 8
TOKEN_SPLIT = "<|endoftext|>".encode("utf-8")
PRETOKEN_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

N_MAX_MERGES = 10_000


@timer
def train_bpe(input_path: str, vocab_size: int, special_tokens: list[str]):
    # 1. Vocabulary initialization
    vocabulary = init_vocabulary(special_tokens)
    # 2. Pre-tokenization
    pre_tok_params = PreTokenizerArgs(N_PROC, TOKEN_SPLIT, special_tokens, PRETOKEN_PAT)
    pre_tok_res = pretokenize_file_parallel(VALID_P, pre_tok_params)
    # 3. Compute BPE merges / Train BPE
    bpe_trained_res = custom_bpe.train_bpe_optimized(vocabulary, pre_tok_res, N_MAX_MERGES)
    # bpe_trained_res = train_bpe_optimized_debug(vocabulary, pre_tok_res, N_MAX_MERGES)
    return bpe_trained_res

In [94]:
N_MAX_VOCAB = 10_000
SPECIAL_TOKENS = ['<|endoftext|>']

N_MAX_MERGES = 1000

train_bpe_res = train_bpe(VALID_P, N_MAX_VOCAB, SPECIAL_TOKENS)

Function 'train_bpe' executed in 5.884 seconds


In [95]:
len(train_bpe_res.merges)

1000

# 3. Tests

## 3.1. BPE Training

In [75]:
sample_text = """
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
start restart false
""".replace('\n', ' ').strip()

t_pretok_dummy = pretokenize_dummy(sample_text)

_t_bpe_res = train_bpe_optimized(vocabulary, t_pretok_dummy, num_merges=13)

In [76]:
for idx in range(256, 300):
    if idx in _t_bpe_res.vocab:
        print(f'{idx}: {_t_bpe_res.vocab[idx]}')

256: b'<|endoftext|>'
257: b'st'
258: b'est'
259: b'ow'
260: b'low'
261: b'west'
262: b'ne'
263: b'newest'
264: b'wi'
265: b'wid'
266: b'widest'
267: b'rt'
268: b'lowe'
269: b'lower'


In [77]:
_t_bpe_res.merges

[(b's', b't'),
 (b'e', b'st'),
 (b'o', b'w'),
 (b'l', b'ow'),
 (b'w', b'est'),
 (b'n', b'e'),
 (b'ne', b'west'),
 (b'w', b'i'),
 (b'wi', b'd'),
 (b'wid', b'est'),
 (b'r', b't'),
 (b'low', b'e'),
 (b'lowe', b'r')]

In [69]:
# use tokenizatoin/sample_text to test tokenizer
target_text = "The new model features the lowest price and the widest selection of colors in its category."