# Byte Pair Encoding

This is a simple pure python implementation of byte pair encoding
Implementation follows the [Wikipedia description](https://en.wikipedia.org/wiki/Byte_pair_encoding)

This implementation is not very useful in practical applications as performance is a major issue here.

In [21]:
from tqdm.notebook import tqdm
from collections import defaultdict
import numpy as np


def bpe_train(text, max_iter, max_unique_tokens=1e6):
    token_map = np.unique(list(text))
    enc_token_map = dict(zip(token_map, range(len(token_map))))
    dec_token_map = dict(zip(enc_token_map.values(), enc_token_map.keys()))
    enc_text = [enc_token_map[c] for c in text]

    for iter in tqdm(range(max_iter)):
        stat = defaultdict(lambda: 0)
        for i in range(len(enc_text) - 1):
            stat[(enc_text[i], enc_text[i + 1])] += 1
        max_pair = max(stat, key=stat.get)
        new_code = len(enc_token_map)
        new_chars = dec_token_map[max_pair[0]] + dec_token_map[max_pair[1]]
        enc_token_map[new_chars] = new_code
        dec_token_map[new_code] = new_chars
        new_enc_text = []
        i = 0
        while i < len(enc_text):
            if i < len(enc_text) - 1 and (enc_text[i], enc_text[i + 1]) == max_pair:
                new_enc_text.append(new_code)
                i += 1
            else:
                new_enc_text.append(enc_text[i])
            i += 1
        enc_text = new_enc_text
        uniq_tokens = len(np.unique(enc_text))
        if uniq_tokens > max_unique_tokens:
            break

    uniq_tokens = np.unique(enc_text)

    dec_token_map = {t: dec_token_map[t] for t in uniq_tokens}
    enc_token_map = dict(zip(dec_token_map.values(), dec_token_map.keys()))

    return enc_text, enc_token_map, dec_token_map


bpe, enc_map, dec_map = bpe_train('aaabdaaabac', max_iter=3)
bpe, enc_map, dec_map

  0%|          | 0/3 [00:00<?, ?it/s]

([6, 3, 6, 0, 2],
 {'a': 0, 'c': 2, 'd': 3, 'aaab': 6},
 {0: 'a', 2: 'c', 3: 'd', 6: 'aaab'})

In [24]:
def bpe_decode(tokens, dec_map, sep=''):
    return sep.join([dec_map[t] for t in tokens])

bpe_decode(bpe, dec_map), bpe_decode(bpe, dec_map, sep='|')

('aaabdaaabac', 'aaab|d|aaab|a|c')

In [23]:
def bpe_encode(s, enc_map):
    out = []
    while len(s) > 0:
        max_pref = ''
        for k, v in enc_map.items():
            if s[:len(k)] == k:
                if len(k) > len(max_pref):
                    max_pref = k
        assert(max_pref != '')
        out.append(enc_map[max_pref])
        s = s[len(max_pref):]
    return out

bpe_encode('aaabdaaabac', enc_map)

[6, 3, 6, 0, 2]