## Byte Pair Encoding (BPE)

这里给出一个ASCII-level的BPE实现, 在实际引用中, 通常用byte-level的, 即字符UTF-8的比特位

In [None]:
import re
from collections import defaultdict, Counter

class BPETokenizer:
    def __init__(self, vocab_size=100):
        self.vocab_size = vocab_size
        self.bpe_codes = {}
        self.vocab = {}

    def get_vocab(self, corpus):
        """将语料转换为初始词汇表（字符级）"""
        vocab = defaultdict(int)
        for word in corpus:
            word = word.strip()
            # 将词分解为字符，并添加结束符标记（便于区分词边界）
            chars = " ".join(list(word)) + " </w>"
            vocab[chars] += 1
        return vocab

    def get_stats(self, vocab):
        """统计所有token对的频率"""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def merge_vocab(self, pair, vocab):
        """合并给定的字符对"""
        new_vocab = {}
        pattern = re.escape(" ".join(pair))
        replacement = "".join(pair)
        re_pattern = re.compile(r"(?<!\S)" + pattern + r"(?!\S)")
        for word in vocab:
            new_word = re_pattern.sub(replacement, word)
            new_vocab[new_word] = vocab[word]
        return new_vocab

    def train(self, corpus):
        self.vocab = self.get_vocab(corpus)
        for i in range(self.vocab_size):
            pairs = self.get_stats(self.vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            self.bpe_codes[best] = i
            self.vocab = self.merge_vocab(best, self.vocab)

    def encode_word(self, word):
        """对单个词进行 BPE 编码"""
        word = list(word) + ['</w>']
        while True:
            pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
            bpe_pairs = {pair: self.bpe_codes.get(pair, float('inf')) for pair in pairs}
            if not bpe_pairs:
                break
            best_pair = min(bpe_pairs, key=bpe_pairs.get)
            if bpe_pairs[best_pair] == float('inf'):
                break
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word)-1 and (word[i], word[i+1]) == best_pair:
                    new_word.append(word[i] + word[i+1])
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            word = new_word
        return word

    def encode(self, sentence):
        """对句子进行编码"""
        return [self.encode_word(word) for word in sentence.split()]

# 示例使用
if __name__ == "__main__":
    corpus = [
        "low", "lower", "newest", "widest"
    ]
    tokenizer = BPETokenizer(vocab_size=10)
    tokenizer.train(corpus)

    print("=== BPE Codes ===")
    for k, v in tokenizer.bpe_codes.items():
        print(k)

    print("\n=== Encoding ===")
    print(tokenizer.encode("lowest newest"))

=== BPE Codes ===
('l', 'o')
('lo', 'w')
('e', 's')
('es', 't')
('est', '</w>')
('low', '</w>')
('low', 'e')
('lowe', 'r')
('lower', '</w>')
('n', 'e')

=== Encoding ===
[['low', 'est</w>'], ['ne', 'w', 'est</w>']]
