# Byte Pair Encoding (BPE) Trainer
- 文本预分词 | special token 处理  
- 统计相邻 pair 频率 | 更新增量 
- 输出 vocab.json 与 merges.txt  

### 问题抽象
现有 $n$ 组( $n = 2 \times 10^6 $ , 每组长度约 $120$ ~ $150$ )变长字符串 $str_i$ ($len(str_i) \leq 20,~i \in [0, \, n)$)，总长度 $\sum_{i = 0}^n len(str_i) \leq 150 \times 10 \times 2 \times 10^6 = 3 \times 10^9$，

$vocab$ 为 $token ~ id$ 对 $Unicode ~ byte(s)$ 的映射集合（也就是 $token$ 集合），初始化为 $[0, 256)$ 对应其十六进制数所代表的 $Unicode ~ byte$ (可以理解为 $ASCII ~ plus$) 

**持续执行如下操作：**

对每个字符串 ***所有最长 $token$*** 进行两两结合统计频率：

<aside>

初始状态每个字符是一个 $token$

`word` → `wo` : 1  `or` : 1  `rd` : 1

> 什么是***最长 $token$*** ？如 `newest` ，`est` 已分配 $token ~ id$，我们仅对 `ne` `ew` `west` 处理
> 
</aside>

将频率最高的一个 $pair$ （若并列则取最高字典序）分配新的 $token ~ id$ 

比如上述 `word` 则新分配 `279` → `wo`

**截止状态：**

最终整个序列已经分配了 $token ~ id$（几乎不可能） 或 $token ~ id$ 的数目 = $vocab\_size$

$token ~ id \leq vocab\_size = 10^4$

---

#### 输入

二维字符串数组

#### 输出

$vocab$ : 从 $token ~ id$ 到 $bytes$ 的映射

$merges$ : 产生的 合并 $pair$（按创建顺序排列）

In [7]:
import os, json
import regex as re
from collections import Counter
from typing import List, Dict, Tuple

In [8]:
def pretokenize(text: str, pattern, special_pattern, special_lookup):
    tokens = []
    if special_pattern is not None: # 处理 <|endoftxt|>
        parts = re.split(special_pattern, text)
        matches = re.findall(special_pattern, text)
        for i, part in enumerate(parts):
            if part:
                tokens.extend([m.group(0) for m in re.finditer(pattern, part)])
            if i < len(matches):
                tokens.append(matches[i])
    else:
        tokens = [m.group(0) for m in re.finditer(pattern, text)]

    # 转成字节序列表 bytes([b]) 生成单个字节对象 | bytes(b) 生成整个字节序列
    corpus = [[bytes([b]) for b in token.encode("utf-8")] for token in tokens]
    return corpus, len(corpus)

def count_pairs(corpus, special_set):
    pairs = Counter()
    for word in corpus: # 词内合并
        for a, b in zip(word, word[1:]): # 两两配对
            if a not in special_set and b not in special_set:
                pairs[(a, b)] += 1
    return pairs

def decrement_pair(pair_counts, pair): # 删除操作，防止内存爆掉
    if pair_counts[pair] > 1:
        pair_counts[pair] -= 1
    else:
        del pair_counts[pair] # 释放 counts[b'x']为 0 的空间，提高排序效率 | 减少内存堆积

def apply_merge(corpus, pair_counts, merge_pair, special_set):
    a, b = merge_pair
    merged = a + b
    for word in corpus: # 寻找连续 a-b pair 对，替换成 merged
        i = 0
        while i < len(word) - 1: # 滑动窗口
            if word[i] == a and word[i + 1] == b:
                left = word[i - 1] if i > 0 else None
                right = word[i + 2] if i + 2 < len(word) else None
                if left and left not in special_set:
                    decrement_pair(pair_counts, (left, a))
                    pair_counts[(left, merged)] += 1
                if right and right not in special_set:
                    decrement_pair(pair_counts, (b, right))
                    pair_counts[(merged, right)] += 1
                word[i : i + 2] = [merged] # 把两个元素替换为一个
            i += 1
    if merge_pair in pair_counts: # merge_pair = [a, b]
        del pair_counts[merge_pair] # 释放内存

In [9]:
def train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str],
    output_dir: str = "./bpeModel",
    PAT: str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
):
    os.makedirs(output_dir, exist_ok=True)
    pattern = re.compile(PAT, re.UNICODE)

    vocab = {i: bytes([i]) for i in range(256)} # 初始化 vocab
    for tok in special_tokens:
        vocab[len(vocab)] = tok.encode("utf-8") # len(vocab) 作下标添加 special token

    special_lookup = set(special_tokens) # str 匹配
    special_set = set(tok.encode("utf-8") for tok in special_tokens) # bytes 匹配
    special_pattern = ( # 正则匹配
        re.compile("|".join(re.escape(t) for t in sorted(special_tokens, key=len, reverse=True)))
        if special_tokens else None
    )

    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()

    corpus, n_stories = pretokenize(text, pattern, special_pattern, special_lookup)
    print(f"Loaded {n_stories} tokens")

    pair_counts = count_pairs(corpus, special_set)
    merges = []
    print(f"Initial unique pairs: {len(pair_counts)}")

    while len(vocab) < vocab_size and pair_counts:
        # item[0]: (a, b)  item[1]: freq
        (a, b), freq = max(pair_counts.items(), key=lambda item: (item[1], item[0])) 
        vocab[len(vocab)] = a + b # 分配新 token
        merges.append((a, b))
        apply_merge(corpus, pair_counts, (a, b), special_set)
        if len(merges) % 100 == 0:
            print(f"Step {len(merges)}, merged {a+b} freq={freq}")

    vocab_out = {k: v.decode("utf-8", errors="ignore") for k, v in vocab.items()}
    with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as f:
        json.dump(vocab_out, f, ensure_ascii=False, indent=2)
    with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as f:
        for a, b in merges:
            f.write(f"{a.decode('utf-8', 'ignore')} {b.decode('utf-8', 'ignore')}\n")

    print(f"BPE Training done: {len(vocab)} tokens, {len(merges)} merges.")
    return vocab, merges

In [10]:
vocab, merges = train_bpe(
    input_path="./datasets/TinyStories/valid.txt",
    vocab_size=10000,
    special_tokens=["<|endoftext|>"],
    output_dir="./bpeModel"
)

Loaded 4554143 tokens
Initial unique pairs: 1048
Step 100, merged b'ad' freq=28155
Step 200, merged b'fu' freq=12244
Step 300, merged b'ited' freq=6999
Step 400, merged b' outside' freq=4795
Step 500, merged b' sorry' freq=3499
Step 600, merged b'lly' freq=2692
Step 700, merged b'xt' freq=2078
Step 800, merged b' give' freq=1728
Step 900, merged b' curi' freq=1453
Step 1000, merged b' having' freq=1232
Step 1100, merged b' swim' freq=1063
Step 1200, merged b' filled' freq=929
Step 1300, merged b'ons' freq=792
Step 1400, merged b' treasure' freq=691
Step 1500, merged b'fort' freq=619
Step 1600, merged b' plac' freq=542
Step 1700, merged b' Wh' freq=490
Step 1800, merged b' Mittens' freq=441
Step 1900, merged b' comfort' freq=397
Step 2000, merged b' which' freq=357
Step 2100, merged b' teeth' freq=328
Step 2200, merged b' meant' freq=298
Step 2300, merged b' sold' freq=272
Step 2400, merged b' lit' freq=251
Step 2500, merged b'iff' freq=231
Step 2600, merged b' become' freq=218
Step 270