<a href="https://colab.research.google.com/github/thomasshin/NLP_Study/blob/main/HuggingFace_NLP_Course/Tokenization_Algorithm/WordPiece_Tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers



#WordPiece tokenization

WordPiece is the tokenization algorithm Google developed to pretrain BERT. It has since been reused in quite a few Transformer models based on BERT, such as DistilBERT, MobileBERT, Funnel Transformers, and MPNET. It’s very similar to BPE in terms of the training, but the actual tokenization is done differently.

#Implementing WordPiece

Now let’s take a look at an implementation of the WordPiece algorithm. Like with BPE, this is just pedagogical, and you won’t able to use this on a big corpus.

We will use the same corpus as in the BPE example:

In [None]:
corpus = [
    "This is the Hugging Face Course.",
    "This chapter is about tokenization.",
    "This section shows several tokenizer algorithms.",
    "Hopefully, you will be able to understand how they are trained and generate tokens.",
]

First, we need to pre-tokenize the corpus into words. Since we are replicating a WordPiece tokenizer (like BERT), we will use the bert-base-cased tokenizer for the pre-tokenization:

In [None]:
from transformers import AutoTokenizer

tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-cased")

Then we compute the frequencies of each word in the corpus as we do the pre-tokenization:

In [None]:
from collections import defaultdict

word_freqs = defaultdict(int)
for text in corpus:
  words_with_offsets = tokenizer_bert.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
  new_words = [word for word, offset in words_with_offsets]
  for word in new_words:
    word_freqs[word] += 1

word_freqs

defaultdict(int,
            {'This': 3,
             'is': 2,
             'the': 1,
             'Hugging': 1,
             'Face': 1,
             'Course': 1,
             '.': 4,
             'chapter': 1,
             'about': 1,
             'tokenization': 1,
             'section': 1,
             'shows': 1,
             'several': 1,
             'tokenizer': 1,
             'algorithms': 1,
             'Hopefully': 1,
             ',': 1,
             'you': 1,
             'will': 1,
             'be': 1,
             'able': 1,
             'to': 1,
             'understand': 1,
             'how': 1,
             'they': 1,
             'are': 1,
             'trained': 1,
             'and': 1,
             'generate': 1,
             'tokens': 1})

In [None]:
word_freqs

defaultdict(int,
            {'This': 3,
             'is': 2,
             'the': 1,
             'Hugging': 1,
             'Face': 1,
             'Course': 1,
             '.': 4,
             'chapter': 1,
             'about': 1,
             'tokenization': 1,
             'section': 1,
             'shows': 1,
             'several': 1,
             'tokenizer': 1,
             'algorithms': 1,
             'Hopefully': 1,
             ',': 1,
             'you': 1,
             'will': 1,
             'be': 1,
             'able': 1,
             'to': 1,
             'understand': 1,
             'how': 1,
             'they': 1,
             'are': 1,
             'trained': 1,
             'and': 1,
             'generate': 1,
             'tokens': 1})

As we saw before, the alphabet is the unique set composed of all the first letters of words, and all the other letters that appear in words prefixed by ##:

In [None]:
alphabet = []

for word in word_freqs.keys():
  if word[0] not in alphabet:
    alphabet.append(word[0])
  for letter in word[1:]:
    if f"##{letter}" not in alphabet:
      alphabet.append(f"##{letter}")

alphabet.sort()
alphabet

We also add the special tokens used by the model at the beginning of that vocabulary. In the case of BERT, it’s the list ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]:

In [None]:
vocabs =  ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()

Next we need to split each word, with all the letters that are not the first prefixed by ##:

In [None]:
splits = {word : [c if i ==0 else f"##{c}" for i, c in enumerate(word)] for word in word_freqs.keys()}
splits

{'This': ['T', '##h', '##i', '##s'],
 'is': ['i', '##s'],
 'the': ['t', '##h', '##e'],
 'Hugging': ['H', '##u', '##g', '##g', '##i', '##n', '##g'],
 'Face': ['F', '##a', '##c', '##e'],
 'Course': ['C', '##o', '##u', '##r', '##s', '##e'],
 '.': ['.'],
 'chapter': ['c', '##h', '##a', '##p', '##t', '##e', '##r'],
 'about': ['a', '##b', '##o', '##u', '##t'],
 'tokenization': ['t',
  '##o',
  '##k',
  '##e',
  '##n',
  '##i',
  '##z',
  '##a',
  '##t',
  '##i',
  '##o',
  '##n'],
 'section': ['s', '##e', '##c', '##t', '##i', '##o', '##n'],
 'shows': ['s', '##h', '##o', '##w', '##s'],
 'several': ['s', '##e', '##v', '##e', '##r', '##a', '##l'],
 'tokenizer': ['t', '##o', '##k', '##e', '##n', '##i', '##z', '##e', '##r'],
 'algorithms': ['a',
  '##l',
  '##g',
  '##o',
  '##r',
  '##i',
  '##t',
  '##h',
  '##m',
  '##s'],
 'Hopefully': ['H', '##o', '##p', '##e', '##f', '##u', '##l', '##l', '##y'],
 ',': [','],
 'you': ['y', '##o', '##u'],
 'will': ['w', '##i', '##l', '##l'],
 'be': ['b', '##e

Now that we are ready for training, let’s write a function that computes the score of each pair. We’ll need to use this at each step of the training:

In [None]:
def compute_pair_scores(splits):
  letter_freqs = defaultdict(int)
  pair_freqs = defaultdict(int)
  for word, freq in word_freqs.items():
    split = splits[word]
    if len(split) == 1:
      letter_freqs[split[0]] += freq
      continue
    for i in range(len(split)-1):
      pair = (split[i], split[i+1])
      pair_freqs[pair] += freq
      letter_freqs[split[i]] += freq
    letter_freqs[split[-1]] += freq

  scores = {
      pair : freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]]) for pair, freq in pair_freqs.items()
  }

  return scores

