# BPE Tokenizer

## utils

In [27]:
def get_stats(ids, counts=None):
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts
print('get stats')
example = [1, 2, 3, 1, 2] # token id 序列
counts = get_stats(example)
print(counts) # 相邻token出现频次

In [28]:
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2 # 相邻两个token id 匹配上Pair, 那么就进行替换
        else:
            newids.append(ids[i])
            i += 1
    return newids
ids=[1, 2, 3, 1, 2]
pair=(1, 2)
newids = merge(ids, pair, 4)
print(newids)

In [29]:
people = [
    {'name': 'Alice', 'age': 30},
    {'name': 'Bob', 'age': 25},
    {'name': 'Charlie', 'age': 35}
]
oldest = max(people, key=lambda person: person['age'])
print(oldest)

## Dummy text

In [30]:
text = '''   
Large Language Models is all you need,
what can i say, manba out. 
Attention is All you need.
Vision Transformers, 
Generative Pretrained Transformers,
Reinforcement leraning from human feedback
chain of thought is basic resoning tool.
LLMs can evaluate NLP results.
Richard Sutton Refinforcement Learning Introduction edition 2.
encoder-only
'''

text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
print(len(ids))

In [31]:
def build_vocab():
    merges = {}
    vocab = {idx: bytes([idx]) for idx in range(8)}
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    return vocab

print(build_vocab())

## BPE tokenizer implemention

In [32]:
INITIAL_VOCAB_SIZE = 256

class BasicTokenizer():
    def __init__(self):
        # def __init__(self):
        self.merges = {} # (int, int) -> int
        self.vocab = self.build_vocab() # int -> bytes
        
    def build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(INITIAL_VOCAB_SIZE)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        return vocab

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= INITIAL_VOCAB_SIZE
        num_merges = vocab_size - INITIAL_VOCAB_SIZE

        text_bytes = text.encode("utf-8") 
        ids = list(text_bytes) 

        merges = {} 
        # int -> bytes
        vocab = {idx: bytes([idx]) for idx in range(INITIAL_VOCAB_SIZE)} 
        for i in range(num_merges):
            stats = get_stats(ids)
            # pair(2,3),    vocab[2]='te', vocab[3]='st'
            pair = max(stats, key=stats.get)             
            idx = 256 + i
            ids = merge(ids, pair, idx)
            merges[pair] = idx
            
            # 原来的词不会剔除，而是在基础词表上累加，如
            # 'tr' 'tran' 'transf' 
            # vocab[new_id] = 'te' + 'st' -> vocab[4] = 'test'
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 
            \
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

bpe = BasicTokenizer()
bpe.train(text, vocab_size = 266)
for i in range(256,266,1):
    print(bpe.vocab[i])

print(bpe.merges)

## BPE Encode 

In [33]:
# encoder
# utf-8 token ids
text = 'i love transfromers'
text_bytes = text.encode("utf-8") # raw bytes
# 首先对数据转成 字符 的token id
# 再将raw token id 按照 merges表 对raw token id 进行合并 -> token_id

# bpe token ids
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
    stats = get_stats(ids)
    # ids = (2,3,4,5)
    # stats为原文本相邻值统计。   text： pair(2,3),       (3,4),     (3,5)
    # key为筛选merges集合里有效的 词表： pair             (3,4)~268,   (3,5)~289,
    # 此时选出pair(3,4) 268, (3,5) 289
    # 取 bpe.merges.get([3,4]) = 268
    # 取 bpe.merges.get([3,5]) = 289
    # 取 bpe.merges.get([2,3]) = inf
    # 结果取min，是指merge对应idx越小，出现的频率越高
    pair = min(stats, key=lambda p: bpe.merges.get(p, float("inf"))) 
    print(pair)
    print(bpe.vocab[pair[0]], bpe.vocab[pair[1]])
    if pair not in bpe.merges:
        break 
    idx = bpe.merges[pair] # (3,4) -> 268
    ids = merge(ids, pair, idx) # (2,3,4,5) -> (2, 268, 5)
print(ids)


## BPE Decode

In [34]:

# 解码
text_bytes = b"".join(bpe.vocab[idx] for idx in ids)
decode_text = text_bytes.decode("utf-8", errors="replace")
print(decode_text)