In [17]:
# bpe_trainer.py
# added delta incremental updates for pair frequencies using chatgpt.

from __future__ import annotations

import json
import os
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import Iterable, List, Tuple, Dict, BinaryIO
import heapq

import regex as re  # supports \p{L}, \p{N}, etc.

# ---------- Shared pattern ----------
PAT = re.compile(
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

class BPETrainer:
    """
    Owns corpus scanning, pre-tokenization, stats, and merge learning.
    Trains byte-level BPE and can export (vocab, merges).
    Uses delta updates for pair frequencies (fast).
    """

    def __init__(self, byte_level: bool = True, reference_mode: bool = True):
        self.PAT = PAT
        self.special_tokens: List[str] = ["<|endoftext|>"]
        self.byte_level = byte_level

        # Training state
        self.word_counts: Counter[bytes] = Counter()           # token (bytes) -> count
        self.segmented_words: Dict[bytes, List[bytes]] = {}    # token -> list of byte symbols
        self.vocab_index: Dict[bytes, int] = {}                # symbol -> stable id (non-special)
        self.merges: List[Tuple[bytes, bytes]] = []

        # Pair-tracking (delta update machinery)
        self.pair_freq: Counter[Tuple[bytes, bytes]] = Counter()
        self.pair_occ: Dict[Tuple[bytes, bytes], set[bytes]] = defaultdict(set)
        self.pair_heap: List[Tuple[int, Tuple[bytes, bytes]]] = []  # (-freq, pair)

        # Dirty flag (only for initial build)
        self._pairs_dirty = True

        # Reference mode (for testing)
        # at no point I tested correcness of my implementation against the original BPETrainer
        # this is for debugging
        self.reference_mode = reference_mode

    # ---------- Public pipeline (serial pretokenization) ----------

    def pretokenize(self, corpus: Iterable[str]) -> "BPETrainer":
        """
        Serial pretokenization: build word_counts (bytes -> freq) skipping special tokens.
        """
        wc: Counter[bytes] = Counter()
        specials_s = set(self.special_tokens)

        for line in corpus:
            line = line.rstrip("\n")
            # remove *all* occurrences of each special, not just full-line matches
            for s in specials_s:
                line = line.replace(s, " ")

            for m in self.PAT.finditer(line):
                tok = m.group(0)
                if not tok:
                    continue
                wc[tok.encode("utf-8")] += 1

        self._init_state_from_counts(wc)
        return self

    def compute_pair_stats(self) -> "BPETrainer":
        """
        One-time build of pair frequencies, occurrences, and heap from current segmentations.
        """
        import heapq

        self.pair_freq.clear()
        self.pair_occ.clear()
        self.pair_heap.clear()

        for w, freq in self.word_counts.items():
            seq = self.segmented_words[w]
            for a, b in zip(seq, seq[1:]):
                p = (a, b)
                self.pair_freq[p] += freq
                self.pair_occ[p].add(w)

        for p, f in self.pair_freq.items():
            if f > 0:
                heapq.heappush(self.pair_heap, (-f, p))

        self._pairs_dirty = False
        return self

    def fit_to_vocab_size(
        self, vocab_size: int, special_tokens: List[str], progress: bool = True
    ) -> "BPETrainer":
        """
        Greedy BPE loop using delta updates; stops when (256 base + merges) + specials = vocab_size.
        """

        target_non_special = max(0, vocab_size - len(special_tokens))

        if self._pairs_dirty or not self.pair_heap:
            self.compute_pair_stats()

        while len(self.vocab_index) < target_non_special and self.pair_heap:
            best_pair = None
            # Lazy-pop until top of heap matches current freq (skip stale)
            while self.pair_heap:
                negf, p = heapq.heappop(self.pair_heap)
                f = -negf
                if self.pair_freq.get(p, 0) == f and f > 0:
                    best_pair = p
                    break
            if best_pair is None:
                break

            self._apply_merge_delta(best_pair)
            self.merges.append(best_pair)
            print(f"Learned merge: {best_pair[0]} + {best_pair[1]})")

            if progress and (len(self.merges) % 1000 == 0):
                pass  # add logging if desired

        return self

    def export_vocab_and_merges(
        self, special_tokens: List[str], vocab_size: int
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        """
        Build final vocab (IDs -> bytes) with special tokens first, then symbols by stable index.
        Caps size at vocab_size.
        """
        vocab: Dict[int, bytes] = {i: s.encode("utf-8") for i, s in enumerate(special_tokens)}
        offset = len(special_tokens)

        for sym, idx in sorted(self.vocab_index.items(), key=lambda kv: kv[1]):
            tid = offset + idx
            if tid >= vocab_size:
                break
            vocab[tid] = sym

        return vocab, list(self.merges)

    # ---------- Private helpers ----------

    def _init_state_from_counts(self, wc: Counter[bytes]) -> None:
        """
        Initialize trainer state from precomputed word_counts (bytes -> freq).
        """
        self.word_counts = wc
        # Token -> sequence of single-byte symbols
        self.segmented_words = {w: [bytes([b]) for b in w] for w in wc}
        # Base 256-byte vocab
        base_symbols = [bytes([b]) for b in range(256)]
        self.vocab_index = {sym: i for i, sym in enumerate(base_symbols)}
        self._pairs_dirty = True

    def _apply_merge_delta(self, pair: Tuple[bytes, bytes]) -> None:
        """
        Merge (a,b)->c with incremental updates to pair_freq, pair_occ, and heap.
        Only words containing the pair are touched.
        """

        a, b = pair
        merged = a + b
        if merged not in self.vocab_index:
            self.vocab_index[merged] = len(self.vocab_index)

        affected = list(self.pair_occ.get(pair, set()))
        self.pair_occ[pair].clear()

        for w in affected:
            seq = self.segmented_words[w]
            if not seq:
                continue

            # Old pairs for this word
            old_pairs = list(zip(seq, seq[1:]))

            # Apply merge locally
            i = 0
            new_seq: List[bytes] = []
            changed = False
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                    new_seq.append(merged)
                    i += 2
                    changed = True
                else:
                    new_seq.append(seq[i])
                    i += 1

            if not changed:
                continue  # stale occurrence

            # New pairs
            new_pairs = list(zip(new_seq, new_seq[1:]))
            freq_w = self.word_counts[w]

            # Decrement old contributions
            for p in old_pairs:
                self.pair_freq[p] -= freq_w
                if self.pair_freq[p] <= 0:
                    self.pair_freq[p] = 0
                    # occ cleanup is lazy

            # Increment new contributions, update occ and heap lazily
            for p in new_pairs:
                self.pair_freq[p] += freq_w
                if self.pair_freq[p] > 0:
                    self.pair_occ[p].add(w)
                    heapq.heappush(self.pair_heap, (-self.pair_freq[p], p))

            # Save updated segmentation
            self.segmented_words[w] = new_seq



# -------- Parallel pretokenization helpers (optional) --------

def find_chunk_boundaries(
    file: BinaryIO, desired_num_chunks: int, split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts ending at a delimiter; may return fewer chunks.
    """
    assert isinstance(split_special_token, bytes)
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = max(1, file_size // desired_num_chunks)
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 1 << 20  # 1MB probe for fewer syscalls
    for bi in range(1, len(chunk_boundaries) - 1):
        pos = chunk_boundaries[bi]
        file.seek(pos)
        while True:
            buf = file.read(mini_chunk_size)
            if buf == b"":
                chunk_boundaries[bi] = file_size
                break
            j = buf.find(split_special_token)
            if j != -1:
                chunk_boundaries[bi] = pos + j
                break
            pos += mini_chunk_size

    return sorted(set(chunk_boundaries))


def _count_slice_from_file(path: str, start: int, end: int, specials_s: set[str]) -> Counter[str]:
    """
    Worker: count tokens (string keys) in one slice. Strings are cheaper; convert once later.
    """
    with open(path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")

    for s in specials_s:
        chunk = chunk.replace(s, " ")

    c = Counter()
    for m in PAT.finditer(chunk):
        tok = m.group(0)
        if tok:
            c[tok] += 1
    return c


def parallel_counts_from_boundaries(
    input_path: str, boundaries: list[int], special_tokens: List[str], max_workers: int
) -> Counter[bytes]:
    """
    Map→Reduce: produce global word_counts (bytes -> freq) from chunk boundaries.
    """
    totals_str = Counter()
    specials_s = set(special_tokens)
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futs = [
            ex.submit(_count_slice_from_file, input_path, s, e, specials_s)
            for s, e in zip(boundaries[:-1], boundaries[1:])
        ]
        for fu in futs:
            totals_str.update(fu.result())

    # Convert keys to bytes once (cheaper than encoding per token)
    return Counter({k.encode("utf-8"): v for k, v in totals_str.items()})


# -------- Public training function (deliverable) --------

def train_bpe_tokenizer(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str],
    *,
    num_processes: int | None = None,
    delimiter: bytes = b"<|endoftext|>",
    parallel: bool = False,
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    Trains a byte-level BPE tokenizer and returns (vocab_id_to_bytes, merges).

    - Ensures base 256-byte alphabet + specials are included within vocab_size.
    - If parallel=True, uses chunked pretokenization; otherwise serial.
    """
    required_min = len(special_tokens) + 256
    if vocab_size < required_min:
        raise ValueError(
            f"vocab_size must be at least {required_min} "
            f"(= {len(special_tokens)} specials + 256 base bytes)"
        )

    trainer = BPETrainer(byte_level=True)
    trainer.special_tokens = list(special_tokens)

    if parallel:
        if num_processes is None:
            num_processes = os.cpu_count() or 4
        with open(input_path, "rb") as f:
            boundaries = find_chunk_boundaries(f, num_processes, delimiter)
        wc = parallel_counts_from_boundaries(input_path, boundaries, special_tokens, num_processes)
        trainer._init_state_from_counts(wc)
    else:
        with open(input_path, "r", encoding="utf-8") as f:
            trainer.pretokenize(f)

    trainer.compute_pair_stats().fit_to_vocab_size(vocab_size, special_tokens)
    return trainer.export_vocab_and_merges(special_tokens, vocab_size)


# ---- Example main (optional) ----
if __name__ == "__main__":
    vocab, merges = train_bpe_tokenizer(
        input_path="../data/corpus.en",
        vocab_size=300,
        special_tokens=["<|endoftext|>"],
        parallel=False,
        num_processes=4,
    )
    print(f"Returned vocab size: {len(vocab)} | Merges learned: {len(merges)}")


Learned merge: b' ' + b't')
Learned merge: b' ' + b'a')
Learned merge: b'h' + b'e')
Learned merge: b'i' + b'n')
Learned merge: b' t' + b'he')
Learned merge: b'r' + b'e')
Learned merge: b' ' + b'o')
Learned merge: b' ' + b',')
Learned merge: b'e' + b'r')
Learned merge: b' ' + b's')
Learned merge: b'a' + b't')
Learned merge: b' ' + b'.')
Learned merge: b'n' + b'd')
Learned merge: b'i' + b's')
Learned merge: b'o' + b'r')
Learned merge: b' ' + b'w')
Learned merge: b' ' + b'c')
Learned merge: b'o' + b'n')
Learned merge: b' ' + b'b')
Learned merge: b' ' + b'f')
Learned merge: b'o' + b'u')
Learned merge: b'i' + b't')
Learned merge: b'e' + b'n')
Learned merge: b'e' + b's')
Learned merge: b' o' + b'f')
Learned merge: b' ' + b'p')
Learned merge: b'in' + b'g')
Learned merge: b' ' + b'in')
Learned merge: b'e' + b'd')
Learned merge: b'a' + b'l')
Learned merge: b' ' + b'm')
Learned merge: b' ' + b'd')
Learned merge: b' a' + b'nd')
Learned merge: b'a' + b'n')
Learned merge: b'a' + b'r')
Learned merge

