In [None]:
from transformers import AutoTokenizer
from collections import defaultdict

tokenizer = AutoTokenizer.from_pretrained("gpt2")
pre_tokenizer = tokenizer.backend_tokenizer.pre_tokenizer

corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]


""" 
split sentence into words and compute the occurance freq of each word
return dict(word, freq)
"""
def compute_word_freq(corpus, pre_tokenizer):
    word2freq = defaultdict(int)

    for text in corpus:
        words_with_offsets = pre_tokenizer.pre_tokenize_str(text)
        new_words = [word for word, offset in words_with_offsets]
        for word in new_words:
            word2freq[word] += 1
    
    return word2freq
        
word2freq = compute_word_freq(corpus, pre_tokenizer)
print(word2freq)

In [None]:
""" 
split word into its splits
return dict(word, splits)
"""
def split_words(words):
    word2split = {}
    for word in words:
        tmp = []
        for i in range(len(word)):
            if i==0:
                tmp.append(word[i])
            else:
                tmp.append(f'##{word[i]}')
        word2split[word] = tmp
    return word2split
    
word2split = split_words(word2freq.keys())
print(word2split)

In [None]:
""" 
get initial alphabet from directly sperating words
return list[str]
"""
def get_alphabet_from_words(words):
    alphabet = set()
    for word in words:
        alphabet = alphabet.union([c for c in word])

    alphabet = sorted(alphabet) 
    return alphabet

""" 
get initial alphabet from using words' splits
return list[str]
"""
def get_alphabet_from_splits(splits):
    alphabet = set()
    for split in splits:
        alphabet = alphabet.union(split)

    alphabet = sorted(alphabet) 
    return alphabet

alphabet = get_alphabet_from_words(word2freq.keys())
print(alphabet)

In [None]:
vocab = ['<eot>'] + list(alphabet)

In [None]:
""" 
compute the freq of each pair in these words
return dict[pair, int]
"""
def compute_pair_freq(word2split, word2freq):
    pair2freq = defaultdict(int)
    for word, split in word2split.items():
        if len(word) == 1:
            pass
        else:
            freq = word2freq[word]
            for i in range(len(split)-1):
                pair = (split[i], split[i+1])
                pair2freq[pair] += freq
    return pair2freq

pair2freq = compute_pair_freq(word2split, word2freq)
print(pair2freq)

In [None]:
""" 
find the pair with the biggest frequence
return pair, int
"""
def find_most_frequent_pair(pair2freq):
    assert len(pair2freq) >= 1
    max_freq = -1
    max_freq_pair = None
    for pair, freq in pair2freq.items():
        if freq > max_freq:
            max_freq = freq
            max_freq_pair = pair
    return max_freq_pair, max_freq

pair, freq = find_most_frequent_pair(pair2freq)

pair, freq

In [None]:
""" 
merge two tokens, '##' is considered
return str
"""
def merge_pair(pair):
    return pair[0] + pair[1][2:]

merge_pair(pair)

In [None]:
merge_rule = {}
""" 
update the splits of words according to one rule of a specific pair of tokens
return dict(word, splits)
"""
def update_splits(pair, word2split, new_byte=None):
    if new_byte is None:
        new_byte = merge_pair(pair)
    for word, split in word2split.items():
        if len(word) == 1:
            pass
        else:
            i = 0
            while i < len(split)-1:
                if (split[i], split[i+1]) == pair:
                    split = split[:i] + [new_byte] + split[i+2:]
                else:
                    i += 1
            word2split[word] = split
    return word2split

merge_rule[pair] = merge_pair(pair)
word2split = update_splits(pair, word2split)

Now put it all together

In [None]:
corpus = [sentence.lower() for sentence in corpus]


vocab_size = 50
word2freq = compute_word_freq(corpus, pre_tokenizer)
alphabet = get_alphabet_from_words(word2freq.keys())
word2split = split_words(word2freq.keys())
alphabet = get_alphabet_from_splits(word2split.values())
vocab = ['<eot>'] + list(alphabet)
merge_rule = {}

In [None]:
while len(vocab) < vocab_size:
    # get the pair freq and the biggest
    pair2freq = compute_pair_freq(word2split, word2freq)
    pair, freq = find_most_frequent_pair(pair2freq)
    # merge rule is kept for faster tokenization
    merge_rule[pair] = merge_pair(pair)
    vocab.append(merge_pair(pair))
    # update splits according to the new pair rule
    word2split = update_splits(pair, word2split)

In [None]:
merge_rule

In [None]:
def tokenize(text, pre_tokenizer, merge_rule):
    """ 
    just like the loop above, using 
    """
    text = text.lower()
    words_with_offsets = pre_tokenizer.pre_tokenize_str(text)
    words = [word for word, offset in words_with_offsets]
    word2split = split_words(words)
    for pair, new_byte in merge_rule.items():
        word2split = update_splits(pair, word2split, new_byte)
    tokenized_words = sum([word2split[word] for word in words], [])
    return tokenized_words

tokenize("This is not a token.", pre_tokenizer, merge_rule)