In [1]:
import re
from collections import Counter, defaultdict

In [2]:
sentences = [
    "The boy hugs the cat.",
    "The boys are hugging the dogs.",
    "The dogs are chasing the cats.",
    "The dog and the cat sit quietly.",
    "The boy is sitting on the dog."
]

In [3]:
def preprocess(sent):
    sent = sent.lower()
    sent = re.sub(r'[^\w\s]', '', sent)  # remove punctuation
    return sent.split()

tokenized_corpus = [preprocess(s) for s in sentences]

In [None]:
def word_to_chars(word):
    return list(word) + ["</w>"] #word endings to prevent merges across words

In [5]:
vocab = Counter()
for sentence in tokenized_corpus:
    for word in sentence:
        vocab[" ".join(word_to_chars(word))] += 1

In [None]:
def pair_freq(vocab):
    #frequency of pairs of symbols
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs

In [None]:
def merge_vocab(pair, vocab_in):
    vocab_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in vocab_in:
        w_out = p.sub(''.join(pair), word)
        vocab_out[w_out] = vocab_in[word]
    return vocab_out

In [None]:
num_merges = 20
for i in range(num_merges):
    pairs = pair_freq(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(f"Merge {i+1}: {best}")


Merge 1: ('e', '</w>')
Merge 2: ('t', 'h')
Merge 3: ('th', 'e</w>')
Merge 4: ('s', '</w>')
Merge 5: ('g', '</w>')
Merge 6: ('d', 'o')
Merge 7: ('b', 'o')
Merge 8: ('bo', 'y')
Merge 9: ('g', 's</w>')
Merge 10: ('c', 'a')
Merge 11: ('ca', 't')
Merge 12: ('i', 'n')
Merge 13: ('in', 'g</w>')
Merge 14: ('boy', '</w>')
Merge 15: ('h', 'u')
Merge 16: ('cat', '</w>')
Merge 17: ('a', 'r')
Merge 18: ('ar', 'e</w>')
Merge 19: ('do', 'gs</w>')
Merge 20: ('do', 'g</w>')


In [9]:
final_vocab = set()
for word in vocab:
    for token in word.split():
        final_vocab.add(token)
print("\nFinal WordPiece Vocabulary:\n", final_vocab)


Final WordPiece Vocabulary:
 {'are</w>', 's', 'a', 'hu', 'ing</w>', 'cat</w>', 'i', 's</w>', 'boy</w>', 'h', 't', 'o', 'the</w>', 'e', 'dogs</w>', 'q', 'boy', 'u', '</w>', 'l', 'c', 'cat', 'y', 'dog</w>', 'gs</w>', 'd', 'n', 'g'}


In [10]:
def wordpiece_tokenize(word, vocab):
    word = list(word) + ["</w>"]
    tokens = []
    i = 0
    while i < len(word):
        j = len(word)
        while j > i:
            piece = "".join(word[i:j])
            if piece in vocab:
                tokens.append(piece)
                i = j
                break
            j -= 1
        else:
            tokens.append("[UNK]")
            break
    return tokens

In [12]:
test_sentence = "The cat is chasing the dog quietly."
test_tokens = preprocess(test_sentence)

output_tokens = []
for word in test_tokens:
    output_tokens.extend(wordpiece_tokenize(word, final_vocab))

print("\nTokenization of test sentence:\n", output_tokens)


Tokenization of test sentence:
 ['the</w>', 'cat</w>', 'i', 's</w>', 'c', 'h', 'a', 's', 'ing</w>', 'the</w>', 'dog</w>', 'q', 'u', 'i', 'e', 't', 'l', 'y', '</w>']
