In [1]:
import re, collections

test_text = "this is a test. we are testing BPE tokenization. BPE is really cool."


In [2]:
def get_vocab(text):
    words = text.split()
    vocab = collections.defaultdict(int)
    for word in words:
        vocab[' '.join(word)] += 1
    return vocab

test_vocab = get_vocab(test_text)
test_vocab

defaultdict(int,
            {'t h i s': 1,
             'i s': 2,
             'a': 1,
             't e s t .': 1,
             'w e': 1,
             'a r e': 1,
             't e s t i n g': 1,
             'B P E': 2,
             't o k e n i z a t i o n .': 1,
             'r e a l l y': 1,
             'c o o l .': 1})

In [3]:
def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

pairs = get_stats(test_vocab)
pairs

defaultdict(int,
            {('t', 'h'): 1,
             ('h', 'i'): 1,
             ('i', 's'): 3,
             ('t', 'e'): 2,
             ('e', 's'): 2,
             ('s', 't'): 2,
             ('t', '.'): 1,
             ('w', 'e'): 1,
             ('a', 'r'): 1,
             ('r', 'e'): 2,
             ('t', 'i'): 2,
             ('i', 'n'): 1,
             ('n', 'g'): 1,
             ('B', 'P'): 2,
             ('P', 'E'): 2,
             ('t', 'o'): 1,
             ('o', 'k'): 1,
             ('k', 'e'): 1,
             ('e', 'n'): 1,
             ('n', 'i'): 1,
             ('i', 'z'): 1,
             ('z', 'a'): 1,
             ('a', 't'): 1,
             ('i', 'o'): 1,
             ('o', 'n'): 1,
             ('n', '.'): 1,
             ('e', 'a'): 1,
             ('a', 'l'): 1,
             ('l', 'l'): 1,
             ('l', 'y'): 1,
             ('c', 'o'): 1,
             ('o', 'o'): 1,
             ('o', 'l'): 1,
             ('l', '.'): 1})

In [6]:
best = max(pairs, key=pairs.get)
best, pairs[best]

(('i', 's'), 3)

In [None]:
def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

_test_vocab = merge_vocab(best, test_vocab)
_test_vocab

{'t h is': 1,
 'is': 2,
 'a': 1,
 't e s t .': 1,
 'w e': 1,
 'a r e': 1,
 't e s t i n g': 1,
 'B P E': 2,
 't o k e n i z a t i o n .': 1,
 'r e a l l y': 1,
 'c o o l .': 1}

In [8]:
num_merges = 10
for i in range(num_merges):
    pairs = get_stats(test_vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    test_vocab = merge_vocab(best, test_vocab)

test_vocab

{'this': 1,
 'is': 2,
 'a': 1,
 'test.': 1,
 'we': 1,
 'a re': 1,
 'test i n g': 1,
 'BPE': 2,
 't o k e n i z a t i o n .': 1,
 're a l l y': 1,
 'c o o l .': 1}

In [9]:
def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out


def bpe_tokenization(text, num_merges):
    vocab = get_vocab(text)
    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        
    return vocab


In [10]:
bpe_tokenization(test_text, 10)

{'this': 1,
 'is': 2,
 'a': 1,
 'test.': 1,
 'w e': 1,
 'a re': 1,
 'test i n g': 1,
 'BPE': 2,
 't o k e n i z a t i o n .': 1,
 're a l l y': 1,
 'c o o l .': 1}