# BPE Training Project\n
## Complete Implementation in English\n\n
This notebook consolidates all project files:\n
1. bpetraining.py\n
2. Tokenizer.py\n
3. problem4.py\n

## bpetraining.py

In [None]:
"""
BPE training implementation.

Further-optimized version:
- Builds initial pair -> total count and pair -> set(word_indices)
- On each merge, only touch words that contain the chosen pair:
  * remove their old pair contributions
  * apply the merge to that word
  * add new pair contributions for that word
- Avoids rebuilding global pair counts from scratch each iteration.
- Keeps same API as original: train_bpe(input_path, vocab_size, special_tokens)
"""

from collections import Counter, defaultdict
from pathlib import Path
import regex as re
from typing import Dict, List, Tuple

# GPT-2 regex pattern (as required by the spec)
PAT = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"

_token_re = re.compile(PAT)

def _split_on_specials(text: str, special_tokens: List[str]) -> List[str]:
    """Split text into chunks that do not cross special token boundaries.
    The special tokens themselves are kept as chunks."""
    if not special_tokens:
        return [text]
    esc = [re.escape(st) for st in special_tokens]
    splitter = re.compile("(" + "|".join(esc) + ")")
    parts = splitter.split(text)
    return [p for p in parts if p != ""]

def _pretokenize_chunk(chunk: str, special_tokens_set: set) -> List[Tuple[bytes, ...]]:
    """Pre-tokenize a chunk. Protected special tokens are returned as single tokens.
    Otherwise apply GPT-2 regex and return tuple-of-single-byte-bytes symbols for each token."""
    if chunk in special_tokens_set:
        return [(chunk.encode("utf-8"),)]

    out = []
    for m in _token_re.finditer(chunk):
        tok = m.group(0)
        if not tok:
            continue
        b = tok.encode("utf-8")
        symbols = tuple(bytes([bb]) for bb in b)
        if symbols:
            out.append(symbols)
    return out

def train_bpe(input_path: str, vocab_size: int, special_tokens: List[str]):
    """Train BPE tokenizer.

    Returns:
      vocab: dict[int, bytes]  -- mapping token id -> byte sequence (as bytes)
      merges: list[tuple[bytes, bytes]] -- list of merges in creation order
    """
    p = Path(input_path)
    if not p.exists():
        raise FileNotFoundError(f"input file not found: {input_path}")

    text = p.read_text(encoding="utf-8")

    # split on special tokens first (they are protected)
    chunks = _split_on_specials(text, special_tokens)
    special_set = set(special_tokens)

    # pretokenize all chunks and count frequencies of "words" (tuples of symbols)
    word_freq: Counter = Counter()
    for chunk in chunks:
        tokens = _pretokenize_chunk(chunk, special_set)
        for t in tokens:
            word_freq[t] += 1

    # Initialize vocab list (bytes sequences)
    vocab_list: List[bytes] = []
    # Add special tokens first as full-byte sequences
    for st in special_tokens:
        vocab_list.append(st.encode("utf-8"))
    # Add 256 single-byte tokens
    for i in range(256):
        vocab_list.append(bytes([i]))

    merges: List[Tuple[bytes, bytes]] = []

    # Convert word_freq to mutable list-of-lists + freqs
    words: List[List[bytes]] = [list(w) for w in word_freq.keys()]  # each w: list of symbols (bytes)
    freqs: List[int] = list(word_freq.values())
    n_words = len(words)

    # Build initial pair_counts and mapping pair -> set(word_indices)
    pair_counts: Dict[Tuple[bytes, bytes], int] = defaultdict(int)
    pair_to_word_indices: Dict[Tuple[bytes, bytes], set] = defaultdict(set)

    for idx, (w, f) in enumerate(zip(words, freqs)):
        ln = len(w)
        for i in range(ln - 1):
            pair = (w[i], w[i + 1])
            pair_counts[pair] += f
            pair_to_word_indices[pair].add(idx)

    current_vocab_size = len(vocab_list)

    # Main loop: merge until target vocab_size
    while current_vocab_size < vocab_size and pair_counts:
        # Find best pair: highest freq, tie-break by lexicographic greater pair
        # Note: iterating over dict items is fine; number of distinct pairs is typically modest
        best_pair = None
        best_freq = -1
        for pair, cnt in pair_counts.items():
            if cnt > best_freq or (cnt == best_freq and (best_pair is None or pair > best_pair)):
                best_pair = pair
                best_freq = cnt

        if best_pair is None or best_freq <= 0:
            break

        a, b = best_pair
        new_symbol = a + b

        # Record merge and add to vocab
        merges.append((a, b))
        vocab_list.append(new_symbol)
        current_vocab_size += 1

        # Get affected word indices (copy because we'll modify the sets)
        affected_indices = list(pair_to_word_indices.get(best_pair, set()))
        if not affected_indices:
            # no words actually contain it any more (defensive)
            # remove pair and continue
            pair_counts.pop(best_pair, None)
            pair_to_word_indices.pop(best_pair, None)
            continue

        # For each affected word: remove old pair contributions, modify word, then add new pair contributions
        for idx in affected_indices:
            # If this word was removed/empty somehow, skip
            if idx >= len(words):
                continue
            w = words[idx]
            f = freqs[idx]
            if len(w) < 2:
                # nothing to do
                continue

            # Compute old pairs for this word (list)
            old_pairs = []
            for i in range(len(w) - 1):
                old_pairs.append((w[i], w[i + 1]))

            # Remove this word's contribution from global pair_counts and pair_to_word_indices
            # (for each old pair, decrement and remove idx from set)
            for pair in old_pairs:
                # decrement count
                cnt = pair_counts.get(pair, 0)
                if cnt <= f:
                    # remove entirely
                    pair_counts.pop(pair, None)
                else:
                    pair_counts[pair] = cnt - f
                # remove index from mapping set
                s = pair_to_word_indices.get(pair)
                if s:
                    s.discard(idx)
                    if len(s) == 0:
                        pair_to_word_indices.pop(pair, None)

            # Apply merge on this word: replace adjacent (a,b) with new_symbol
            new_w = []
            i = 0
            changed = False
            while i < len(w):
                if i < len(w) - 1 and w[i] == a and w[i + 1] == b:
                    new_w.append(new_symbol)
                    i += 2
                    changed = True
                else:
                    new_w.append(w[i])
                    i += 1

            # If no change (maybe pair no longer present due to earlier merges), skip adding contributions
            words[idx] = new_w

            if not changed:
                continue

            # Compute new pairs for this modified word and add contributions
            ln2 = len(new_w)
            for j in range(ln2 - 1):
                p = (new_w[j], new_w[j + 1])
                pair_counts[p] += f
                pair_to_word_indices[p].add(idx)

        # Finally, remove the merged pair from maps if present (it's now obsolete)
        pair_counts.pop(best_pair, None)
        pair_to_word_indices.pop(best_pair, None)

    # Build final vocab mapping id -> bytes
    vocab: Dict[int, bytes] = {i: v for i, v in enumerate(vocab_list)}
    return vocab, merges

