# Byte Pair encoding

In [1]:
from collections import defaultdict
import re

In [2]:
def get_vocab(data):
    vocab = defaultdict(int)
    for word in data:
        vocab[' '.join(word) + ' </w>'] += 1
    return vocab

In [3]:
def get_stats(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq
    return pairs

In [4]:
def merge_vocab(pair, v_in):
    v_out = {}
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    for word in v_in:
        w_out = word.replace(bigram, replacement)
        v_out[w_out] = v_in[word]
    return v_out

In [5]:
def byte_pair_encoding(data, num_merges):
    vocab = get_vocab(data)
    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best_pair = max(pairs, key=pairs.get)
        vocab = merge_vocab(best_pair, vocab)
    return vocab

In [6]:
def tokenize_sentence(sentence, vocab):
    sentence = list(sentence) + ['</w>']
    tokens = []
    while sentence:
        found = False
        for i in range(len(sentence) - 1, 0, -1):
            pair = ' '.join(sentence[i - 1:i + 1])
            if pair in vocab:
                tokens.append(pair)
                sentence = sentence[:i - 1] + [pair] + sentence[i + 1:]
                found = True
                break
        if not found:
            tokens.append(sentence[0])
            sentence = sentence[1:]
    return tokens

In [8]:
data = ["low", "lowest", "newer", "wider", "new"]
vocab = byte_pair_encoding(data, 5)
print("Vocabulary:", vocab)
sentence = "lower"
tokens = tokenize_sentence(sentence, vocab)
print("Tokenized Sentence:", tokens)

Vocabulary: {'low </w>': 1, 'low e s t </w>': 1, 'new er </w>': 1, 'w i d er </w>': 1, 'new </w>': 1}
Tokenized Sentence: ['l', 'o', 'w', 'e', 'r', '</w>']
