# Chapter 13: Subword Segmentation
Tokenizers using the unigram language model from Taku (2018): https://arxiv.org/pdf/1804.10959.pdf

Programs from the book: [_Python for Natural Language Processing_](https://link.springer.com/book/9783031575488)

__Author__: Pierre Nugues

## Modules

In [1]:
import os
import regex as re
from collections import Counter
from math import log
from tqdm import tqdm
import functools

## Corpus Files

We read the files and we store the corpus in a string

In [2]:
PATH = '../datasets/'

In [3]:
CORPUS = 'HOMER'  # 'DICKENS'

In [4]:
if CORPUS == 'DICKENS':
    folder = PATH + 'dickens/'
elif CORPUS == 'HOMER':
    folder = PATH + 'classics/'

In [5]:
def get_files(dir, suffix):
    """
    Returns all the files in a folder ending with suffix
    :param dir:
    :param suffix:
    :return: the list of file names
    """
    files = []
    for file in os.listdir(dir):
        if file.endswith(suffix):
            files.append(file)
    return files

In [6]:
if CORPUS == 'DICKENS':
    files = get_files(folder, 'txt')
elif CORPUS == 'HOMER':
    files = ['iliad.txt', 'odyssey.txt']
files

['iliad.txt', 'odyssey.txt']

In [7]:
files = [folder + file for file in files]
files

['../datasets/classics/iliad.txt', '../datasets/classics/odyssey.txt']

In [8]:
text = ''
for file in files:
    with open(file, encoding='utf8') as f:
        text += ' ' + f.read().strip()

In [9]:
text[:100]

' BOOK I\n\nSing, O goddess, the anger of Achilles son of Peleus, that brought\ncountless ills upon the '

## Pretokenization

We pretokenize the text using the spaces as delimiters.
In BERT, simply `split()`: https://github.com/google-research/bert/blob/master/tokenization.py#L300-L359. Here we use a regex.

In [10]:
pattern = r'\p{P}|[^\s\p{P}]+'

In [11]:
words = [(match.group(), (match.start(), match.end()))
         for match in re.finditer(pattern, text)]

In [12]:
text.split()[:8]

['BOOK', 'I', 'Sing,', 'O', 'goddess,', 'the', 'anger', 'of']

In [13]:
words[:8]

[('BOOK', (1, 5)),
 ('I', (6, 7)),
 ('Sing', (9, 13)),
 (',', (13, 14)),
 ('O', (15, 16)),
 ('goddess', (17, 24)),
 (',', (24, 25)),
 ('the', (26, 29))]

In [14]:
def pretokenize(pattern, text):
    return re.findall(pattern, text)

In [15]:
words = pretokenize(pattern, text)

In [16]:
word_cnts = Counter(words)

In [17]:
word_cnts.most_common(5)

[(',', 19920), ('the', 15258), ('and', 11467), ('of', 8640), ('.', 8108)]

In [18]:
word_cnts['her']

1145

## Initial BPE

In [19]:
class BPE():
    def __init__(self, merge_cnt=200):
        self.merge_cnt = merge_cnt
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

    def fit(self, text):
        self._bpe_init(text)

        self.merge_ops = []
        for _ in range(self.merge_cnt):
            self._count_bigrams()
            self.best_pair = max(self.pair_cnts,
                                 key=self.pair_cnts.get)
            merge_op = list(self.best_pair)
            self.merge_ops.append(merge_op)
            for word_dict in self.words_bpe.values():
                word_dict['swords'] = self._merge_pair(
                    merge_op,
                    word_dict['swords'])
        self._build_vocab()

    def _build_vocab(self):
        swords = list(map(lambda x: ''.join(x), self.merge_ops))
        self.vocab += swords

    def encode(self, word):
        swords = list(word)
        for op in self.merge_ops:
            swords = self._merge_pair(op, swords)
        return swords

    def tokenize(self, text):
        tokenized_text = []
        cache = {}
        words = self.pretokenize(text)
        for word in words:
            if word not in cache:
                cache[word] = self.encode(word)
            subwords = cache[word]
            tokenized_text += subwords
        return tokenized_text

In [20]:
bpe = BPE()
bpe.fit(text)

In [21]:
bpe.words_bpe

{'▁BOOK': {'freq': 48, 'swords': ['▁', 'B', 'O', 'O', 'K']},
 '▁I': {'freq': 3194, 'swords': ['▁I']},
 '▁Sing': {'freq': 2, 'swords': ['▁S', 'ing']},
 '▁,': {'freq': 19920, 'swords': ['▁,']},
 '▁O': {'freq': 77, 'swords': ['▁', 'O']},
 '▁goddess': {'freq': 112, 'swords': ['▁go', 'd', 'd', 'es', 's']},
 '▁the': {'freq': 15258, 'swords': ['▁the']},
 '▁anger': {'freq': 74, 'swords': ['▁an', 'g', 'er']},
 '▁of': {'freq': 8640, 'swords': ['▁of']},
 '▁Achilles': {'freq': 440, 'swords': ['▁Ach', 'ill', 'es']},
 '▁son': {'freq': 1246, 'swords': ['▁son']},
 '▁Peleus': {'freq': 145, 'swords': ['▁P', 'e', 'le', 'us']},
 '▁that': {'freq': 2558, 'swords': ['▁that']},
 '▁brought': {'freq': 208, 'swords': ['▁br', 'ou', 'ght']},
 '▁countless': {'freq': 8, 'swords': ['▁c', 'ou', 'n', 't', 'l', 'es', 's']},
 '▁ills': {'freq': 5, 'swords': ['▁', 'ill', 's']},
 '▁upon': {'freq': 792, 'swords': ['▁up', 'on']},
 '▁Achaeans': {'freq': 601, 'swords': ['▁Ach', 'ae', 'ans']},
 '▁.': {'freq': 8108, 'swords': ['▁

In [22]:
bpe.merge_ops[:5]

[['▁', 't'], ['h', 'e'], ['▁', 'a'], ['▁t', 'he'], ['▁', 's']]

## Segmentation with BPE: Initial Distribution of the Language Model

In [23]:
bpe.encode('therefore')

['t', 'he', 're', 'f', 'ore']

In [24]:
bpe.pretokenize('Sit careless in the shade!')

['▁Sit', '▁careless', '▁in', '▁the', '▁shade', '▁!']

In [25]:
words = bpe.tokenize('Sit careless in the shade!')
words

['▁S',
 'it',
 '▁c',
 'a',
 're',
 'l',
 'es',
 's',
 '▁in',
 '▁the',
 '▁sh',
 'ad',
 'e',
 '▁',
 '!']

## Initial Distribution

In [26]:
tokens = bpe.tokenize(text)

The negative log likelihood

In [27]:
def calc_nll(tokens):
    token_cnts = Counter(tokens)
    total_cnt = token_cnts.total()
    uni_probs = {token: -log(cnt/total_cnt) for
                 token, cnt in token_cnts.items()}
    return uni_probs

In [28]:
uni_probs_bpe = calc_nll(tokens)
uni_probs_bpe['her']

5.880767285455809

In [29]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

In [30]:
unigram = Unigram(uni_probs_bpe)

In [31]:
unigram.uni_probs['▁t']

5.105016576360653

## Segmentation with the Language Model: Brute Force

Here we see how to split a word in all possible ways

In [32]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

In [33]:
Unigram.split_word('there', '0110')

['th', 'e', 're']

In [34]:
word = '▁there'

In [35]:
cnt_sp = len(word) - 1
candidates = []
for i in range(2**cnt_sp):
    splitpoints = f'{i:0{cnt_sp}b}'
    candidates += [Unigram.split_word(word, splitpoints)]
candidates

[['▁there'],
 ['▁ther', 'e'],
 ['▁the', 're'],
 ['▁the', 'r', 'e'],
 ['▁th', 'ere'],
 ['▁th', 'er', 'e'],
 ['▁th', 'e', 're'],
 ['▁th', 'e', 'r', 'e'],
 ['▁t', 'here'],
 ['▁t', 'her', 'e'],
 ['▁t', 'he', 're'],
 ['▁t', 'he', 'r', 'e'],
 ['▁t', 'h', 'ere'],
 ['▁t', 'h', 'er', 'e'],
 ['▁t', 'h', 'e', 're'],
 ['▁t', 'h', 'e', 'r', 'e'],
 ['▁', 'there'],
 ['▁', 'ther', 'e'],
 ['▁', 'the', 're'],
 ['▁', 'the', 'r', 'e'],
 ['▁', 'th', 'ere'],
 ['▁', 'th', 'er', 'e'],
 ['▁', 'th', 'e', 're'],
 ['▁', 'th', 'e', 'r', 'e'],
 ['▁', 't', 'here'],
 ['▁', 't', 'her', 'e'],
 ['▁', 't', 'he', 're'],
 ['▁', 't', 'he', 'r', 'e'],
 ['▁', 't', 'h', 'ere'],
 ['▁', 't', 'h', 'er', 'e'],
 ['▁', 't', 'h', 'e', 're'],
 ['▁', 't', 'h', 'e', 'r', 'e']]

We compute all the probabilities and we keep the min negative log likelihood (NLL)

In [36]:
min([(cand,
      sum(map(lambda x: uni_probs_bpe.get(x, 1000), cand)))
     for cand in candidates], key=lambda x: x[1])

(['▁there'], 6.563964135162586)

In [37]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

    def encode(self, word):
        cnt_sp = len(word) - 1
        if cnt_sp > 20:
            return list(word)
        candidates = []
        for i in range(2**cnt_sp):
            splitpoints = f'{i:0{cnt_sp}b}'
            candidates += [Unigram.split_word(word, splitpoints)]
        return min(
            [(cand,
              sum(map(lambda x: self.uni_probs.get(x, 1000), cand))
              ) for cand in candidates],
            key=lambda x: x[1])

In [38]:
unigram = Unigram(uni_probs_bpe)

In [39]:
unigram.encode('there')

(['ther', 'e'], 10.395533624170854)

In [40]:
unigram.encode('▁there')

(['▁there'], 6.563964135162586)

In [41]:
unigram.encode('▁thexrexx')

(['▁the', 'x', 're', 'x', 'x'], 28.163451667610353)

In [42]:
unigram.encode('▁with')

(['▁with'], 5.3599913308366505)

We add tokenize

In [43]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

    def encode(self, word):
        cnt_sp = len(word) - 1
        if cnt_sp > 20:
            return list(word)
        candidates = []
        for i in range(2**cnt_sp):
            splitpoints = f'{i:0{cnt_sp}b}'
            candidates += [Unigram.split_word(word, splitpoints)]
        return min(
            [(cand,
              sum(map(lambda x: self.uni_probs.get(x, 1000), cand))
              ) for cand in candidates],
            key=lambda x: x[1])

    def tokenize(self, text):
        cache = {}
        nll_text = 0.0
        tokenized_text = []
        words = self.pretokenize(text)
        for word in tqdm(words):
            if word not in cache:
                cache[word] = self.encode(word)
            subwords, nll = cache[word]
            tokenized_text += subwords
            nll_text += nll
        return tokenized_text, nll_text

In [44]:
unigram = Unigram(uni_probs_bpe)
unigram.encode('there')

(['ther', 'e'], 10.395533624170854)

In [45]:
unigram.tokenize('Sit careless in the shade')

100%|██████████| 5/5 [00:00<00:00, 10911.30it/s]


(['▁S',
  'it',
  '▁c',
  'a',
  're',
  'le',
  's',
  's',
  '▁in',
  '▁the',
  '▁sh',
  'ad',
  'e'],
 61.740722053383436)

We tokenize the text with the brute-force algorithm. Crashes half-way with the dickens corpus if there is not length limit

In [46]:
unigram = Unigram(uni_probs_bpe)
tokens_bf = unigram.tokenize(text)[0]

100%|██████████| 312886/312886 [00:05<00:00, 55101.10it/s] 


In [47]:
tokens_bf

['▁',
 'B',
 'O',
 'O',
 'K',
 '▁I',
 '▁S',
 'ing',
 '▁,',
 '▁',
 'O',
 '▁go',
 'd',
 'd',
 'es',
 's',
 '▁,',
 '▁the',
 '▁an',
 'g',
 'er',
 '▁of',
 '▁Ach',
 'ill',
 'es',
 '▁son',
 '▁of',
 '▁P',
 'e',
 'le',
 'us',
 '▁,',
 '▁that',
 '▁br',
 'ough',
 't',
 '▁c',
 'ou',
 'n',
 't',
 'le',
 's',
 's',
 '▁',
 'ill',
 's',
 '▁up',
 'on',
 '▁the',
 '▁Ach',
 'ae',
 'ans',
 '▁.',
 '▁M',
 'an',
 'y',
 '▁a',
 '▁br',
 'a',
 've',
 '▁s',
 'ou',
 'l',
 '▁d',
 'id',
 '▁it',
 '▁se',
 'nd',
 '▁h',
 'ur',
 'r',
 'y',
 'ing',
 '▁d',
 'ow',
 'n',
 '▁to',
 '▁H',
 'ad',
 'es',
 '▁,',
 '▁and',
 '▁man',
 'y',
 '▁a',
 '▁her',
 'o',
 '▁d',
 'id',
 '▁it',
 '▁y',
 'ie',
 'ld',
 '▁a',
 '▁p',
 're',
 'y',
 '▁to',
 '▁do',
 'g',
 's',
 '▁and',
 '▁',
 'v',
 'u',
 'l',
 't',
 'ur',
 'es',
 '▁,',
 '▁for',
 '▁so',
 '▁were',
 '▁the',
 '▁c',
 'ou',
 'n',
 'sel',
 's',
 '▁of',
 '▁',
 'J',
 'ove',
 '▁f',
 'u',
 'l',
 'f',
 'ill',
 'ed',
 '▁from',
 '▁the',
 '▁d',
 'ay',
 '▁on',
 '▁whi',
 'ch',
 '▁the',
 '▁son',
 '▁of',
 '▁

## Segmentation with the Language Model: Viterbi

In [48]:
uni_probs_bpe['▁t'], uni_probs_bpe['h'], uni_probs_bpe['▁th']

(5.105016576360653, 5.245661261179482, 5.56798617894901)

In [49]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

    def encode_bf(self, word):
        cnt_sp = len(word) - 1
        if cnt_sp > 20:
            return list(word)
        candidates = []
        for i in range(2**cnt_sp):
            splitpoints = f'{i:0{cnt_sp}b}'
            candidates += [Unigram.split_word(word, splitpoints)]
        return min(
            [(cand,
              sum(map(lambda x: self.uni_probs.get(x, 1000), cand))
              ) for cand in candidates],
            key=lambda x: x[1])

    def encode(self, word):
        n = len(word)
        swords = [word[:i] for i in range(1, n + 1)]
        min_nlls = [self.uni_probs.get(sword, 1000.0) for sword in swords]

        for i in range(2, n + 1):
            for j in range(1, i):
                sword = word[j:i]
                nll = self.uni_probs.get(sword, 1000.0) + min_nlls[j - 1]
                if min_nlls[i - 1] > nll:
                    min_nlls[i - 1] = nll
                    swords[i - 1] = sword
        # backtrace
        final_swords = [swords.pop()]
        while True:
            for i in range(len(final_swords[-1]) - 1):
                swords.pop()
            if swords:
                final_swords += [swords.pop()]
            else:
                break
        return final_swords[::-1], min_nlls[-1]

    def tokenize(self, text):
        cache = {}
        nll_text = 0.0
        tokenized_text = []
        words = self.pretokenize(text)
        for word in tqdm(words):
            if word not in cache:
                cache[word] = self.encode(word)
            subwords, nll = cache[word]
            tokenized_text += subwords
            nll_text += nll
        return tokenized_text, nll_text

In [50]:
unigram = Unigram(uni_probs_bpe)

In [51]:
unigram.encode('there')

(['ther', 'e'], 10.395533624170854)

In [52]:
unigram.tokenize('Sit careless in the shade')

100%|██████████| 5/5 [00:00<00:00, 79137.81it/s]


(['▁S',
  'it',
  '▁c',
  'a',
  're',
  'le',
  's',
  's',
  '▁in',
  '▁the',
  '▁sh',
  'ad',
  'e'],
 61.740722053383436)

In [53]:
unigram.encode('▁thexrexfore')

(['▁the', 'x', 're', 'x', 'f', 'ore'], 32.56230540341623)

In [54]:
unigram.encode('▁with')

(['▁with'], 5.3599913308366505)

In [55]:
tokens_viterbi = unigram.tokenize(text)[0]

100%|██████████| 312886/312886 [00:00<00:00, 1601948.94it/s]


We check the tokenization of Viterbi and brute force. This should be true if the length of the words is less than 20 chars. 

In [56]:
tokens_viterbi == tokens_bf

True

## Norvig's Method
See reference here: https://github.com/norvig/pytudes/blob/main/ipynb/How%20to%20Do%20Things%20with%20Words.ipynb

In [57]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

    def encode_bf(self, word):
        cnt_sp = len(word) - 1
        if cnt_sp > 20:
            return list(word)
        candidates = []
        for i in range(2**cnt_sp):
            splitpoints = f'{i:0{cnt_sp}b}'
            candidates += [Unigram.split_word(word, splitpoints)]
        return min(
            [(cand,
              sum(map(lambda x: self.uni_probs.get(x, 1000), cand))
              ) for cand in candidates],
            key=lambda x: x[1])

    def encode_viterbi(self, word):
        n = len(word)
        swords = [word[:i] for i in range(1, n + 1)]
        min_nlls = [self.uni_probs.get(sword, 1000.0) for sword in swords]

        for i in range(2, n + 1):
            for j in range(1, i):
                sword = word[j:i]
                nll = self.uni_probs.get(sword, 1000.0) + min_nlls[j - 1]
                if min_nlls[i - 1] > nll:
                    min_nlls[i - 1] = nll
                    swords[i - 1] = sword
        # backtrace
        final_swords = [swords.pop()]
        while True:
            for i in range(len(final_swords[-1]) - 1):
                swords.pop()
            if swords:
                final_swords += [swords.pop()]
            else:
                break
        return final_swords[::-1], min_nlls[-1]

    def encode(self, char_seq):
        # Use one of the two cache functions below to have a faster answer:
        # @functools.lru_cache(maxsize=2**10)
        @functools.cache  # Available from Python 3.9
        # The arguments of the cached function must be hashable that's why we define an inner cacheable function
        def __tokenize_lm(char_seq):
            # Write your code here
            if not char_seq:
                return [], 0.0
            splits = [(char_seq[:i + 1], char_seq[i + 1:])
                      for i in range(len(char_seq))]
            candidates = []
            for first, rest in splits:
                first_prob = self.uni_probs.get(first, 1000.0)
                rest, rest_prob = __tokenize_lm(rest)
                candidates.append(([first] + rest, first_prob + rest_prob))
            return min(candidates, key=lambda x: x[1])

        return __tokenize_lm(char_seq)

    def tokenize(self, text):
        cache = {}
        nll_text = 0.0
        tokenized_text = []
        words = self.pretokenize(text)
        for word in tqdm(words):
            if word not in cache:
                cache[word] = self.encode(word)
            subwords, nll = cache[word]
            tokenized_text += subwords
            nll_text += nll
        return tokenized_text, nll_text

In [58]:
unigram = Unigram(uni_probs_bpe)

In [59]:
unigram.encode('therefore')

(['ther', 'e', 'f', 'ore'], 21.390774681844547)

In [60]:
unigram.encode('withx')

(['w', 'ith', 'x'], 19.49583823140657)

We tokenize with the BPE distribution all the words in the text

In [61]:
tokens_norvig = unigram.tokenize(text)[0]

100%|██████████| 312886/312886 [00:00<00:00, 1116244.91it/s]


We check the tokenization of the two techniques. This should be true.

In [62]:
tokens_norvig == tokens_viterbi

True

## Expectation-Maximization

In [63]:
uni_probs_2 = calc_nll(tokens_viterbi)

The distribution has changed

In [64]:
uni_probs_2

{'▁': 3.8357429341950673,
 'B': 7.31475192396911,
 'O': 6.840086281617563,
 'K': 8.12685104739042,
 '▁I': 5.049053820782126,
 '▁S': 6.556044548096757,
 'ing': 4.537065237152543,
 '▁,': 3.474395588412318,
 '▁go': 5.599019352885387,
 'd': 5.278581342706259,
 'es': 5.244110673756736,
 's': 3.6623958893982778,
 '▁the': 3.6736672184002113,
 '▁an': 5.969596001513639,
 'g': 4.920900500461321,
 'er': 4.778240501828107,
 '▁of': 4.243335817778279,
 '▁Ach': 6.3808599966179465,
 'ill': 5.860711574316832,
 '▁son': 6.086998707800207,
 '▁P': 6.08017739894947,
 'e': 3.8552415512188336,
 'le': 4.9411510847611195,
 'us': 4.715877051543648,
 '▁that': 5.5265032833911185,
 '▁br': 6.304000991092335,
 'ough': 6.220041317972064,
 't': 4.126528021945629,
 '▁c': 4.557614910816325,
 'ou': 5.845005862908656,
 'n': 4.452017140197274,
 '▁up': 6.022716893120213,
 'on': 4.940063537073721,
 'ae': 6.5100717280979525,
 'ans': 5.949113357727698,
 '▁.': 4.373268611979016,
 '▁M': 6.073402305283108,
 'an': 5.095446860351836

In [65]:
uni_probs_2['▁with']

5.35623795964243

In [66]:
uni_probs_bpe['▁with']

5.3599913308366505

In [67]:
class Unigram():
    def __init__(self, uni_probs):
        self.uni_probs = uni_probs
        self.pattern = r'\p{P}|[^\s\p{P}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        words = list(map(lambda x: '▁' + x, words))
        return words

    @staticmethod
    def split_word(string, splitpoints):
        subwords = []
        prev_sp = 0
        for i, sp in enumerate(splitpoints, start=1):
            if sp == '1':
                subword = string[prev_sp:i]
                prev_sp = i
                subwords.append(subword)
        subword = string[prev_sp:]
        subwords.append(subword)
        return subwords

    def encode_bf(self, word):
        cnt_sp = len(word) - 1
        if cnt_sp > 20:
            return list(word)
        candidates = []
        for i in range(2**cnt_sp):
            splitpoints = f'{i:0{cnt_sp}b}'
            candidates += [Unigram.split_word(word, splitpoints)]
        return min(
            [(cand,
              sum(map(lambda x: self.uni_probs.get(x, 1000), cand))
              ) for cand in candidates],
            key=lambda x: x[1])

    def encode(self, word):
        n = len(word)
        swords = [word[:i] for i in range(1, n + 1)]
        min_nlls = [self.uni_probs.get(sword, 1000.0) for sword in swords]

        for i in range(2, n + 1):
            for j in range(1, i):
                sword = word[j:i]
                nll = self.uni_probs.get(sword, 1000.0) + min_nlls[j - 1]
                if min_nlls[i - 1] > nll:
                    min_nlls[i - 1] = nll
                    swords[i - 1] = sword
        # backtrace
        final_swords = [swords.pop()]
        while True:
            for i in range(len(final_swords[-1]) - 1):
                swords.pop()
            if swords:
                final_swords += [swords.pop()]
            else:
                break
        return final_swords[::-1], min_nlls[-1]

    def encode_norvig(self, char_seq):
        # Use one of the two cache functions below to have a faster answer:
        # @functools.lru_cache(maxsize=2**10)
        @functools.cache  # Available from Python 3.9
        # The arguments of the cached function must be hashable that's why we define an inner cacheable function
        def __tokenize_lm(char_seq):
            # Write your code here
            if not char_seq:
                return [], 0.0
            splits = [(char_seq[:i + 1], char_seq[i + 1:])
                      for i in range(len(char_seq))]
            candidates = []
            for first, rest in splits:
                first_prob = self.uni_probs.get(first, 1000.0)
                rest, rest_prob = __tokenize_lm(rest)
                candidates.append(([first] + rest, first_prob + rest_prob))
            return min(candidates, key=lambda x: x[1])

        return __tokenize_lm(char_seq)

    def tokenize(self, text):
        cache = {}
        nll_text = 0.0
        tokenized_text = []
        words = self.pretokenize(text)
        for word in words:
            if word not in cache:
                cache[word] = self.encode(word)
            subwords, nll = cache[word]
            tokenized_text += subwords
            nll_text += nll
        return tokenized_text, nll_text

The likelihood from BPE

In [68]:
def em(text, uni_probs_old):
    cache = {}
    tokens = []
    unigram = Unigram(uni_probs_old)
    words = unigram.pretokenize(text)
    for word in words:
        if word not in cache:
            cache[word] = unigram.encode(word)[0]
        tokens += cache[word]
    uni_probs_new = calc_nll(tokens)
    return uni_probs_new

In [69]:
uni_probs = dict(uni_probs_bpe)
print(list(map(uni_probs.get, ['▁,', '▁the', 's', 'e', '▁and'])))
for _ in range(5):
    uni_probs = em(text, uni_probs)
    print(list(map(uni_probs.get, ['▁,', '▁the', 's', 'e', '▁and'])))

[3.476829698201476, 3.6761013281893695, 3.6892379073687205, 3.8798880662790753, 4.029080605301761]
[3.474395588412318, 3.6736672184002113, 3.6623958893982778, 3.8552415512188336, 4.026646495512603]
[3.4745775359512794, 3.6738491659391728, 3.662759605546301, 3.83355001105575, 4.026828443051564]
[3.474705035010758, 3.6739766649986514, 3.6668943959428217, 3.832743501353862, 4.026955942111043]
[3.4747128088178476, 3.673984438805741, 3.6665980242443585, 3.832392273182586, 4.026963715918133]
[3.4747128088178476, 3.673984438805741, 3.6665980242443585, 3.832392273182586, 4.026963715918133]


In [70]:
len(uni_probs_bpe)

255

In [71]:
len(uni_probs)

254

In [72]:
uni_probs['▁the']

3.673984438805741

## The Final Tokenization with Stable Estimates

In [73]:
unigram = Unigram(uni_probs)

In [74]:
unigram.encode('txxhere')

(['t', 'x', 'x', 'her', 'e'], 27.95207566991922)

In [75]:
unigram.tokenize('Sit carexxless in the shade')

(['▁S',
  'it',
  '▁c',
  'a',
  're',
  'x',
  'x',
  'le',
  's',
  's',
  '▁in',
  '▁the',
  '▁sh',
  'ad',
  'e'],
 74.93529156047747)

In [76]:
unigram.tokenize('here')

(['▁he', 're'], 9.512558170433628)

In [77]:
unigram.tokenize('Therefore')

(['▁The', 're', 'f', 'ore'], 22.015487824561927)

In [78]:
unigram.tokenize('Sit careless in the shade')

(['▁S',
  'it',
  '▁c',
  'a',
  're',
  'le',
  's',
  's',
  '▁in',
  '▁the',
  '▁sh',
  'ad',
  'e'],
 61.7467506955091)

In [79]:
swords_text, nll = unigram.tokenize(text)
swords_text, nll

(['▁',
  'B',
  'O',
  'O',
  'K',
  '▁I',
  '▁S',
  'ing',
  '▁,',
  '▁',
  'O',
  '▁go',
  'd',
  'd',
  'es',
  's',
  '▁,',
  '▁the',
  '▁an',
  'g',
  'er',
  '▁of',
  '▁Ach',
  'ill',
  'es',
  '▁son',
  '▁of',
  '▁P',
  'e',
  'le',
  'us',
  '▁,',
  '▁that',
  '▁br',
  'ough',
  't',
  '▁c',
  'ou',
  'n',
  't',
  'le',
  's',
  's',
  '▁',
  'ill',
  's',
  '▁up',
  'on',
  '▁the',
  '▁Ach',
  'ae',
  'ans',
  '▁.',
  '▁M',
  'an',
  'y',
  '▁a',
  '▁br',
  'a',
  've',
  '▁s',
  'ou',
  'l',
  '▁d',
  'id',
  '▁it',
  '▁se',
  'nd',
  '▁h',
  'ur',
  'r',
  'y',
  'ing',
  '▁d',
  'ow',
  'n',
  '▁to',
  '▁H',
  'ad',
  'es',
  '▁,',
  '▁and',
  '▁man',
  'y',
  '▁a',
  '▁her',
  'o',
  '▁d',
  'id',
  '▁it',
  '▁y',
  'ie',
  'ld',
  '▁a',
  '▁p',
  're',
  'y',
  '▁to',
  '▁do',
  'g',
  's',
  '▁and',
  '▁',
  'v',
  'u',
  'l',
  't',
  'ur',
  'es',
  '▁,',
  '▁for',
  '▁so',
  '▁were',
  '▁the',
  '▁c',
  'ou',
  'n',
  'sel',
  's',
  '▁of',
  '▁',
  'J',
  'ove',
  '

## Final Vocabulary

In [80]:
len(uni_probs)

254

In [81]:
swords_excl_one = []
for token in tqdm(uni_probs):
    if len(token) == 1:
        continue
    if len(token) == 2 and token[0] == '▁':
        continue
    sword_excl = (token, uni_probs[token])
    uni_prob_temp = dict(uni_probs)
    uni_prob_temp.pop(token)
    unigram = Unigram(uni_prob_temp)
    nll_text = unigram.tokenize(text)[1]
    swords_excl_one += [[sword_excl, nll_text]]
swords_excl_one

100%|██████████| 254/254 [00:31<00:00,  7.98it/s]


[[('ing', 4.537382457558073), 3306107.5517700408],
 [('▁go', 5.548547607736448), 3271188.935935537],
 [('es', 5.233293879348584), 3271601.596745802],
 [('▁the', 3.673984438805741), 3355981.5949805058],
 [('▁an', 5.969913221919169), 3268895.6580680325],
 [('er', 4.781520687364294), 3280204.038174003],
 [('▁of', 4.243653038183809), 3325396.117282533],
 [('▁Ach', 6.381177217023477), 3269851.5430730497],
 [('ill', 5.861028794722362), 3273985.120758836],
 [('▁son', 6.087315928205737), 3270410.2896017404],
 [('le', 4.916961489712882), 3279223.692795845],
 [('us', 4.71116870032866), 3283792.1910348255],
 [('▁that', 5.526820503796648), 3278226.1662566448],
 [('▁br', 6.304318211497865), 3268232.7728315843],
 [('ough', 6.220358538377594), 3272541.782391051],
 [('ou', 5.845323083314186), 3267246.4764114157],
 [('▁up', 6.0230341135257435), 3275177.2513445555],
 [('on', 4.965698565463541), 3279495.999228436],
 [('ae', 6.510388948503483), 3266185.403442203],
 [('ans', 5.949430578133228), 3269748.624

In [82]:
sorted(swords_excl_one, key=lambda x: x[1])[:10]

[[('ip', 7.9103605349308275), 3265385.8417472467],
 [('he', 7.772073519076736), 3265408.9012203133],
 [('▁fr', 7.3150691443746405), 3265530.8823400135],
 [('ith', 7.931774629434644), 3265554.1070309104],
 [('her', 6.797722770908213), 3265627.4887778256],
 [('gh', 7.721703159687786), 3265777.7169258883],
 [('ght', 7.578134589191065), 3265844.3613296044],
 [('ae', 6.510388948503483), 3266185.403442203],
 [('ot', 6.692083742506628), 3266193.1468482157],
 [('hen', 7.452613920312621), 3266195.8088417994]]