# WordPiece Algorithm Implementation

#### HADDOU Younes

This notebook presents the WordPiece tokenization algorithm used in BERT transformers.

The goal is to create a WordPiece(text) function which takes as input a string representing the corpus and outputs the tokens vocabulary using the WordPiece Algorithm.


The algorithm can be defined as follows:

    - Trivial word pre-tokenization of the corpus before
    - Initial V0 vocabulary generation  
    - As a first iteration, we associate to each word a list of it's chars (apple : [a, ##p, ##p, ##l, ##e]
    - We simplify trough merges the chars list using a score method (score = freq_of_pairs/freq_first_element*freq_second_element)
    - The end condition of the algorithm is the size of constituted vocabulary (size_vocab = size_V0 + simplified tokens)

In [42]:
from collections import Counter

In [43]:
text = 'Trivial word pre-tokenization of the corpus'
def basic_tokenize(text):
  corpus = []

  tokens = text.split()
  token_counts = collections.Counter(tokens)


  return token_counts

count = basic_tokenize(text)
count.keys()

dict_keys(['Trivial', 'word', 'pre-tokenization', 'of', 'the', 'corpus'])

In [46]:
def vocab0(count):
    """
    Generate a initial vocabulary 
    """
    v0 = []
    for word in count.keys():
        if word[0] not in v0:
            v0.append(word[0])
        for letter in word[1:]:
            if f"##{letter}" not in v0:
                v0.append(f"##{letter}")
    v0.sort()
    return v0

vocab0(count)[]


['##-',
 '##a',
 '##d',
 '##e',
 '##f',
 '##h',
 '##i',
 '##k',
 '##l',
 '##n',
 '##o',
 '##p',
 '##r',
 '##s',
 '##t',
 '##u',
 '##v',
 '##z',
 'T',
 'c',
 'o',
 'p',
 't',
 'w']

In [48]:
def get_words_parts(count):
    """
    Assotiate each words with its parts
    """
    words_parts = {}
    for word in count.keys():
        parts=[]
        for a,b in enumerate(word):
            # words_parts[word] = b if a != 0 else f"##{b}"
            parts.append(b if a == 0 else f"##{b}")
        words_parts[word] = parts
    return words_parts

words_parts = get_words_parts(count)
words_parts

{'Trivial': ['T', '##r', '##i', '##v', '##i', '##a', '##l'],
 'word': ['w', '##o', '##r', '##d'],
 'pre-tokenization': ['p',
  '##r',
  '##e',
  '##-',
  '##t',
  '##o',
  '##k',
  '##e',
  '##n',
  '##i',
  '##z',
  '##a',
  '##t',
  '##i',
  '##o',
  '##n'],
 'of': ['o', '##f'],
 'the': ['t', '##h', '##e'],
 'corpus': ['c', '##o', '##r', '##p', '##u', '##s']}

In [49]:
def calculate_score(parts, count):
    """
    calculate scores for the paires
    """
    letter_freqs = {}
    pair_freqs = {}
    for word, freq in count.items():
        part = parts[word]
        if len(part) == 1:
            letter_freqs[part[0]] = letter_freqs.get(part[0], 0) + freq
            continue
        for i in range(len(part) - 1):
            pair = (part[i], part[i + 1])
            letter_freqs[pair[0]] = letter_freqs.get(pair[0], 0) + freq
            letter_freqs[pair[1]] = letter_freqs.get(pair[1], 0) + freq
            pair_freqs[pair] = pair_freqs.get(pair, 0) + freq


    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores
scores = calculate_score(words_parts, count)
scores