In [19]:
# bpe_trainer.py
# added delta incremental updates for pair frequencies using chatgpt.

# bpe_trainer.py
from __future__ import annotations

import json
import os
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from typing import Iterable, List, Tuple, Dict, BinaryIO
from collections import defaultdict

import regex as re  # supports \p{L}, \p{N}, etc.

# ---------- Shared pattern ----------
PAT = re.compile(
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)


class BPETrainer:
    """
    Owns corpus scanning, pre-tokenization, stats, and merge learning.
    Trains byte-level BPE and can export (vocab, merges).
    Reference implementation: full recount each step with deterministic tie-break.
    """

    def __init__(self, byte_level: bool = True):
        self.PAT = PAT
        self.special_tokens: List[str] = ["<|endoftext|>"]
        self.byte_level = byte_level

        # Training state
        self.word_counts: Counter[bytes] = Counter()           # token (bytes) -> count
        self.segmented_words: Dict[bytes, List[bytes]] = {}    # token -> list of byte symbols
        self.vocab_index: Dict[bytes, int] = {}                # symbol -> stable id (non-special)
        self.merges: List[Tuple[bytes, bytes]] = []

        # Pair counts (rebuilt every iteration)
        self.pair_counts: Counter[Tuple[bytes, bytes]] = Counter()

        self.pair_freq: Counter[tuple[bytes, bytes]] = Counter()      # live pair frequencies
        self.pair_occ: dict[tuple[bytes, bytes], set[bytes]] = defaultdict(set)  # which words contain a pair

    # ---------- Public pipeline (serial pretokenization) ----------

    def pretokenize(self, corpus: Iterable[str]) -> "BPETrainer":
        """
        Serial pretokenization: build word_counts (bytes -> freq).
        Special-token strings are removed from text so they don't affect merges.
        """
        wc: Counter[bytes] = Counter()
        specials_s = set(self.special_tokens)

        for line in corpus:
            line = line.rstrip("\n")
            # Remove any occurrences of special-token strings
            for s in specials_s:
                line = line.replace(s, " ")

            for m in self.PAT.finditer(line):
                tok = m.group(0)
                if not tok:
                    continue
                wc[tok.encode("utf-8")] += 1

        self._init_state_from_counts(wc)
        return self

    def compute_pair_stats(self) -> "BPETrainer":
        """One-time build of pair frequencies and occurrences."""
        self.pair_freq.clear()
        self.pair_occ.clear()
        for w, freq in self.word_counts.items():
            seq = self.segmented_words[w]
            for a, b in zip(seq, seq[1:]):
                p = (a, b)
                self.pair_freq[p] += freq
                self.pair_occ[p].add(w)
        return self

    def fit_to_vocab_size(self, vocab_size: int, special_tokens: list[str], progress: bool = True) -> "BPETrainer":
        target_non_special = max(0, vocab_size - len(special_tokens))

        # build initial pair stats (once)
        if not self.pair_freq:
            self.compute_pair_stats()

        while len(self.vocab_index) < target_non_special and self.pair_freq:
            # choose best by frequency, tie-break by lexicographically GREATER pair
            maxf = max(self.pair_freq.values())
            # NOTE: max(...) over tuples implements lexicographically greater
            best_pair = max((p for p, f in self.pair_freq.items() if f == maxf))

            # apply merge with local delta updates
            self._apply_merge_delta_simple(best_pair)
            self.merges.append(best_pair)

            if progress and (len(self.merges) % 1000 == 0):
                pass  # add logging if you want

        return self

    def export_vocab_and_merges(
        self, special_tokens: List[str], vocab_size: int
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        """
        Build final vocab (IDs -> bytes) with special tokens first, then symbols by stable index.
        Caps size at vocab_size.
        """
        vocab: Dict[int, bytes] = {i: s.encode("utf-8") for i, s in enumerate(special_tokens)}
        offset = len(special_tokens)

        for sym, idx in sorted(self.vocab_index.items(), key=lambda kv: kv[1]):
            tid = offset + idx
            if tid >= vocab_size:
                break
            vocab[tid] = sym

        return vocab, list(self.merges)

    # ---------- Private helpers ----------

    def _init_state_from_counts(self, wc: Counter[bytes]) -> None:
        """Initialize trainer state from precomputed word_counts (bytes -> freq)."""
        self.word_counts = wc
        # Token -> sequence of single-byte symbols
        self.segmented_words = {w: [bytes([b]) for b in w] for w in wc}
        # Base 256-byte vocab (ids 0..255 by byte value)
        self.vocab_index = {bytes([b]): b for b in range(256)}

    def _apply_merge_full(self, pair: Tuple[bytes, bytes]) -> None:
        """Apply merge (a,b)->a+b to all words (no delta structures)."""
        a, b = pair
        merged = a + b
        if merged not in self.vocab_index:
            self.vocab_index[merged] = len(self.vocab_index)

        for w, seq in self.segmented_words.items():
            i = 0
            out: List[bytes] = []
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                    out.append(merged)
                    i += 2
                else:
                    out.append(seq[i])
                    i += 1
            self.segmented_words[w] = out

    def _apply_merge_delta_simple(self, pair: tuple[bytes, bytes]) -> None:
        """
        Merge (a,b)->a+b by updating only words that contain the pair.
        Recompute pairs for each affected word and update pair_freq/pair_occ.
        """
        a, b = pair
        merged = a + b
        if merged not in self.vocab_index:
            self.vocab_index[merged] = len(self.vocab_index)

        affected_words = list(self.pair_occ.get(pair, set()))
        # after we process them, this pair’s occurrences disappear
        self.pair_occ[pair].clear()
        self.pair_freq[pair] = 0  # its count goes to zero once all replacements are done

        for w in affected_words:
            seq = self.segmented_words[w]
            if not seq:
                continue

            # ----- old pairs for this word (before) -----
            old_pairs = list(zip(seq, seq[1:]))

            # ----- apply merge locally -----
            i = 0
            new_seq: list[bytes] = []
            changed = False
            while i < len(seq):
                if i < len(seq) - 1 and seq[i] == a and seq[i + 1] == b:
                    new_seq.append(merged)
                    i += 2
                    changed = True
                else:
                    new_seq.append(seq[i])
                    i += 1

            if not changed:
                # stale membership; skip
                continue

            # ----- new pairs for this word (after) -----
            new_pairs = list(zip(new_seq, new_seq[1:]))
            freq_w = self.word_counts[w]

            # ----- delta update: remove old contributions -----
            for p in old_pairs:
                self.pair_freq[p] -= freq_w
                if self.pair_freq[p] <= 0:
                    # fully remove to keep dict small; also clear occ set
                    self.pair_freq.pop(p, None)
                    s = self.pair_occ.get(p)
                    if s is not None:
                        s.discard(w)
                        if not s:
                            self.pair_occ.pop(p, None)
                else:
                    # still present globally; ensure w no longer listed for p
                    s = self.pair_occ.get(p)
                    if s is not None:
                        s.discard(w)
                        if not s:
                            self.pair_occ.pop(p, None)

            # ----- delta update: add new contributions -----
            for p in new_pairs:
                self.pair_freq[p] += freq_w
                self.pair_occ.setdefault(p, set()).add(w)

            # save updated segmentation
            self.segmented_words[w] = new_seq
    
# -------- Parallel pretokenization helpers (optional) --------

def find_chunk_boundaries(
    file: BinaryIO, desired_num_chunks: int, split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts ending at a delimiter; may return fewer chunks.
    """
    assert isinstance(split_special_token, bytes)
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = max(1, file_size // desired_num_chunks)
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 1 << 20  # 1MB probe for fewer syscalls
    for bi in range(1, len(chunk_boundaries) - 1):
        pos = chunk_boundaries[bi]
        file.seek(pos)
        while True:
            buf = file.read(mini_chunk_size)
            if buf == b"":
                chunk_boundaries[bi] = file_size
                break
            j = buf.find(split_special_token)
            if j != -1:
                chunk_boundaries[bi] = pos + j
                break
            pos += mini_chunk_size

    return sorted(set(chunk_boundaries))


def _count_slice_from_file(path: str, start: int, end: int, specials_s: set[str]) -> Counter[str]:
    """
    Worker: count tokens (string keys) in one slice. Strings are cheaper; convert once later.
    """
    with open(path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")

    # strip specials so their fragments never appear
    for s in specials_s:
        chunk = chunk.replace(s, " ")

    c = Counter()
    for m in PAT.finditer(chunk):
        tok = m.group(0)
        if tok:
            c[tok] += 1
    return c


def parallel_counts_from_boundaries(
    input_path: str, boundaries: list[int], special_tokens: List[str], max_workers: int
) -> Counter[bytes]:
    """
    Map→Reduce: produce global word_counts (bytes -> freq) from chunk boundaries.
    """
    totals_str = Counter()
    specials_s = set(special_tokens)
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futs = [
            ex.submit(_count_slice_from_file, input_path, s, e, specials_s)
            for s, e in zip(boundaries[:-1], boundaries[1:])
        ]
        for fu in futs:
            totals_str.update(fu.result())

    # Convert keys to bytes once (cheaper than encoding per token)
    return Counter({k.encode("utf-8"): v for k, v in totals_str.items()})


# -------- Public training function (deliverable) --------

def train_bpe_tokenizer(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str],
    *,
    num_processes: int | None = None,
    delimiter: bytes = b"<|endoftext|>",
    parallel: bool = False,
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    Trains a byte-level BPE tokenizer and returns (vocab_id_to_bytes, merges).

    - Ensures base 256-byte alphabet + specials are included within vocab_size.
    - If parallel=True, uses chunked pretokenization; otherwise serial.
    """
    required_min = len(special_tokens) + 256
    if vocab_size < required_min:
        raise ValueError(
            f"vocab_size must be at least {required_min} "
            f"(= {len(special_tokens)} specials + 256 base bytes)"
        )

    trainer = BPETrainer(byte_level=True)
    trainer.special_tokens = list(special_tokens)

    if parallel:
        if num_processes is None:
            num_processes = os.cpu_count() or 4
        with open(input_path, "rb") as f:
            boundaries = find_chunk_boundaries(f, num_processes, delimiter)
        wc = parallel_counts_from_boundaries(input_path, boundaries, special_tokens, num_processes)
        trainer._init_state_from_counts(wc)
    else:
        with open(input_path, "r", encoding="utf-8") as f:
            trainer.pretokenize(f)

    trainer.fit_to_vocab_size(vocab_size, special_tokens)
    return trainer.export_vocab_and_merges(special_tokens, vocab_size)


# ---- Example main (optional) ----
if __name__ == "__main__":
    vocab, merges = train_bpe_tokenizer(
        input_path="../data/tinystories_sample_5M.txt",
        vocab_size=1000,
        special_tokens=["<|endoftext|>"],
        parallel=False,
        num_processes=4,
    )
    print(f"Returned vocab size: {len(vocab)} | Merges learned: {len(merges)}")


Returned vocab size: 1000 | Merges learned: 743


In [22]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    if b"<|" in word_bytes:
        print(f"Warning: Found special token in word: {word_bytes}")

In [15]:
word_bytes

b'oon'

In [16]:
vocabs_without_specials

[b'\x00',
 b'\x01',
 b'\x02',
 b'\x03',
 b'\x04',
 b'\x05',
 b'\x06',
 b'\x07',
 b'\x08',
 b'\t',
 b'\n',
 b'\x0b',
 b'\x0c',
 b'\r',
 b'\x0e',
 b'\x0f',
 b'\x10',
 b'\x11',
 b'\x12',
 b'\x13',
 b'\x14',
 b'\x15',
 b'\x16',
 b'\x17',
 b'\x18',
 b'\x19',
 b'\x1a',
 b'\x1b',
 b'\x1c',
 b'\x1d',
 b'\x1e',
 b'\x1f',
 b' ',
 b'!',
 b'"',
 b'#',
 b'$',
 b'%',
 b'&',
 b"'",
 b'(',
 b')',
 b'*',
 b'+',
 b',',
 b'-',
 b'.',
 b'/',
 b'0',
 b'1',
 b'2',
 b'3',
 b'4',
 b'5',
 b'6',
 b'7',
 b'8',
 b'9',
 b':',
 b';',
 b'<',
 b'=',
 b'>',
 b'?',
 b'@',
 b'A',
 b'B',
 b'C',
 b'D',
 b'E',
 b'F',
 b'G',
 b'H',
 b'I',
 b'J',
 b'K',
 b'L',
 b'M',
 b'N',
 b'O',
 b'P',
 b'Q',
 b'R',
 b'S',
 b'T',
 b'U',
 b'V',
 b'W',
 b'X',
 b'Y',
 b'Z',
 b'[',
 b'\\',
 b']',
 b'^',
 b'_',
 b'`',
 b'a',
 b'b',
 b'c',
 b'd',
 b'e',
 b'f',
 b'g',
 b'h',
 b'i',
 b'j',
 b'k',
 b'l',
 b'm',
 b'n',
 b'o',
 b'p',
 b'q',
 b'r',
 b's',
 b't',
 b'u',
 b'v',
 b'w',
 b'x',
 b'y',
 b'z',
 b'{',
 b'|',
 b'}',
 b'~',
 b'\x7f',
 b'\x80',


In [24]:
# input_path = FIXTURES_PATH / "tinystories_sample_5M.txt"
vocab, merges = train_bpe_tokenizer(
    input_path="../data/tinystories_sample_5M.txt",
    vocab_size=1000,
    special_tokens=["<|endoftext|>"],
)
# save vocab to disk for debugging
# with open("train-bpe-special-tokens-vocab.json", "w", encoding="utf-8") as f:
#     json.dump(vocab, f, ensure_ascii=False, indent=2)
# Check that the special token is not in the vocab
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]
for word_bytes in vocabs_without_specials:
    if b"<|" in word_bytes:
        print(f"Warning: Found special token in word: {word_bytes}")
    assert b"<|" not in word_bytes

# snapshot.assert_match(
#     {
#         "vocab_keys": set(vocab.keys()),
#         "vocab_values": set(vocab.values()),
#         "merges": merges,
#     },
# )