In [None]:
pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break

('T', '##h'): 0.125
('##h', '##i'): 0.03409090909090909
('##i', '##s'): 0.02727272727272727
('i', '##s'): 0.1
('t', '##h'): 0.03571428571428571
('##h', '##e'): 0.011904761904761904


In [None]:
best_pair = ""
max_score = None

for pair, score in pair_scores.items():
  if max_score == None or score > max_score:
    max_score = score
    best_pair = pair
best_pair, max_score

(('a', '##b'), 0.2)

So the first merge to learn is ('a', '##b') -> 'ab', and we add 'ab' to the vocabulary:

In [None]:
vocabs.append("ab")

merge_pair("a", "b", spl

In [None]:
def merge_pair(a, b, splits):
  for word in word_freqs:
    split = splits[word]
    if len(split) == 1:
      continue
    i = 0
    while i < len(split)-1:
      if split[i] == a and split[i+1] == b:
        merge = a + b[2:] if b.startswith("##") else a + b
        split = split[:i] + [merge] + split[i+2:]
      else:
        i += 1
    splits[word] = split
  return splits

In [None]:
def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [None]:
splits = merge_pair("a", "##b", splits)
splits["about"]

['ab', '##o', '##u', '##t']

Now we have everything we need to loop until we have learned all the merges we want. Let’s aim for a vocab size of 70:

In [None]:
vocab_size = 70

while len(vocabs) < vocab_size:
  scores = compute_pair_scores(splits)
  best_pair, max_score = "", None
  for pair, score in scores.items():
    if max_score == None or max_score < score:
      max_score = score
      best_pair = pair
  splits = merge_pair(*best_pair, splits)
  new_token = (
      best_pair[0] + best_pair[1][2:] if best_pair[1].startswith("##") else best_pair[0] + best_pair[1]
  )
  vocabs.append(new_token)

vocabs

['[PAD]',
 '[UNK]',
 '[CLS]',
 '[SEP]',
 '[MASK]',
 '##a',
 '##b',
 '##c',
 '##d',
 '##e',
 '##f',
 '##g',
 '##h',
 '##i',
 '##k',
 '##l',
 '##m',
 '##n',
 '##o',
 '##p',
 '##r',
 '##s',
 '##t',
 '##u',
 '##v',
 '##w',
 '##y',
 '##z',
 ',',
 '.',
 'C',
 'F',
 'H',
 'T',
 'a',
 'b',
 'c',
 'g',
 'h',
 'i',
 's',
 't',
 'u',
 'w',
 'y',
 'ab',
 'Fa',
 'Fac',
 '##ct',
 '##ful',
 '##full',
 '##fully',
 'Th',
 'ch',
 '##hm',
 'cha',
 'chap',
 'chapt',
 '##thm',
 'Hu',
 'Hug',
 'Hugg',
 'sh',
 'th',
 'is',
 '##thms',
 '##za',
 '##zat',
 '##ut',
 '##ta']

As we can see, compared to BPE, this tokenizer learns parts of words as tokens a bit faster.

To tokenize a new text, we pre-tokenize it, split it, then apply the tokenization algorithm on each word. That is, we look for the biggest subword starting at the beginning of the first word and split it, then we repeat the process on the second part, and so on for the rest of that word and the following words in the text:

In [None]:
def encode_word(word):
  tokens = []
  while len(word) > 0:
    i = len(word)
    while i > 0 and word[:i] not in vocabs:
      i -= 1
    if i == 0:
      return [["UNK"]]
    tokens.append(word[:i])
    word = word[i:]
    if len(word) > 0:
      word = f"##{word}"
  return tokens

In [None]:
print(encode_word("Hugging"))
print(encode_word("HOgging"))

['Hugg', '##i', '##n', '##g']
[['UNK']]


Now, let’s write a function that tokenizes a text:

In [None]:
def tokenize(text):
  pre_tokenized_result = tokenizer_bert.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
  pre_tokenized_text = [word for word, offset in pre_tokenized_result]
  encoded_words = [encode_word(word) for word in pre_tokenized_text]
  return sum(encoded_words, [])

In [None]:
tokenize("This is the Hugging Face course!")

['Th',
 '##i',
 '##s',
 'is',
 'th',
 '##e',
 'Hugg',
 '##i',
 '##n',
 '##g',
 'Fac',
 '##e',
 'c',
 '##o',
 '##u',
 '##r',
 '##s',
 '##e',
 ['UNK']]

That’s it for the WordPiece algorithm! Now let’s take a look at Unigram.