In [4]:
def most_common_pair(nums):
    from collections import Counter
    ctr = Counter(zip(nums, nums[1:]))
    return None if len(ctr) == 0 else ctr.most_common(1)[0][0]

def merge_pair(nums, pair, idx):
    new_nums = []
    i = 0
    while i < len(nums):
        if i < len(nums) - 1 and nums[i] == pair[0] and nums[i + 1] == pair[1]:
            new_nums.append(idx)
            i += 2
        else:
            new_nums.append(nums[i])
            i += 1
    return new_nums

def get_tokenizer(text, vocab_size, start_idx=256):
    nums = [ord(char) for char in text]
    itos = {n: chr(n) for n in nums}
    n_merges = start_idx - vocab_size
    merges = {}
    for i in range(n_merges):
        pair = most_common_pair(nums)
        if pair is None:
            break
        pair_idx = start_idx + i
        nums = merge_pair(nums, pair, pair_idx)
        itos[pair_idx] = itos[pair[0]] + itos[pair[1]]
        merges[pair] = pair_idx
    return merges, itos, nums

text = 'the quick brown fox jumped over the angry dog'
for n_merges in range(50):
    merges, itos, nums = get_tokenizer(text, 256 - n_merges)
    if n_merges % 5 == 0:
        print(f'after {n_merges} merges: {",".join([itos[n] for n in nums])}')

after 0 merges: t,h,e, ,q,u,i,c,k, ,b,r,o,w,n, ,f,o,x, ,j,u,m,p,e,d, ,o,v,e,r, ,t,h,e, ,a,n,g,r,y, ,d,o,g
after 5 merges: the qu,i,c,k, ,b,r,o,w,n, ,f,o,x, ,j,u,m,p,e,d, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g
after 10 merges: the quick b,r,o,w,n, ,f,o,x, ,j,u,m,p,e,d, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g
after 15 merges: the quick brown ,f,o,x, ,j,u,m,p,e,d, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g
after 20 merges: the quick brown fox j,u,m,p,e,d, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g
after 25 merges: the quick brown fox jumped, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g
after 30 merges: the quick brown fox jumped over, ,the ,a,n,g,r,y, ,d,o,g
after 35 merges: the quick brown fox jumped over the ang,r,y, ,d,o,g
after 40 merges: the quick brown fox jumped over the angry do,g
after 45 merges: the quick brown fox jumped over the angry dog


In [5]:
class BPETokenizer:
    def __init__(self, vocab_size=100, max_char_idx=256):
        self.vocab_size = vocab_size
        self.max_char_idx = max_char_idx
        self.vocab = {idx: bytes([idx]) for idx in range(self.max_char_idx)}
        self.merged_pairs = {}

    def fit(self, X_text):
        all_text = '\n'.join(X_text)
        nums = self._encode(text)
        n_merges = self.max_char_idx - self.vocab_size
        for i in range(n_merges):
            counts = self._get_pair_counts(nums)
            pair = max(counts, key=counts.get)
            pair_idx = self.max_char_idx + i
            nums = self._merge_pair(nums, pair, pair_idx)
            self.vocab[pair_idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
            self.merged_pairs[pair] = pair_idx

    def transform(self, X_text):
        X_nums = []
        for text in X_text:
            nums = self._encode(text)
            X_nums.append(nums)
        return X_nums

    def inverse_transform(self, X_nums):
        X_text = []
        for nums in X_nums:
            text = self._decode(nums)
            X_text.append(text)
        return X_text

    def nums_to_tokens(self, X_nums):
        X_tokens = []
        for nums in X_nums:
            tokens = [self.vocab[idx].decode('utf-8', errors='replace') for idx in nums]
            X_tokens.append(tokens)
        return X_tokens

    def _merge_pair(self, nums, pair, pair_idx):
        nums_merged = []
        i = 0
        while i < len(nums):
            if i < len(nums) - 1 and nums[i] == pair[0] and nums[i + 1] == pair[1]:
                nums_merged.append(pair_idx)
                i += 2
            else:
                nums_merged.append(nums[i])
                i += 1
        return nums_merged

    def _encode(self, text):
        nums = list(text.encode('utf-8'))
        while len(nums) >= 2:
            counts = self._get_pair_counts(nums)
            pair = min(counts, key=lambda pair: self.merged_pairs.get(pair, float('inf')))
            if pair not in self.merged_pairs:
                break
            pair_idx = self.merged_pairs[pair]
            nums = self._merge_pair(nums, pair, pair_idx)
        return nums

    def _decode(self, nums):
        tokens = b''.join(self.vocab[idx] for idx in nums)
        text = tokens.decode('utf-8', errors='replace')
        return text

    @staticmethod
    def _get_pair_counts(nums):
        counts = {}
        for pair in zip(nums, nums[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

text = 'the quick brown fox jumped over the angry dog'
tokenizer = BPETokenizer(vocab_size=256 - 10)
tokenizer.fit([text])

In [6]:
print(tokenizer.transform([text]))

[[265, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 101, 100, 32, 111, 118, 101, 114, 32, 258, 97, 110, 103, 114, 121, 32, 100, 111, 103]]


In [7]:
print(','.join(tokenizer.nums_to_tokens(tokenizer.transform([text]))[0]))

the quick b,r,o,w,n, ,f,o,x, ,j,u,m,p,e,d, ,o,v,e,r, ,the ,a,n,g,r,y, ,d,o,g


In [8]:
print(','.join(tokenizer.nums_to_tokens(tokenizer.transform(['hello to the world!']))[0]))

h,e,l,l,o, ,t,o, ,the ,w,o,r,l,d,!