{('T', '##r'): 0.125,
 ('##r', '##i'): 0.015625,
 ('##i', '##v'): 0.0625,
 ('##v', '##i'): 0.0625,
 ('##i', '##a'): 0.03125,
 ('##a', '##l'): 0.25,
 ('w', '##o'): 0.125,
 ('##o', '##r'): 0.03125,
 ('##r', '##d'): 0.125,
 ('p', '##r'): 0.125,
 ('##r', '##e'): 0.025,
 ('##e', '##-'): 0.1,
 ('##-', '##t'): 0.125,
 ('##t', '##o'): 0.03125,
 ('##o', '##k'): 0.0625,
 ('##k', '##e'): 0.1,
 ('##e', '##n'): 0.06666666666666667,
 ('##n', '##i'): 0.041666666666666664,
 ('##i', '##z'): 0.0625,
 ('##z', '##a'): 0.125,
 ('##a', '##t'): 0.0625,
 ('##t', '##i'): 0.03125,
 ('##i', '##o'): 0.015625,
 ('##o', '##n'): 0.041666666666666664,
 ('o', '##f'): 1.0,
 ('t', '##h'): 0.5,
 ('##h', '##e'): 0.1,
 ('c', '##o'): 0.125,
 ('##r', '##p'): 0.0625,
 ('##p', '##u'): 0.25,
 ('##u', '##s'): 0.5}

In [51]:
def simplify_parts(pair, words_parts):
    """
    Simplify components list
    """
    lc = pair[0]
    rc = pair[1]
    for word, part in words_parts.items():
        if len(part) == 1:
            continue
        for i in range(len(part)-1):
            if part[i] == lc and part[i + 1] == rc:
                simplified = lc + rc[2:] if rc.startswith("##") else lc + rc 
                part[i:i+2] = [simplified]  
        words_parts[word] = part
    return words_parts
    
simplify_parts(('w', '##o'),words_parts)
words_parts['word']

['wo', '##r', '##d']

In [66]:
VOCAB_SIZE = 57

def WordPiece(text)->list:
    """
    Aggregate all above functions.
    Using VOCAB_SIZE, the vocabulary's granularity can be defined
    """
    count = basic_tokenize(text)
    v0 = vocab0(count)
    words_parts = get_words_parts(count)
    while len(v0) < VOCAB_SIZE:
        scores = calculate_score(words_parts, count)
        # This try block catch an error that happens when VOCAB_SIZE is too high compared to the corpus
        try:
            best_pair = max(scores, key=scores.get)
        except ValueError:
            print(f"VOCAB_SIZE is to high, maximum VOCAB_SIZE is:{len(v0)}")
            break
        lc = best_pair[0]
        rc = best_pair[1]
        max_score = scores[best_pair]
        words_parts = simplify_parts(best_pair, words_parts)
        tokened_part =  lc + rc[2:] if rc.startswith("##") else lc + rc 
        v0.append(tokened_part)
    v0.sort(reverse=True)
    return v0
    
    
WordPiece("Autumn leaves rustled in the gentle breeze as I walked through the peaceful forest. The golden hues of the trees created a mesmerizing landscape, and the sound of a distant stream added to the tranquility of the scene. I stopped to take a deep breath, savoring the crisp, fresh air before continuing my hike.")

['w',
 't',
 's',
 'rustl',
 'rust',
 'rus',
 'ru',
 'r',
 'p',
 'of',
 'o',
 'my',
 'm',
 'l',
 'i',
 'h',
 'g',
 'f',
 'd',
 'c',
 'b',
 'a',
 'Th',
 'T',
 'I',
 'Autum',
 'Autu',
 'Aut',
 'Au',
 'A',
 '##z',
 '##y',
 '##v',
 '##u',
 '##ty',
 '##t',
 '##s',
 '##r',
 '##q',
 '##p',
 '##o',
 '##n',
 '##m',
 '##lity',
 '##l',
 '##k',
 '##ity',
 '##i',
 '##h',
 '##g',
 '##f',
 '##e',
 '##d',
 '##c',
 '##a',
 '##.',
 '##,']

## Remarks

The pre-tokenization (whitespace tokenization) here is very simple and does not take into account punctuation and other special characters. So some tokens may be unaccurate. To solve this problem we can use a nlp pre-tokenization model. I did not want to use external tools like Spacy in order to stay within the framework of the exercise.

Finally, to tokenize a text we apply the algoritm on each word and we take the biggest subword present in the generated vocabulary.
