# Word Piece algorithm

Implement the WordPiece tokenization algorithm explained in the NPTEL video: [
Lec 09 | Tokenization Strategies](https://www.youtube.com/watch?v=PcxUVCmvsAM&list=PLp6ek2hDcoNDDRINFiWGDlPKUwW-g1Hjk&index=10)

## Part 1: exploring individual functions

In [235]:
train_text = """Predictions for food delivery in next 3 - 6 months - 

1. Commissions will go down for restaurants - they will be based on distance and proximity 

2. Platform fee will increase to 20-30 rs 

3. Gold / black membership prices will go up 

4. Delivery fee load on customers will increase majorly, riders will start getting paid more 

5. Ad earnings of aggregators will see a major downfall 

6. Many digital first chains will either shut down or get acquired 

7. 10 mins food delivery will slow down massively, people will go back to Zomato and Swiggy 

8. New food delivery hyperlocal players will emerge for tier 2 and 3 cities 

9. Physical brands will expand thru cloud kitchens thinking that's where nirvana is, only to be surprised later 

10. Cloud kitchen brands will go physical and consolidate their stores 

In the end, Indian aggregators will copy the door dash model with high delivery fee per order and Avg order values will increase by atleast 30 percent 

Quick commerce on the other hand will keep burning cash and will start consolidating in many zones 

Good days ahead for the food brands which survived so far !"""

### 1st Vocab
- Individual words
- Middle Words
- Unknown tokens

In [236]:
set(train_text)

{'\n',
 ' ',
 '!',
 "'",
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 'A',
 'C',
 'D',
 'G',
 'I',
 'M',
 'N',
 'P',
 'Q',
 'S',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

In [237]:
individual_chars = list(set(train_text))
middle_chars = ["##" + c for c in individual_chars]
unk_token = ["<unk>"]
len(individual_chars), len(middle_chars), len(unk_token)

(55, 55, 1)

In [238]:
vocab = individual_chars + middle_chars + unk_token
len(vocab)

111

### Create the word list

In [239]:
import re
train_words = re.split(r'\s+|\n', train_text.strip().strip("\n")) 
train_words[:10]

['Predictions',
 'for',
 'food',
 'delivery',
 'in',
 'next',
 '3',
 '-',
 '6',
 'months']

In [240]:
len(train_words)

197

### Prepare this for WordPiece Learning

In [241]:
train_words_modified = []

for word in train_words:
    letters = list(word)
    letters_modified = [letter if i==0 else "##" + letter for i, letter in enumerate(letters)]
    train_words_modified.append(letters_modified)

train_words_modified[:5]

[['P', '##r', '##e', '##d', '##i', '##c', '##t', '##i', '##o', '##n', '##s'],
 ['f', '##o', '##r'],
 ['f', '##o', '##o', '##d'],
 ['d', '##e', '##l', '##i', '##v', '##e', '##r', '##y'],
 ['i', '##n']]

In [242]:
len(train_words)

197

### Now get one iteration of the train loop
- Get stats: Unigram and Bigram
- Get score for each bigram
- Take bigram with max score
- Add merged bigram to the vocab
- Finally modify the `train_words_modified` with this merged bigram

In [212]:
def get_stats(train_words):
    unigram_stats, bigram_stats = {}, {}

    for word_list in train_words:
        for i in range(len(word_list) - 1):
            unigram_stats[word_list[i]] = unigram_stats.get(word_list[i], 0) + 1
            bigram_stats[(word_list[i], word_list[i + 1])] = bigram_stats.get((word_list[i], word_list[i + 1]), 0) + 1

        unigram_stats[word_list[-1]] = unigram_stats.get(word_list[-1], 0) + 1

    return unigram_stats, bigram_stats

In [213]:
unigram_stats, bigram_stats = get_stats(train_words_modified)

In [214]:
unigram_stats

{'I': 6,
 '##t': 179,
 "##'": 9,
 '##s': 116,
 'w': 64,
 '##o': 186,
 '##r': 182,
 '##h': 123,
 'n': 12,
 '##i': 122,
 '##n': 184,
 '##g': 42,
 'h': 28,
 '##e': 316,
 't': 76,
 '##a': 138,
 'a': 57,
 '##w': 21,
 'v': 6,
 'o': 46,
 '##f': 46,
 'W': 6,
 '##d': 77,
 '##P': 5,
 '##c': 64,
 '##l': 84,
 '##m': 76,
 '##:': 4,
 'B': 4,
 '##-': 10,
 '##u': 103,
 '##p': 43,
 '##.': 44,
 'b': 24,
 'c': 33,
 'g': 8,
 'i': 41,
 's': 49,
 '"': 1,
 '##G': 1,
 '##v': 25,
 '##b': 32,
 'd': 12,
 '##k': 7,
 'D': 2,
 '##,': 69,
 '##z': 1,
 'p': 15,
 'r': 7,
 'm': 5,
 '##"': 1,
 'T': 17,
 '##y': 35,
 'e': 16,
 'L': 2,
 '##E': 2,
 'f': 23,
 '##F': 1,
 '##x': 8,
 '##R': 1,
 '##T': 1,
 'S': 4,
 'u': 5,
 '##q': 1,
 'F': 5,
 'J': 1,
 'C': 2,
 'K': 2,
 'y': 1,
 '(': 13,
 '##)': 13,
 'V': 1,
 'k': 11,
 '=': 3,
 '4': 2,
 'j': 2,
 'N': 1,
 'E': 3,
 '{': 1,
 '#': 18,
 '###': 23,
 '##}': 1,
 'M': 1,
 '1': 8,
 '##1': 8,
 '##3': 6,
 '3': 3,
 '##+': 1,
 '##=': 1,
 '##6': 4,
 'O': 1,
 'l': 4,
 '>': 1,
 '##0': 2,
 '2': 2,

In [215]:
bigram_stats

{('I', '##t'): 4,
 ('##t', "##'"): 4,
 ("##'", '##s'): 4,
 ('w', '##o'): 22,
 ('##o', '##r'): 54,
 ('##r', '##t'): 10,
 ('##t', '##h'): 21,
 ('n', '##o'): 6,
 ('##o', '##t'): 12,
 ('##t', '##i'): 22,
 ('##i', '##n'): 28,
 ('##n', '##g'): 21,
 ('h', '##e'): 1,
 ('##e', '##r'): 46,
 ('##r', '##e'): 27,
 ('t', '##h'): 57,
 ('##h', '##a'): 10,
 ('##a', '##t'): 27,
 ('##h', '##e'): 52,
 ('a', '##r'): 4,
 ('t', '##w'): 3,
 ('##w', '##o'): 3,
 ('v', '##e'): 1,
 ('##r', '##s'): 8,
 ('##s', '##i'): 4,
 ('##i', '##o'): 9,
 ('##o', '##n'): 16,
 ('##n', '##s'): 9,
 ('o', '##f'): 25,
 ('W', '##o'): 4,
 ('##r', '##d'): 24,
 ('##d', '##P'): 4,
 ('##P', '##i'): 4,
 ('##i', '##e'): 12,
 ('##e', '##c'): 19,
 ('##c', '##e'): 18,
 ('a', '##l'): 11,
 ('##l', '##g'): 8,
 ('##g', '##o'): 8,
 ('##r', '##i'): 19,
 ('##i', '##t'): 16,
 ('##h', '##m'): 8,
 ('##m', '##:'): 2,
 ('B', '##o'): 1,
 ('##t', '##t'): 2,
 ('##t', '##o'): 6,
 ('##o', '##m'): 14,
 ('##m', '##-'): 2,
 ('##-', '##u'): 2,
 ('##u', '##p'): 2,


In [216]:
def get_score(bigram, unigram_stats, bigram_stats):
    unigram1, unigram2 = bigram[0], bigram[1]

    score = bigram_stats[bigram] / (unigram_stats[unigram1] * unigram_stats[unigram2])
    return score

In [217]:
bigram_score_stats = {}

for key in bigram_stats:
    bigram_score_stats[key] = get_score(key, unigram_stats, bigram_stats)

bigram_score_stats

{('I', '##t'): 0.0037243947858473,
 ('##t', "##'"): 0.002482929857231533,
 ("##'", '##s'): 0.0038314176245210726,
 ('w', '##o'): 0.0018481182795698926,
 ('##o', '##r'): 0.0015951790145338532,
 ('##r', '##t'): 0.00030695561421818405,
 ('##t', '##h'): 0.0009538084207657719,
 ('n', '##o'): 0.002688172043010753,
 ('##o', '##t'): 0.0003604253018561903,
 ('##t', '##i'): 0.001007418261745581,
 ('##i', '##n'): 0.0012473271560940842,
 ('##n', '##g'): 0.002717391304347826,
 ('h', '##e'): 0.00011301989150090416,
 ('##e', '##r'): 0.000799833078314091,
 ('##r', '##e'): 0.00046946724161914033,
 ('t', '##h'): 0.006097560975609756,
 ('##h', '##a'): 0.0005891363261458702,
 ('##a', '##t'): 0.0010930289045421424,
 ('##h', '##e'): 0.0013378614798806217,
 ('a', '##r'): 0.000385579332947754,
 ('t', '##w'): 0.0018796992481203006,
 ('##w', '##o'): 0.0007680491551459293,
 ('v', '##e'): 0.0005274261603375527,
 ('##r', '##s'): 0.00037893141341417203,
 ('##s', '##i'): 0.0002826455624646693,
 ('##i', '##o'): 0.000

In [220]:
best_bigram = max(bigram_score_stats.items(), key=lambda i: i[-1])[0]
best_bigram

('"', '##G')

In [221]:
def merge(train_words, best_bigram):
    joined_word = best_bigram[0] + best_bigram[1][2:]
    print(joined_word)
    vocab.append(joined_word)
    train_words_modified = []

    for k, word_list in enumerate(train_words):
        cnt = 0
        i = 0
        while i < (len(word_list) - 1):
            if word_list[i] == best_bigram[0] and word_list[i + 1] == best_bigram[1]:
                # print(word_list)
                word_list[i:i + 2] = [joined_word]
                # print(word_list)
                cnt += 1
            else:
                i += 1

        if cnt > 0:
            print(word_list, k)

        train_words_modified.append(word_list)

    return train_words_modified

In [222]:
train_words_modified = merge(train_words_modified, best_bigram)

"G
['"G', '##i', '##v', '##e', '##n'] 23


In [223]:
train_words_modified[23]

['"G', '##i', '##v', '##e', '##n']

In [224]:
vocab[-1]

'"G'

### Now have some more merges

In [243]:
num_merges = 10

for i in range(num_merges):
    unigram_stats, bigram_stats = get_stats(train_words_modified)
    bigram_score_stats = {}
    for key in bigram_stats:
        bigram_score_stats[key] = get_score(key, unigram_stats, bigram_stats)
    best_bigram = max(bigram_score_stats.items(), key=lambda i: i[-1])[0]
    vocab.append("".join(best_bigram))
    train_words_modified = merge(train_words_modified, best_bigram)
    print(f"Best Bigram: {best_bigram}")
    print("*" * 20)

##-3
['2', '##0', '##-3', '##0'] 33
Best Bigram: ('##-', '##3')
********************
##0-3
['2', '##0-3', '##0'] 33
Best Bigram: ('##0', '##-3')
********************
20-3
['20-3', '##0'] 33
Best Bigram: ('2', '##0-3')
********************
20-30
['20-30'] 33
Best Bigram: ('20-3', '##0')
********************
10
['10'] 82
['10', '##.'] 130
Best Bigram: ('1', '##0')
********************
30
['30'] 166
Best Bigram: ('3', '##0')
********************
Sw
['Sw', '##i', '##g', '##g', '##y'] 97
Best Bigram: ('S', '##w')
********************
up
['up'] 43
Best Bigram: ('u', '##p')
********************
1.
['1.'] 11
Best Bigram: ('1', '##.')
********************
4.
['4.'] 44
Best Bigram: ('4', '##.')
********************


### Now try and encode test text

In [348]:
train_text = """Language model pre-training has been shown to
be effective for improving many natural language
processing tasks (Dai and Le, 2015; Peters et al.,
2018a; Radford et al., 2018; Howard and Ruder,
2018). These include sentence-level tasks such as
natural language inference (Bowman et al., 2015;
Williams et al., 2018) and paraphrasing (Dolan
and Brockett, 2005), which aim to predict the relationships between sentences by analyzing them
holistically, as well as token-level tasks such as
named entity recognition and question answering,
where models are required to produce fine-grained
output at the token level (Tjong Kim Sang and
De Meulder, 2003; Rajpurkar et al., 2016).
There are two existing strategies for applying pre-trained language representations to downstream tasks: feature-based and fine-tuning. The
feature-based approach, such as ELMo (Peters
et al., 2018a), uses task-specific architectures that
include the pre-trained representations as additional features. The fine-tuning approach, such as
the Generative Pre-trained Transformer (OpenAI
GPT) (Radford et al., 2018), introduces minimal
task-specific parameters, and is trained on the
downstream tasks by simply fine-tuning all pretrained parameters. The two approaches share the
same objective function during pre-training, where
they use unidirectional language models to learn
general language representations.
We argue that current techniques restrict the
power of the pre-trained representations, especially for the fine-tuning approaches. The major limitation is that standard language models are
unidirectional, and this limits the choice of architectures that can be used during pre-training. For
example, in OpenAI GPT, the authors use a left-toright architecture, where every token can only attend to previous tokens in the self-attention layers
of the Transformer (Vaswani et al., 2017). Such restrictions are sub-optimal for sentence-level tasks,
and could be very harmful when applying finetuning based approaches to token-level tasks such
as question answering, where it is crucial to incorporate context from both directions.
In this paper, we improve the fine-tuning based
approaches by proposing BERT: Bidirectional
Encoder Representations from Transformers.
BERT alleviates the previously mentioned unidirectionality constraint by using a “masked language model” (MLM) pre-training objective, inspired by the Cloze task (Taylor, 1953). The
masked language model randomly masks some of
the tokens from the input, and the objective is to
predict the original vocabulary id of the masked
"""

In [349]:
individual_chars = list(set(train_text))
middle_chars = ["##" + c for c in individual_chars]
unk_token = ["<unk>"]
len(individual_chars), len(middle_chars), len(unk_token)

(65, 65, 1)

In [350]:
vocab = individual_chars + middle_chars + unk_token
len(vocab)

131

In [351]:
import re
train_words = re.split(r'\s+|\n', train_text.strip().strip("\n")) 
train_words[:10]

['Language',
 'model',
 'pre-training',
 'has',
 'been',
 'shown',
 'to',
 'be',
 'effective',
 'for']

In [352]:
train_words_modified = []

for word in train_words:
    letters = list(word)
    letters_modified = [letter if i==0 else "##" + letter for i, letter in enumerate(letters)]
    train_words_modified.append(letters_modified)

train_words_modified[:5]

[['L', '##a', '##n', '##g', '##u', '##a', '##g', '##e'],
 ['m', '##o', '##d', '##e', '##l'],
 ['p',
  '##r',
  '##e',
  '##-',
  '##t',
  '##r',
  '##a',
  '##i',
  '##n',
  '##i',
  '##n',
  '##g'],
 ['h', '##a', '##s'],
 ['b', '##e', '##e', '##n']]

In [353]:
len(train_words_modified)

364

In [354]:
num_merges = 100

for i in range(num_merges):
    unigram_stats, bigram_stats = get_stats(train_words_modified)
    bigram_score_stats = {}
    for key in bigram_stats:
        bigram_score_stats[key] = get_score(key, unigram_stats, bigram_stats)
    best_bigram = max(bigram_score_stats.items(), key=lambda i: i[-1])[0]
    train_words_modified = merge(train_words_modified, best_bigram)
    print(f"Best Bigram: {best_bigram}")
    print("*" * 20)

19
['19', '##5', '##3', '##)', '##.'] 337
Best Bigram: ('1', '##9')
********************
##AI
['(', '##O', '##p', '##e', '##n', '##AI'] 149
['O', '##p', '##e', '##n', '##AI'] 236
Best Bigram: ('##A', '##I')
********************
##LM
['E', '##LM', '##o'] 124
['(', '##M', '##LM', '##)'] 328
Best Bigram: ('##L', '##M')
********************
##MLM
['(', '##MLM', '##)'] 328
Best Bigram: ('##M', '##LM')
********************
ELM
['ELM', '##o'] 124
Best Bigram: ('E', '##LM')
********************
##ER
['B', '##ER', '##T', '##:'] 309
['B', '##ER', '##T'] 315
Best Bigram: ('##E', '##R')
********************
BER
['BER', '##T', '##:'] 309
['BER', '##T'] 315
Best Bigram: ('B', '##ER')
********************
195
['195', '##3', '##)', '##.'] 337
Best Bigram: ('19', '##5')
********************
1953
['1953', '##)', '##.'] 337
Best Bigram: ('195', '##3')
********************
GP
['GP', '##T', '##)'] 150
['GP', '##T', '##,'] 237
Best Bigram: ('G', '##P')
********************
##3;
['2', '##0', '##0', '##3;'] 9

In [355]:
len(vocab)

231

In [356]:
id2token, token2id = {}, {}

for i, token in enumerate(vocab):
    id2token[i] = token
    token2id[token] = i

In [415]:
def encode(text):
    words = text.split(' ')
    encoded_words = []

    for word in words:
        i = 0
        N = len(word)

        while i < N:
            end = N
            match = False
            while end > i:
                cur_word = word[i:end]
                cur_word = ("##" + cur_word) if i > 0 else cur_word
                if cur_word in vocab:
                    encoded_words.append(token2id[cur_word])
                    i = end
                    match = True
                    break
                end -= 1

            if not match:
                encoded_words.append(token2id["<unk>"])
                i += 1

    return encoded_words

def decode(encoded_words):
    decoded_words = [id2token[i] for i in encoded_words]   
    decoded_text = ""

    for word in decoded_words:
        if "##" in word:
            decoded_text += word.strip("##")
        else:
            decoded_text += " " + word

    return decoded_text.strip(' ')

In [416]:
test_text = """
hi
bye"""

In [417]:
test_text

'\nhi\nbye'

In [418]:
encoded_text = encode(test_text)
# encoded_text

In [419]:
len(test_text)

7

In [420]:
encoded_text

[40, 119, 66, 105, 124, 104, 69]

In [421]:
[id2token[i] for i in encoded_text]   

['\n', '##h', '##i', '##\n', '##b', '##y', '##e']

In [422]:
decode(encoded_text)

'\nhi\nbye'

## Part 2: Now create class

In [442]:
class WordPiece:
    def __init__(self):
        self.vocab = []
        self.id2token = {}
        self.token2id = {}

    def init_vocab(self, text):
        individual_chars = list(set(text))
        middle_chars = ["##" + c for c in individual_chars]
        unk_token = ["<unk>"]

        self.vocab = individual_chars + middle_chars + unk_token

    def get_stats(self, train_words):
        unigram_stats, bigram_stats = {}, {}

        for word_list in train_words:
            for i in range(len(word_list) - 1):
                unigram_stats[word_list[i]] = unigram_stats.get(word_list[i], 0) + 1
                bigram_stats[(word_list[i], word_list[i + 1])] = bigram_stats.get((word_list[i], word_list[i + 1]), 0) + 1

            unigram_stats[word_list[-1]] = unigram_stats.get(word_list[-1], 0) + 1

        return unigram_stats, bigram_stats
    
    def get_score(self, bigram, unigram_stats, bigram_stats):
        unigram1, unigram2 = bigram[0], bigram[1]

        score = bigram_stats[bigram] / (unigram_stats[unigram1] * unigram_stats[unigram2])
        return score
    
    def merge(self, train_words, best_bigram):
        joined_word = best_bigram[0] + best_bigram[1][2:]
        self.vocab.append(joined_word)
        train_words_modified = []

        for k, word_list in enumerate(train_words):
            cnt = 0
            i = 0
            while i < (len(word_list) - 1):
                if word_list[i] == best_bigram[0] and word_list[i + 1] == best_bigram[1]:
                    word_list[i:i + 2] = [joined_word]
                    cnt += 1
                else:
                    i += 1

            train_words_modified.append(word_list)

        return train_words_modified
    
    def encode(self, text):
        words = text.split(' ')
        original_length = 0
        encoded_words = []

        for word in words:
            i = 0
            N = len(word)
            original_length += N

            while i < N:
                end = N
                match = False
                while end > i:
                    cur_word = word[i:end]
                    cur_word = ("##" + cur_word) if i > 0 else cur_word
                    if cur_word in self.vocab:
                        encoded_words.append(self.token2id[cur_word])
                        i = end
                        match = True
                        break
                    end -= 1

                if not match:
                    encoded_words.append(self.token2id["<unk>"])
                    i += 1

        encoded_length = len(encoded_words)
        compression = original_length / encoded_length
        print(f"Compression: {compression} X")
        return encoded_words

    def decode(self, encoded_words):
        decoded_words = [self.id2token[i] for i in encoded_words]   
        decoded_text = ""

        for word in decoded_words:
            if "##" in word:
                decoded_text += word.strip("##")
            else:
                decoded_text += " " + word

        return decoded_text.strip(' ')
    
    def get_mapping(self):
        for i, token in enumerate(self.vocab):
            self.id2token[i] = token
            self.token2id[token] = i

In [443]:
wordpiece_tokenizer = WordPiece()

In [444]:
train_text = """The models in our experiments are word-based, character-based, mixed word-character-based or several
wordpiece models with varying vocabulary sizes.
For the word model, we selected the most frequent 212K source words as the source vocabulary and the
most popular 80k target words as the target vocabulary. Words not in the source vocabulary or the target
vocabulary (unknown words) are converted into special <first_char>_UNK_<last_char> symbols. Note, in
this case, there is more than one UNK (e.g., our production word models have roughly 5000 different UNKs
in this case). We then use the attention mechanism to copy a corresponding word from the source to replace
these unknown words during decoding [37].
The mixed word-character model is similar to the word model, except the out-of-vocabulary (OOV) words
are converted into sequences of characters with special delimiters around them as described in section 4.2 in
more detail. In our experiments, the vocabulary size for the mixed word-character model is 32K. For the pure
character model, we simply split all words into constituent characters, resulting typically in a few hundred
basic characters (including special symbols appearing in the data). For the wordpiece models, we train 3
different models with vocabulary sizes of 8K, 16K, and 32K.
Table 4 summarizes our results on the WMT En→Fr dataset. In this table, we also compare against other
strong baselines without model ensembling. As can be seen from the table, “WPM-32K”, a wordpiece model
with a shared source and target vocabulary of 32K wordpieces, performs well on this dataset and achieves the
best quality as well as the fastest inference speed.
The pure character model (char input, char output) works surprisingly well on this task, not much worse
than the best wordpiece models in BLEU score. However, these models are rather slow to train and slow to
use as the sequences are much longer.
Our best model, WPM-32K, achieves a BLEU score of 38.95. Note that this BLEU score represents the
averaged score of 8 models we trained. The maximum BLEU score of the 8 models is higher at 39.37. We
point out that our models are completely self-contained, as opposed to previous models reported in [ 45],
which depend on some external alignment models to achieve their best results. Also note that all our test set
numbers were achieved by picking an optimal model on the development set which was then used to decode
the test set.
Note that the timing numbers for this section are obtained on CPUs, not TPUs. We use here the same
CPU machine as described above, and run the decoder with a batchsize of 16 sentences in parallel and a
maximum of 4 concurrent hypotheses at any time per sentence. The time per sentence is the total decoding
time divided by the number of respective sentences in the test set."""

In [445]:
wordpiece_tokenizer.init_vocab(train_text)

In [446]:
len(set(train_text))

67

In [447]:
len(wordpiece_tokenizer.vocab)

135

In [448]:
def preprocess_text(text):
    words = re.split(r'\s+|\n', text.strip().strip("\n")) 
    words_modified = []

    for word in words:
        letters = list(word)
        letters_modified = [letter if i==0 else "##" + letter for i, letter in enumerate(letters)]
        words_modified.append(letters_modified)

    return words_modified

In [449]:
train_words = preprocess_text(train_text)
train_words[:5]

[['T', '##h', '##e'],
 ['m', '##o', '##d', '##e', '##l', '##s'],
 ['i', '##n'],
 ['o', '##u', '##r'],
 ['e', '##x', '##p', '##e', '##r', '##i', '##m', '##e', '##n', '##t', '##s']]

In [450]:
num_merges = 200

for i in range(num_merges):
    unigram_stats, bigram_stats = wordpiece_tokenizer.get_stats(train_words)
    bigram_score_stats = {}
    for key in bigram_stats:
        bigram_score_stats[key] = wordpiece_tokenizer.get_score(key, unigram_stats, bigram_stats)
    best_bigram = max(bigram_score_stats.items(), key=lambda i: i[-1])[0]
    train_words = wordpiece_tokenizer.merge(train_words, best_bigram)
    print(f"Iteration {i + 1}, Best Bigram: {best_bigram}")

Iteration 1, Best Bigram: ('2', '##1')
Iteration 2, Best Bigram: ('##→', '##F')
Iteration 3, Best Bigram: ('“', '##W')
Iteration 4, Best Bigram: ('##O', '##V')
Iteration 5, Best Bigram: ('##O', '##OV')
Iteration 6, Best Bigram: ('1', '##6')
Iteration 7, Best Bigram: ('U', '##N')
Iteration 8, Best Bigram: ('##M', '##T')
Iteration 9, Best Bigram: ('##_', '##<')
Iteration 10, Best Bigram: ('5', '##0')
Iteration 11, Best Bigram: ('50', '##0')
Iteration 12, Best Bigram: ('500', '##0')
Iteration 13, Best Bigram: ('8', '##0')
Iteration 14, Best Bigram: ('##3', '##7')
Iteration 15, Best Bigram: ('[', '##37')
Iteration 16, Best Bigram: ('[37', '##]')
Iteration 17, Best Bigram: ('##5', '##]')
Iteration 18, Best Bigram: ('##9', '##5')
Iteration 19, Best Bigram: ('B', '##L')
Iteration 20, Best Bigram: ('BL', '##E')
Iteration 21, Best Bigram: ('4', '##5]')
Iteration 22, Best Bigram: ('(', '##OOV')
Iteration 23, Best Bigram: ('(OOV', '##)')
Iteration 24, Best Bigram: ('“W', '##P')
Iteration 25, Best

In [451]:
len(wordpiece_tokenizer.vocab)

335

In [454]:
wordpiece_tokenizer.get_mapping()

In [461]:
text = """some of which we disagree with, see the table caption"""

In [462]:
encoded_text = wordpiece_tokenizer.encode(text)

Compression: 1.0731707317073171 X


In [463]:
text == wordpiece_tokenizer.decode(encoded_text)

True

In [464]:
wordpiece_tokenizer.decode(encoded_text)

'some of which we disagree with, see the table caption'