## 2-BPE | 预分词(Pre-tokenization) \[ 多进程并行 \]
BPE-No.1: 预分词(Pre-tokenization) 把原始文本分成初步"词形片段"并计数(正则化去掉标点、符号)

In [None]:
import os
import regex as re 
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from collections import Counter
from cs336_pretokenization_example import find_chunk_boundaries

In [None]:
train_path = "/home/winbeau/Study/1-transformer/datasets/TinyStories/txt/train_with_eot.txt"
assert os.path.exists(train_path), "Not found train_with_eot.txt"

In [None]:
num_processes = min(12, cpu_count())
print(f"Using {num_processes} processes")

In [None]:
# regex 正则化 减弱标点、其他符号对文本的影响
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
SPECIAL = "<|endoftext|>" # 结束标志特殊正则

In [None]:
def process_chunk(start_end):
    start, end = start_end
    counter = Counter()

    with open(train_path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")

    idx = 0
    while True: # 手动查找 <|endoftext|> 位置
        pos = chunk.find(SPECIAL, idx)
        if pos == -1: # 没找到 对剩下部分用正则分词
            part = chunk[idx:]
            for m in re.finditer(PAT, part):
                tok = m.group()
                counter[tok.encode("utf-8")] += 1
            break
        # 对特殊 token 前面的部分做分词
        part = chunk[idx:pos]
        for m in re.finditer(PAT, part):
            tok = m.group()
            counter[tok.encode("utf-8")] += 1
        # 单独计一次特殊 token
        counter[SPECIAL.encode("utf-8")] += 1
        idx = pos + len(SPECIAL)

    return counter

In [None]:
with open(train_path, "rb") as f: 
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

chunk_pairs = list(zip(boundaries[:-1], boundaries[1:])) # 0 1 -> 1 2 
print(f"Found {len(chunk_pairs)} chunks")

In [None]:
with Pool(num_processes) as p: 
    counters = list(tqdm(
        p.imap(process_chunk, chunk_pairs), # 若不想加过程可视化模块直接 p.imap 即可
        total=len(chunk_pairs), 
        desc="Pre-tokenization chunks", 
        ncols=80
    ))

total_counts = sum(counters, Counter())
print(f"Total unique tokens: {len(total_counts)}")

In [None]:
print("Top 10 most common tokens:") # 可以看到有很多前导空格————' '频率极高，使用前导空格优化效率
for token, freq in total_counts.most_common(10): 
    try: 
        print(f"{token.decode('utf-8', errors='ignore')!r} : {freq} ")
    except Exception: 
        print(f"{token} : {freq}")
print(f"<|endoftext|> freq: {total_counts.get(b'<|endoftext|>', 0)}")