In [10]:
from collections import defaultdict
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

string = "low low low low low lowest lowest newer newer newer newer newer newer wider wider wider new new"
string = string.lower()
words = string.split()

word_freqs = defaultdict(int)
for text in words:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    new_words = [word for word, offset in words_with_offsets]
    for word in new_words:
        word_freqs[word] += 1

print(word_freqs)

alphabet = []
for word in word_freqs.keys():
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
alphabet
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores

print(alphabet)

pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break

best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)

def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

vocab.append((
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    ))
splits = merge_pair(*best_pair, splits)

vocab_size = 25
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    print(best_pair, max_score)
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

print(vocab)

def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

def tokenize(text):
    text = text.lower()
    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)
    pre_tokenized_text = [word for word, offset in pre_tokenize_result]
    encoded_words = [encode_word(word) for word in pre_tokenized_text]
    return sum(encoded_words, [])

print(tokenize("newer"))
print(tokenize("lower"))

defaultdict(<class 'int'>, {'low': 5, 'lowest': 2, 'newer': 6, 'wider': 3, 'new': 2})
['##d', '##e', '##i', '##o', '##r', '##s', '##t', '##w', 'l', 'n', 'w']
('l', '##o'): 0.14285714285714285
('##o', '##w'): 0.06666666666666667
('##w', '##e'): 0.028070175438596492
('##e', '##s'): 0.05263157894736842
('##s', '##t'): 0.5
('n', '##e'): 0.05263157894736842
('##s', '##t') 0.5
('w', '##i') 0.3333333333333333
('wi', '##d') 0.3333333333333333
('l', '##o') 0.14285714285714285
('lo', '##w') 0.06666666666666667
('##e', '##st') 0.05263157894736842
('low', '##est') 0.14285714285714285
('n', '##e') 0.058823529411764705
('ne', '##w') 0.125
['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##d', '##e', '##i', '##o', '##r', '##s', '##t', '##w', 'l', 'n', 'w', '##st', 'wi', 'wid', 'lo', 'low', '##est', 'lowest', 'ne', 'new']
['new', '##e', '##r']
['low', '##e', '##r']
