##### A simple implementation of the BPE algorithm from scratch. 
##### The objective is to understand the core idea of the algorithm, hence, a very basic limited implementation.

In [1]:
# import
import itertools
from collections import Counter
from collections import defaultdict

# Implement BPE from Scratch
#### step 1 - break down text into characters
#### step 2 - calculate the frequency of bigrams
#### step 3 - merge the most common pair to the dictionary

In [2]:
# step 1 - break down text into characters

def init_vocab(corpus):
    # print(corpus)
    input_text_with_special_char = [word + ' </w>' for word in corpus.split()]
    # print(input_text_with_special_char)
    vocab = []
    for ip in input_text_with_special_char:
        word_to_split = ' ' .join(list(ip.split()[0])) + " </w>"
        vocab.append(word_to_split)
    return vocab

# input_text = "Walker walked a long walk at someunknownbeach"
# print(init_vocab(input_text))

In [3]:
# step 2 - calculate the frequency of bigrams

def bigram_frequency(vocab):
    # print(f"current vocab = {vocab}")
    vocab_len = len(vocab)
    # print(f"len of current vocab = {vocab_len}")/
    bigrams_freq = defaultdict(int)
    for word in vocab:
        word = word.split()
        # print(word)
        for i1 in range(len(word)-2):
            pair = (word[i1],word[i1+1])
            bigrams_freq[pair] += 1
    # print(f"frequency of bigrams = {bigrams_freq}")
    max_frequency = max(bigrams_freq.values())
    # print(f"max frequency = {most_frequent_pair}")
    merged_frequent_pair = [pair for pair in bigrams_freq if bigrams_freq[pair] == max_frequency][0]
    # print(f"most frequent pairs: {merged_frequent_pair}")
    return merged_frequent_pair, max_frequency
  

In [4]:
# step 3 - merge the most common pair to the dictionary

def merge_common_pair_update_vocab(vocab, merged_frequent_pair):
    vocab_len = len(vocab)
    # print(vocab)
    # print(vocab_len)
    most_frequent_pair_list = list(merged_frequent_pair)
    # print(most_frequent_pair_list)
    for cnt, word in enumerate(vocab):
        # print(word)
        word = word.split()
        # print(word)
        items_to_pop = list()
        for y in range(len(word)-2):
            # print("inside")
            # print(y, y+1)
            if word[y] == most_frequent_pair_list[0] and word[y + 1] == most_frequent_pair_list[1]:
                word[y] = most_frequent_pair_list[0] + most_frequent_pair_list[1]
                # print(word[y])
                items_to_pop.append(word[y + 1])
            # print(f"items to pop: {items_to_pop}")
        for item in items_to_pop:
            word.remove(item)
        # print("word = ", ' '.join(word))
        vocab[cnt] = ' '.join(word)
    return vocab


### bpe orchestrator

In [5]:
def bpe_training():
    input_text = "Walker walked a long walk at someunknownbeach"
    corpus1 = init_vocab(input_text)
    print("initial corpus = ", corpus1)
    merges = 5
    learned_merges = []
    for i in range(merges):
        mst_freq_pair, max_freq = bigram_frequency(corpus1)
        # print("mst_freq_pair", mst_freq_pair)
        # if freq_gt_one:
        corpus1 = merge_common_pair_update_vocab(corpus1, mst_freq_pair)
        print(f"merge: {i} ; updated corpus : {corpus1} \n")
        learned_merges.append(mst_freq_pair)
        # print("********************************************************** \n")
    # final_corpus = " ".join(corpus1)
    print(f"final vocab = {corpus1}")
    print(f"learned merges: {learned_merges}")
    return corpus1, learned_merges
    

# use the learned merge to tokenise new sentence

In [6]:
def apply_bpe(test_word, merges):
    test_tokens = list(test_word) + ['</w>']
    i = 0
    while i < len(test_tokens) - 1:
        pair = (test_tokens[i], test_tokens[i+1])
        if pair in reversed(merges):
            test_tokens[i] = ''.join(pair)
            del test_tokens[i+1]
            i = max(i-1, 0)  # re-check merged token with previous one
        else:
            i += 1
    return test_tokens


In [7]:
final_corpus, learned_merges = bpe_training()

initial corpus =  ['W a l k e r </w>', 'w a l k e d </w>', 'a </w>', 'l o n g </w>', 'w a l k </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>']
merge: 0 ; updated corpus : ['W al k e r </w>', 'w al k e d </w>', 'a </w>', 'l o n g </w>', 'w al k </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>'] 

merge: 1 ; updated corpus : ['W alk e r </w>', 'w alk e d </w>', 'a </w>', 'l o n g </w>', 'w alk </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>'] 

merge: 2 ; updated corpus : ['W alke r </w>', 'w alke d </w>', 'a </w>', 'l o n g </w>', 'w alk </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>'] 

merge: 3 ; updated corpus : ['Walke r </w>', 'w alke d </w>', 'a </w>', 'l o n g </w>', 'w alk </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>'] 

