## **Training**

In [1]:
from tqdm import tqdm
import re
import json
from collections import defaultdict
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

## **Loading the Dataset**
* Add the kaggle dataset **Machine Translation Data Set** by Aadish Joshi to your working directory

In [6]:
corpus = ""
with open("/kaggle/input/machine-translation-data-set/enlish_data.txt", 'r', encoding='utf-8') as f:
    corpus = f.read()

## **Fixing the apostrophe and space error in the dataset**

In [None]:
def fix_apostrophe_space(input_string):
    pattern = re.compile(r"(?<=\w)'\s(?=[tslvrm])")
    output_string = re.sub(pattern, "'", input_string)

    return output_string

corpus = fix_apostrophe_space(corpus)
print(corpus[:5000])

In [None]:
corpus = corpus.split(".")
print(len(corpus))

In [None]:
word_freqs = defaultdict(int)

for text in corpus:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

In [None]:
alphabet = []

for word in word_freqs.keys():
    for letter in word:
        if letter not in alphabet:
            alphabet.append(letter)
alphabet.sort()

print(alphabet)

In [None]:
vocab = ["<|endoftext|>"] + alphabet.copy()
print(vocab)

In [None]:
splits = {word: [c for c in word] for word in word_freqs.keys()}

In [None]:
def compute_pair_freqs(splits):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

In [None]:
merges = {("Ġ", "t"): "Ġt"}
vocab.append("Ġt")

In [None]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue

        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                split = split[:i] + [a + b] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [None]:
splits = merge_pair("Ġ", "t", splits)
print(splits["Ġtemple"])

## **Byte-Pair Encoding**

In [None]:
vocab_size = 22000

for i in tqdm(range(vocab_size)):
    pair_freqs = compute_pair_freqs(splits)
    best_pair = ""
    max_freq = None
    for pair, freq in pair_freqs.items():
        if max_freq is None or max_freq < freq:
            best_pair = pair
            max_freq = freq
    splits = merge_pair(*best_pair, splits)
    merges[best_pair] = best_pair[0] + best_pair[1]
    vocab.append(best_pair[0] + best_pair[1])

In [None]:
token2index = {token: i for i, token in enumerate(vocab)}
merges_str = {str(pair): merge for pair, merge in merges.items()}

## **Saving the files**

In [None]:
json_file = json.dumps(token2index, indent=4)
with open("vocab.json", "w") as outfile:
    json.dump(json_file, outfile)

In [None]:
merges_json = json.dumps(merges_str, indent=4)
with open("merges.json", "w") as outfile:
    json.dump(merges_json, outfile)

## **Inference**

In [None]:
import ast
import json
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

## **Loading trained vocab and merges**

In [11]:
f = open("vocab.json")
vocab = json.load(f)
vocab = ast.literal_eval(vocab)

f = open("merges.json")
merges_bpe = json.load(f)
merges_bpe = ast.literal_eval(merges_bpe)

In [None]:
merges = {}
for key, value in merges_bpe.items():
    merges[ast.literal_eval(key)] = value

In [None]:
def tokenize(text):
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    splits = [[l for l in word] for word in pre_tokenized_text]
    for pair, merge in merges.items():
        for idx, split in enumerate(splits):
            i = 0
            while i < len(split) - 1:
                if split[i] == pair[0] and split[i + 1] == pair[1]:
                    split = split[:i] + [merge] + split[i + 2 :]
                else:
                    i += 1
            splits[idx] = split

    return sum(splits, [])