## Tokenizer.py

In [None]:
import re
import regex 
from typing import Dict, List, Tuple, Any, Union
import json
import pickle
import os


class OptimizedBPE:
    def __init__(self, vocab_size: int = 1000,
                 special_tokens: List[str] = None,
                 lowercase: bool = False):
        self.vocab_size = vocab_size
        self.lowercase = lowercase
        self.special_tokens = special_tokens or []
        self.gpt2_pattern = regex.compile(
            r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        )

        self.vocab: Dict[int, bytes] = {}  # id -> bytes
        self.merges: List[Tuple[bytes, bytes]] = []  
        self.id_to_token: Dict[int, bytes] = {}
        self.token_to_id: Dict[bytes, int] = {}
        self.str_vocab: Dict[str, int] = {}
        self.id_to_str_token: Dict[int, str] = {}
        self.special_token_bytes: Dict[str, bytes] = {}
        self.special_token_ids: Dict[str, int] = {}
        self._cache: Dict[str, List[str]] = {}

    def initialize_from_existing(self, vocab: Dict[int, bytes], merges: List[Tuple[bytes, bytes]]):
        self.vocab = vocab
        self.merges = merges
        self.id_to_token = vocab
        self.token_to_id = {v: k for k, v in vocab.items()}
        self.str_vocab = {}
        self.id_to_str_token = {}
        for token_id, token_bytes in vocab.items():
            try:
                token_str = token_bytes.decode('utf-8')
            except UnicodeDecodeError:
                token_str = token_bytes.hex()
            self.str_vocab[token_str] = token_id
            self.id_to_str_token[token_id] = token_str

        for token in self.special_tokens:
            token_bytes = token.encode('utf-8')
            self.special_token_bytes[token] = token_bytes

            if token_bytes in self.token_to_id:
                self.special_token_ids[token] = self.token_to_id[token_bytes]
            else:
                new_id = max(self.vocab.keys()) + 1 if self.vocab else 0
                self.vocab[new_id] = token_bytes
                self.id_to_token[new_id] = token_bytes
                self.token_to_id[token_bytes] = new_id
                self.special_token_ids[token] = new_id
                self.str_vocab[token] = new_id
                self.id_to_str_token[new_id] = token

    def encode(self, text: str) -> List[int]:
        if not text:
            return []

        if self.special_tokens:
            sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)
            escaped = [re.escape(t) for t in sorted_special_tokens]
            special_pattern = re.compile(f"({'|'.join(escaped)})")
            parts = special_pattern.split(text)
        else:
            parts = [text]

        all_ids = []
        for part in parts:
            if not part:
                continue

            if part in self.special_token_ids:
                all_ids.append(self.special_token_ids[part])
            else:
                pretokens = self.gpt2_pattern.findall(part)
                for pretoken in pretokens:
                    ids = self._apply_bpe_to_pretoken(pretoken)
                    all_ids.extend(ids)

        return all_ids

    def tokenize(self, text: str) -> List[str]:
        ids = self.encode(text)

        tokens = []
        for token_id in ids:
            if token_id in self.id_to_str_token:
                tokens.append(self.id_to_str_token[token_id])
            elif token_id in self.id_to_token:
                token_bytes = self.id_to_token[token_id]
                try:
                    tokens.append(token_bytes.decode('utf-8'))
                except UnicodeDecodeError:
                    tokens.append(token_bytes.hex())
            else:
                tokens.append(f"[UNK:{token_id}]")

        return tokens

    def _apply_bpe_to_pretoken(self, pretoken: str) -> List[int]:
        if not pretoken:
            return []

        pretoken_bytes = pretoken.encode('utf-8')

        tokens = [bytes([b]) for b in pretoken_bytes]

        for a, b in self.merges:
            i = 0
            new_tokens = []
            while i < len(tokens):
                if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
                    new_tokens.append(a + b)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens

        ids = []
        for token in tokens:
            if token in self.token_to_id:
                ids.append(self.token_to_id[token])

        return ids

    def decode(self, ids: List[int]) -> str:
        if not ids:
            return ""

        flat_ids = []
        for item in ids:
            if isinstance(item, list):
                flat_ids.extend(item)
            else:
                flat_ids.append(item)

        result_bytes = b''
        for token_id in flat_ids:
            if token_id in self.id_to_token:
                result_bytes += self.id_to_token[token_id]

        try:
            return result_bytes.decode('utf-8')
        except UnicodeDecodeError:
            return result_bytes.decode('utf-8', errors='replace')

    def encode_iterable(self, texts: Any) -> List[List[int]]:
        if hasattr(texts, 'read'):
            content = texts.read()
            lines = content.splitlines(keepends=False)
            text_list = [line for line in lines if line]
        elif isinstance(texts, list):
            text_list = texts
        else:
            text_list = list(texts)

        all_ids = []
        for text in text_list:
            if text:
                ids = self.encode(text)
                all_ids.extend(ids)

        return [all_ids]

    def save(self, path: str):
        data = {
            'vocab': self.vocab,
            'merges': self.merges,
            'special_tokens': self.special_tokens,
        }
        with open(path, 'wb') as f:
            pickle.dump(data, f)

    @classmethod
    def load(cls, path: str) -> 'OptimizedBPE':
        with open(path, 'rb') as f:
            data = pickle.load(f)

        bpe = cls(
            vocab_size=len(data['vocab']),
            special_tokens=data['special_tokens']
        )
        bpe.initialize_from_existing(data['vocab'], data['merges'])
        return bpe