merge: 4 ; updated corpus : ['Walker </w>', 'w alke d </w>', 'a </w>', 'l o n g </w>', 'w alk </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>'] 

final vocab = ['Walker </w>', 'w alke d </w>', 'a </w>', '

In [8]:
print(apply_bpe("walk", learned_merges))

['w', 'alk', '</w>']


In [9]:
print(apply_bpe("alongwalk", learned_merges))

['al', 'o', 'n', 'g', 'w', 'alk', '</w>']


In [10]:
print(apply_bpe("a long walk", learned_merges))

['a', ' ', 'l', 'o', 'n', 'g', ' ', 'w', 'alk', '</w>']


# tokens to embeddings - create embeddings
#### assign token ids to the corpus plus some additional special tokens

In [11]:
def create_corpus_token_ids(corpus):
    print(corpus)
    token_to_id = {}
    token_id = 1
    for token in corpus:
        print(token)
        for tkn in token.split():
            if tkn not in token_to_id.keys():
                token_to_id[tkn] = token_id
                token_id += 1
        #Add some special tokens
        token_to_id["<UNK>"] = 400
        token_to_id["<SPACE>"] = 5000
        token_to_id["</w>"] = 6000
    return token_to_id

In [12]:
corpus_token_numbers = create_corpus_token_ids(final_corpus)
print(corpus_token_numbers)

['Walker </w>', 'w alke d </w>', 'a </w>', 'l o n g </w>', 'w alk </w>', 'a t </w>', 's o m e u n k n o w n b e a c h </w>']
Walker </w>
w alke d </w>
a </w>
l o n g </w>
w alk </w>
a t </w>
s o m e u n k n o w n b e a c h </w>
{'Walker': 1, '</w>': 6000, '<UNK>': 400, '<SPACE>': 5000, 'w': 3, 'alke': 4, 'd': 5, 'a': 6, 'l': 7, 'o': 8, 'n': 9, 'g': 10, 'alk': 11, 't': 12, 's': 13, 'm': 14, 'e': 15, 'u': 16, 'k': 17, 'b': 18, 'c': 19, 'h': 20}


In [13]:
#create embeddings for test words
def encode(test_tkns, corpus_token_ids):
    embeddings = []
    # print(corpus_token_ids)
    print(test_tkns)
    for tkn_test in test_tkns:
        # print(tkn_test)
        if tkn_test in corpus_token_ids.keys():
            embeddings.append(corpus_token_ids[tkn_test])
        elif tkn_test == " ":
            embeddings.append(corpus_token_ids["<SPACE>"])    
        else:
            embeddings.append(corpus_token_ids["<UNK>"])
    return embeddings


In [14]:
print(encode(['w', 'alk', '</w>'], corpus_token_numbers))
print(encode(['al', 'o', 'n', 'g', 'w', 'alk', '</w>'], corpus_token_numbers))
print(encode(['a', ' ', 'l', 'o', 'n', 'g', ' ', 'w', 'alk', '</w>'], corpus_token_numbers))

['w', 'alk', '</w>']
[3, 11, 6000]
['al', 'o', 'n', 'g', 'w', 'alk', '</w>']
[400, 8, 9, 10, 3, 11, 6000]
['a', ' ', 'l', 'o', 'n', 'g', ' ', 'w', 'alk', '</w>']
[6, 5000, 7, 8, 9, 10, 5000, 3, 11, 6000]


# embeddings to tokens - decode

In [15]:
def decode(embeddings, corpus_token_ids):
    orig_tokens = []
    for em in embeddings:
        for key, value in corpus_token_ids.items():
            if value == em:
                if key == "<SPACE>":
                    orig_tokens.append(' ')
                else:
                    orig_tokens.append(key)
    orig_word = "".join(orig_tokens[:-1])
    return orig_tokens, orig_word
        

In [16]:
decoded_tokens, decoded_word = decode([3, 11, 6000], corpus_token_numbers)
print("decoded tokens : ", decoded_tokens, "\ndecoded word : ", decoded_word)

decoded_tokens, decoded_word = decode([400, 8, 9, 10, 3, 11, 6000], corpus_token_numbers)
print("decoded tokens : ", decoded_tokens, "\ndecoded word : ", decoded_word)

decoded_tokens, decoded_word = decode([6, 5000, 7, 8, 9, 10, 5000, 3, 11, 6000], corpus_token_numbers)
print("decoded tokens : ", decoded_tokens, "\ndecoded word : ", decoded_word)

decoded tokens :  ['w', 'alk', '</w>'] 
decoded word :  walk
decoded tokens :  ['<UNK>', 'o', 'n', 'g', 'w', 'alk', '</w>'] 
decoded word :  <UNK>ongwalk
decoded tokens :  ['a', ' ', 'l', 'o', 'n', 'g', ' ', 'w', 'alk', '</w>'] 
decoded word :  a long walk
