# BPE skeleton (exercise)

Reference: [../bpe.ipynb](../bpe.ipynb). Implement the TODOs in the next cell.

**How BPE works (TLDR):**
- **merge(ids, pair, idx)**: replace all occurrences of `pair` in `ids` with `idx`. Returns list of ints.
- **get_stats(ids)**: count frequencies of adjacent pairs. Returns dict (pair -> count).
- **BPE.train(text, vocab_size)**: iterative loop â€” get_stats, pick best pair, merge, update merges and vocab.
- **BPE.encode(text)**: apply merges in order until no more. Returns list of token ids.
- **BPE.decode(ids)**: map ids to bytes via vocab, concatenate, decode to string. Returns str.

Merges must be applied in the same order as in training.

In [None]:
from typing import List, Dict, Tuple

def merge(ids: List[int], pair: Tuple[int, int], idx: int) -> List[int]:
    # TODO: Replace all occurrences of (ids[i], ids[i+1]) == pair with idx. Return new list of ints.
    # Example: ids=[1,2,3,1,2,4], pair=(1,2), idx=7 -> [7, 3, 7, 4]
    raise NotImplementedError("TODO: implement merge")

def get_stats(ids: List[int]) -> Dict[Tuple[int, int], int]:
    # TODO: Count frequencies of adjacent pairs (ids[i], ids[i+1]). Return dict pair -> count.
    raise NotImplementedError("TODO: implement get_stats")

class Tokenizer:
    def __init__(self):
        self.merges = {}  # (int, int) -> int
        self.pattern = ""
        self.vocab = self._build_vocab()

    def train(self):
        raise NotImplementedError

    def encode(self):
        raise NotImplementedError

    def decode(self):
        raise NotImplementedError

    def _build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (i1, i2), idx in self.merges.items():
            vocab[idx] = vocab[i1] + vocab[i2]
        return vocab

class BPE(Tokenizer):
    def __init__(self):
        super().__init__()

    def train(self, text: str, vocab_size: int):
        # TODO: Iterative loop. Start with ids = list(text.encode('utf-8')).
        # For num_merges = vocab_size - 256: get_stats(ids), pick best pair, merge(ids, best_pair, new_idx),
        # update merges and vocab. Then set self.merges, self.vocab (and re-build vocab from merges).
        raise NotImplementedError("TODO: implement BPE.train")

    def encode(self, text: str) -> List[int]:
        # TODO: Start with ids = list(text.encode('utf-8')). Apply merges in order until no more can be applied.
        # Return list of token ids.
        raise NotImplementedError("TODO: implement BPE.encode")

    def decode(self, ids: List[int]) -> str:
        # TODO: Map each id to bytes via self.vocab, concatenate, decode to UTF-8 string.
        raise NotImplementedError("TODO: implement BPE.decode")

In [None]:
# Test: after implementing the TODOs above, run this cell.
text = "The quick brown fox jumps over the lazy dog."
print(f"Original text: {text}")

tokenizer = BPE()
tokenizer.train(text, vocab_size=300)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")
print(f"Match: {text == decoded}")