def test_gpt2_compatibility():
    vocab = {
        0: b'<|endoftext|>',
        82: b's',
        198: b'Hello',
        2202: b'He',
        344: b'llo',
        4776: b',',
        612: b' how',
        3932: b' are',
        50256: b'<|endoftext|>',  
    }

    merges = [
        (b'H', b'e'),   # He
        (b'e', b'l'),   # el 
        (b'l', b'l'),   # ll
        (b'll', b'o'),  # llo
    ]

    bpe = OptimizedBPE(
        vocab_size=len(vocab),
        special_tokens=['<|endoftext|>']
    )
    bpe.initialize_from_existing(vocab, merges)

    test_text = "Hello, how are you?"
    tokens = bpe.tokenize(test_text)
    print(f"text: {test_text}")
    print(f"bpe: {tokens}")

    ids = bpe.encode(test_text)
    print(f"code: {ids}")
    print(f"decode: {bpe.decode(ids)}")


if __name__ == "__main__":
    test_gpt2_compatibility()

## problem4.py

In [None]:
iimport time
import tracemalloc
import cProfile
import pstats
import os
from bpetraining import train_bpe

def main():
    print("=== Question 4(a): BPE Training ===")

    # Training parameters - file is in current directory
    input_file = "TinyStoriesV2-GPT4-valid.txt"
    vocab_size = 5000
    special_tokens = ["<|endoftext|>"]

    # Check if file exists
    if not os.path.exists(input_file):
        print(f"Error: File does not exist: {input_file}")
        print("Current directory:", os.getcwd())
        print("Directory contents:", os.listdir("."))
        return

    file_size = os.path.getsize(input_file) / (1024*1024)
    print(f"Using file: {input_file}")
    print(f"File size: {file_size:.1f} MB")
    print(f"Target vocabulary size: {vocab_size}")
    print(f"Special tokens: {special_tokens}")

    # 1. Training with monitoring
    print("\nStarting training...")
    tracemalloc.start()
    start_time = time.time()

    vocab, merges = train_bpe(input_file, vocab_size, special_tokens)

    training_time = time.time() - start_time
    current_mem, peak_mem = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    # 2. Result analysis
    print("\n=== Training completed! ===")
    print(f"Training time: {training_time:.2f} seconds")
    print(f"Peak memory usage: {peak_mem / 1024 / 1024:.2f} MB")
    print(f"Vocabulary size: {len(vocab)}")
    print(f"Number of merges: {len(merges)}")

    # Find longest token
    max_len = 0
    longest_tokens = []
    for token_id, token_bytes in vocab.items():
        if token_id >= 257:  # Skip single bytes and special tokens
            token_len = len(token_bytes)
            if token_len > max_len:
                max_len = token_len
                longest_tokens = [(token_id, token_bytes)]
            elif token_len == max_len:
                longest_tokens.append((token_id, token_bytes))

    print(f"\nLongest token length: {max_len} bytes")
    if longest_tokens:
        token_id, token_bytes = longest_tokens[0]
        try:
            token_str = token_bytes.decode('utf-8', errors='replace')
            print(f"Example longest token (ID={token_id}): {repr(token_str)}")
            print(f"Hexadecimal: {token_bytes.hex()}")
        except:
            print(f"Example longest token (ID={token_id}): Cannot decode as UTF-8")
            print(f"Hexadecimal: {token_bytes.hex()}")

    # 3. Performance analysis
    print("\n=== Question 4(b): Performance Analysis ===")
    print("Running performance analysis...")

    profiler = cProfile.Profile()
    profiler.enable()

    # Run again for profiling
    vocab2, merges2 = train_bpe(input_file, vocab_size, special_tokens)

    profiler.disable()

    # Output analysis results
    stats = pstats.Stats(profiler)
    print("\nTop 10 functions by cumulative time:")
    stats.strip_dirs().sort_stats('cumulative').print_stats(10)

    print("\nTop 10 functions by internal time:")
    stats.sort_stats('time').print_stats(10)

    # Save analysis results
    profiler.dump_stats("bpe_performance.prof")
    print(f"\nDetailed performance analysis saved to: bpe_performance.prof")
    print("Use 'snakeviz bpe_performance.prof' for visualization")

    # Generate answer file
    with open("problem4_answer.txt", "w", encoding="utf-8") as f:
        f.write("# Question 4 Answer\n\n")
        f.write("## (a) BPE Training Results\n")
        f.write(f"- Training file: TinyStoriesV2-GPT4-valid.txt ({file_size:.1f}MB)\n")
        f.write(f"- Target vocab_size: {vocab_size}\n")
        f.write(f"- Actual vocab_size: {len(vocab)}\n")
        f.write(f"- Training time: {training_time:.2f} seconds\n")
        f.write(f"- Peak memory usage: {peak_mem / 1024 / 1024:.2f} MB\n")
        f.write(f"- Longest token length: {max_len} bytes\n")
        if longest_tokens:
            token_id, token_bytes = longest_tokens[0]
            try:
                token_str = token_bytes.decode('utf-8', errors='replace')
                f.write(f"- Longest token example: {repr(token_str)}\n")
            except:
                f.write(f"- Longest token example: (Cannot decode as UTF-8)\n")

        f.write("\n## (b) Performance Analysis\n")
        f.write("According to profiler output, the most time-consuming parts are usually:\n")
        f.write("1. Merge operations in the main loop (while current_vocab_size < vocab_size and pair_counts:)\n")
        f.write("2. Loop for finding the best pair (for pair, cnt in pair_counts.items():)\n")
        f.write("3. Updating pair_counts and pair_to_word_indices dictionaries\n")
        f.write("4. Text preprocessing (_pretokenize_chunk function)\n")

    print(f"\nAnswer saved to: problem4_answer.txt")

if __name__ == "__main__":
    main()