In [1]:
import json, regex
from typing import Dict, List, Tuple, Iterable, Iterator

In [11]:
class Tokenizer:
    def __init__(
        self,
        vocab: Dict[int, bytes],
        merges: List[Tuple[bytes, bytes]],
        special_tokens: List[str] | None = None,
    ):
        self.vocab = vocab
        self.vocab_rev = {v: k for k, v in vocab.items()}
        self.special_tokens = special_tokens or []

        # 编译正则，与训练一致（前导空格）
        self.PAT = regex.compile(
            r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        )

        # 预建 merge rank 表
        self.bpe_ranks = {pair: i for i, pair in enumerate(merges)}

    # --------------------------------------------------
    @classmethod
    def from_files(
        cls,
        vocab_path: str,
        merges_path: str,
        special_tokens: List[str] | None = None,
    ) -> "Tokenizer":
        with open(vocab_path, "r", encoding="utf-8") as f:
            vocab_data = json.load(f)
        vocab = {int(i): v.encode("utf-8") for i, v in vocab_data.items()}

        merges = []
        with open(merges_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                # 忽略注释和非法行
                if not line or line.startswith("#"):
                    continue
                parts = line.split()
                if len(parts) != 2:
                    continue  # 跳过不合规的行
                a, b = parts
                merges.append((a.encode(), b.encode()))

        return cls(vocab, merges, special_tokens)


    # --------------------------------------------------
    def get_pairs(self, tokens: List[bytes]):
        """取相邻 token pair"""
        return {(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)}

    # --------------------------------------------------
    def bpe(self, token: bytes) -> List[bytes]:
        """执行 Byte-Pair Encoding"""
        word = [bytes([b]) for b in token]  # byte-level 初始分词
        pairs = self.get_pairs(word)

        while pairs:
            # 找到当前可合并的最小 rank
            min_pair = min(
                pairs, key=lambda p: self.bpe_ranks.get(p, float("inf"))
            )
            if min_pair not in self.bpe_ranks:
                break

            new_word = []
            i = 0
            while i < len(word):
                if (
                    i < len(word) - 1
                    and word[i] == min_pair[0]
                    and word[i + 1] == min_pair[1]
                ):
                    new_word.append(word[i] + word[i + 1])
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            word = new_word
            pairs = self.get_pairs(word)

        return word

    # --------------------------------------------------
    def encode(self, text: str) -> List[int]:
        """Encode text → token ids"""
        ids = []
        for match in self.PAT.finditer(text):
            token = match.group(0).encode("utf-8")
            for t in self.bpe(token):
                if t in self.vocab_rev:
                    ids.append(self.vocab_rev[t])
        return ids

    # --------------------------------------------------
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for line in iterable:
            for tid in self.encode(line):
                yield tid

    # --------------------------------------------------
    def decode(self, ids: List[int]) -> str:
        byte_stream = b"".join(
            self.vocab.get(i, b"\xef\xbf\xbd") for i in ids
        )
        return byte_stream.decode("utf-8", errors="replace")


In [12]:
toks = Tokenizer.from_files("./bpe_model_hybrid/vocab.json", "./bpe_model_hybrid/merges.txt")

sample = "Hello world! It's a test."
ids = toks.encode(sample)
print(ids)
print(toks.decode(ids))

[72, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 33, 32, 73, 116, 39, 115, 32, 97, 32, 116, 101, 115, 116, 46]
Hello world! It's a test.
