# BPE（byte-pair encode）

![](img/BPE_algorithm.png)

算法步骤
1. 初始化时，把句子分成字符级 token；初始化词表V
2. 统计所有相邻 token 对出现频率；
3. 找出频率最高的 token 对，将它们合并为一个新 token；加入词表V
4. 重复合并操作若干轮，生成合并规则；
5. 用这些规则将输入文本编码为“子词序列”；

In [4]:
text = "This is some textabcdABCD"
byte_ary = bytearray(text, "utf-8")
print(byte_ary)

bytearray(b'This is some textabcdABCD')


In [5]:
ids=list(byte_ary)
print(ids)

[84, 104, 105, 115, 32, 105, 115, 32, 115, 111, 109, 101, 32, 116, 101, 120, 116, 97, 98, 99, 100, 65, 66, 67, 68]


However, the downside of this approach is that it is creating one ID for each character (that’s a lot of IDs for a short text!)

1bytes--[0,255)

In [59]:

from collections import defaultdict
class BPETokenizer:
    #constructer
    def __init__(self,vocab_size):
        self.vocab_size = vocab_size
        self.bpe_codes = {}#对应上述算法中的vocabulary 
        self.vocab = {}#对应上述算法中的corpus,每次合并都会有字符上的变动  <258>cat in <258>hat
    
    def get_vocab(self, corpus):
        #corpus =['low low low lowly lower newer newer']
        #output: {low</w>:3, ....newer:2}
        vocab = defaultdict(int)#value: frequency
        for sentence in corpus: 
            words =  sentence.strip().split()
            for word in words:
                word=' '.join(word)+" </w>"
                vocab[word]+=1
        return vocab
    def get_stats(self, vocab):
        #计算当前pairs的frequency
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split(' ')
            for i in range(len(symbols)-1):
                pairs[symbols[i],symbols[i+1]]+=freq
        return pairs
    def merge_vocab(self, pair, vocab_in):
        #合并pair中的字母， vocab中还需要频率
        #更新原句l o-->lo
        vocab_out = {}
        bigram = ' '.join(pair)#l o 
        replacement = ''.join(pair)#lo
        for word, freq in vocab_in.items():
            new_word = word.replace(bigram,replacement)
            vocab_out[new_word]=vocab_in[word]
        return vocab_out
    
    def fit(self, corpus):
        self.vocab = self.get_vocab(corpus)
        for i in range(self.vocab_size):
            pairs = self.get_stats(self.vocab)
            if not pairs:
                break
            #取最高的频率的pair进行合并
            best_pair = max(pairs.keys(), key = lambda k: pairs[k])#return pair
            self.vocab = self.merge_vocab(best_pair,self.vocab)
            self.bpe_codes[best_pair]=len(self.bpe_codes)
    def encode(self,word):
        word = ' '.join(list(word)) + ' </w>' 
        while True:
            pairs = self.get_stats({word:1})
            candidates = [p for p in pairs if p in self.bpe_codes]
            if not candidates:
                break
            prior = min(candidates, key = lambda k: self.bpe_codes[k])
            word = self.merge_vocab(prior,{word:1})
            word = list(word.keys())[0]
        return word.split()
    
    def decode(self,tokens):#[low</w>, low, l, y</w>]
        return ''.join(tokens).replace('</w>',' ')

        
corpus =['low low low lowly lower newer newer','happy dog happy cat']
Tokenizer=BPETokenizer(5)
Tokenizer.fit(corpus)
print("vocab:",Tokenizer.vocab)
print('bpe_code:',Tokenizer.bpe_codes)
print(Tokenizer.decode(['low</w>', 'low', 'l', 'y</w>']))
print(Tokenizer.encode('hilowest'))

vocab: {'low</w>': 3, 'low l y</w>': 1, 'low er </w>': 1, 'n e w er </w>': 2, 'h a p p y</w>': 2, 'd o g </w>': 1, 'c a t </w>': 1}
bpe_code: {('l', 'o'): 0, ('lo', 'w'): 1, ('low', '</w>'): 2, ('y', '</w>'): 3, ('e', 'r'): 4}
low lowly 
['h', 'i', 'low', 'e', 's', 't', '</w>']
