# Tokenizer
- 加载已经训练好的词表`vocab.json`、合并规则`merges.txt`
- bpe函数 | 对陌生 str 通过已经训练的规则进行处理
- 编码文本 (string -> token IDs)
- 解码tokens (token IDs -> string)

In [1]:
import json
import regex as re
from typing import Dict, List, Tuple, Iterable, Iterator

In [2]:
class Tokenizer:
    def __init__(
        self,
        vocab: Dict[int, bytes],
        merges: List[Tuple[bytes, bytes]],
        special_tokens: List[str] | None = None,
    ):
        self.vocab = vocab
        self.vocab_rev = {v: k for k, v in vocab.items()}
        self.special_tokens = special_tokens or []

        # 编译正则
        self.PAT = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

        # 预建 merge rank 表
        self.bpe_ranks = {pair: i for i, pair in enumerate(merges)}

    @classmethod
    def from_files(
        cls,
        vocab_path: str,
        merges_path: str,
        special_tokens: List[str] | None = None,
    ) -> "Tokenizer":
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab_data = json.load(f)
        vocab = {int(i): v.encode("utf-8") for i, v in vocab_data.items()}

        merges = []
        with open(merges_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                # 忽略注释和非法行
                if not line or line.startswith("#"):
                    continue
                parts = line.split()
                if len(parts) != 2:
                    continue  # 跳过不合规的行
                a, b = parts
                merges.append((a.encode(), b.encode()))

        return cls(vocab, merges, special_tokens)

    def get_pairs(self, tokens: List[bytes]): # 取相邻 token pair
        return {(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)}

    def bpe(self, token: bytes) -> List[bytes]: # 执行 Byte-Pair Encoding
        word = [bytes([b]) for b in token]  # byte-level 初始分词
        pairs = self.get_pairs(word)

        while pairs:# 找到当前可合并的最小 rank(最高频率)
            min_pair = min(
                pairs, key=lambda p: self.bpe_ranks.get(p, float("inf"))
            )
            if min_pair not in self.bpe_ranks:
                break

            new_word = []
            i = 0
            while i < len(word):
                if (
                    i < len(word) - 1
                    and word[i] == min_pair[0]
                    and word[i + 1] == min_pair[1]
                ):
                    new_word.append(word[i] + word[i + 1])
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            word = new_word
            pairs = self.get_pairs(word)

        return word
        
    def encode(self, text: str) -> List[int]:
        """Encode text → token ids"""
        ids = []
        for match in self.PAT.finditer(text):
            token = match.group(0).encode("utf-8")
            for t in self.bpe(token):
                if t in self.vocab_rev:
                    ids.append(self.vocab_rev[t])
        return ids

    # 流式编码
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for line in iterable: # 逐行读取
            for tid in self.encode(line):
                yield tid # 对每一行中的内容诸葛返回 token id

    def decode(self, ids: List[int]) -> str:
        byte_stream = b"".join(
            self.vocab.get(i, b"\xef\xbf\xbd") for i in ids
        )
        return byte_stream.decode("utf-8", errors="replace")

    def token_to_id(self, token: str) -> int: # 将特殊token或普通token字符串 → 对应ID
        b = token.encode("utf-8")
        return self.vocab_rev.get(b, -1)

    def id_to_token(self, idx: int) -> str: # 将ID → 对应token字符串
        return self.vocab.get(idx, b"<?>").decode("utf-8", errors="replace")

In [3]:
import os, json, numpy as np
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

def init_worker(vocab_path, merges_path, special_tokens):
    global _tokenizer # 每个子进程初始化自己的 tokenizer
    _tokenizer = Tokenizer.from_files(vocab_path, merges_path, special_tokens)

def encode_line(line: str): # 子进程执行的编码函数
    global _tokenizer
    if not line.strip():
        return []
    return _tokenizer.encode(line)

In [4]:
def parallel_encode_file_streaming(
    input_txt_path: str,
    output_tokens_path: str,
    vocab_path: str = "./bpeModel/vocab.json",
    merges_path: str = "./bpeModel/merges.txt",
    special_tokens=["<|endoftext|>"],
    num_workers: int = max(1, cpu_count() - 4),
    chunk_size: int = 10000,  # 每批次读取多少行
):
    print(f"Using {num_workers} processes to encode {input_txt_path}")
    print(f"Streaming mode: reading {chunk_size} lines per batch")

    input_txt_path = Path(input_txt_path)
    output_tokens_path = Path(output_tokens_path)
    os.makedirs(output_tokens_path.parent, exist_ok=True)

    tmp_path = output_tokens_path.with_suffix(".tmp.npy") # 缓存文件防止中途破坏影响最终结果
    fp = np.memmap(tmp_path, mode="w+", dtype=np.uint16, shape=(0,))  # 写入硬盘

    total_tokens = 0
    buffer = []

    with Pool(
        processes=num_workers,
        initializer=init_worker,
        initargs=(vocab_path, merges_path, special_tokens),
    ) as pool, open(input_txt_path, "r", encoding="utf-8") as f:

        while True:
            lines = [line for _, line in zip(range(chunk_size), f) if line.strip()]
            if not lines:
                break

            for encoded in pool.imap(encode_line, lines, chunksize=128):
                buffer.extend(encoded)

            if buffer:
                arr = np.array(buffer, dtype=np.uint16)
                if total_tokens == 0:
                    fp = arr
                else:
                    fp = np.concatenate((fp, arr))
                total_tokens += len(arr)
                buffer.clear()

            print(f"-> processed {total_tokens:,} tokens so far...")

    np.save(output_tokens_path, fp)
    print(f"Saved {total_tokens:,} tokens → {output_tokens_path}")

In [8]:
# parallel_encode_file_streaming(
#     input_txt_path="./datasets/TinyStories/train_with_eot.txt",
#     output_tokens_path="./datasets/tokens_train.npy",
# )

   Using 10 processes to encode ./datasets/TinyStories/train_with_eot.txt
   Streaming mode: reading 10000 lines per batch
  → processed 3,470,426 tokens so far...
  → processed 7,178,279 tokens so far...
  → processed 10,685,225 tokens so far...
  → processed 14,427,430 tokens so far...
  → processed 17,848,816 tokens so far...
  → processed 21,445,993 tokens so far...
  → processed 25,042,640 tokens so far...
  → processed 28,454,082 tokens so far...
  → processed 31,855,955 tokens so far...
  → processed 35,347,123 tokens so far...
  → processed 38,980,635 tokens so far...
  → processed 42,701,691 tokens so far...
  → processed 46,289,058 tokens so far...
  → processed 50,039,497 tokens so far...
  → processed 53,734,376 tokens so far...
  → processed 57,509,188 tokens so far...
  → processed 61,440,786 tokens so far...
  → processed 64,976,791 tokens so far...
  → processed 68,584,368 tokens so far...
  → processed 72,320,727 tokens so far...
  → processed 76,007,961 tokens so far.

In [5]:
# parallel_encode_file_streaming(
#     input_txt_path="./datasets/TinyStories/valid_with_eot.txt",
#     output_tokens_path="./datasets/tokens_valid.npy",
# )

   Using 12 processes to encode ./datasets/TinyStories/valid_with_eot.txt
   Streaming mode: reading 10000 lines per batch
  -> processed 3,298,494 tokens so far...
  -> processed 6,922,095 tokens so far...
  -> processed 7,670,696 tokens so far...
    Saved 7,670,696 tokens → datasets/tokens_valid.npy
