# Chapter 13: Subword Segmentation
## Subword Tokenizers: BPE
The implementation of BPE from the Gage and Sennrich papers.

A valuable complete implementation of the algorithm: https://github.com/karpathy/minGPT/blob/master/mingpt/bpe.py

Programs from the book: [_Python for Natural Language Processing_](https://link.springer.com/book/9783031575488)

__Author__: Pierre Nugues

In [1]:
import os
import regex as re
from collections import Counter

## Corpus Files

We read the files and we store the corpus in a string

In [2]:
PATH = '../datasets/'

In [3]:
CORPUS = 'HOMER'  # 'DICKENS'

In [4]:
if CORPUS == 'DICKENS':
    folder = PATH + 'dickens/'
elif CORPUS == 'HOMER':
    folder = PATH + 'classics/'

In [5]:
def get_files(dir, suffix):
    """
    Returns all the files in a folder ending with suffix
    :param dir:
    :param suffix:
    :return: the list of file names
    """
    files = []
    for file in os.listdir(dir):
        if file.endswith(suffix):
            files.append(file)
    return files

In [6]:
if CORPUS == 'DICKENS':
    files = get_files(folder, 'txt')
elif CORPUS == 'HOMER':
    files = ['iliad.txt', 'odyssey.txt']
files

['iliad.txt', 'odyssey.txt']

In [7]:
files = [folder + file for file in files]
files

['../datasets/classics/iliad.txt', '../datasets/classics/odyssey.txt']

In [8]:
text = ''
for file in files:
    with open(file, encoding='utf8') as f:
        text += ' ' + f.read().strip()

In [9]:
text[:100]

' BOOK I\n\nSing, O goddess, the anger of Achilles son of Peleus, that brought\ncountless ills upon the '

## Pretokenization

We pretokenize the text using the spaces as delimiters. First try the `'simple'` tokenization, where the word start is not visible, then the `'spaces'` where leading spaces will show as `'Ġ'`.

In [10]:
pretokenization = 'simple'
if pretokenization == 'simple':
    pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'
elif pretokenization == 'spaces':
    pattern = r' ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+'
elif pretokenization == 'karpathy':
    pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [11]:
words = [(match.group(), (match.start(), match.end()))
         for match in re.finditer(pattern, text)]

In [12]:
words[:8]

[('BOOK', (1, 5)),
 ('I', (6, 7)),
 ('Sing', (9, 13)),
 (',', (13, 14)),
 ('O', (15, 16)),
 ('goddess', (17, 24)),
 (',', (24, 25)),
 ('the', (26, 29))]

In [13]:
def pretokenize_1(pattern, text):
    return re.findall(pattern, text)

In [14]:
def pretokenize_2(pattern, text):
    words = re.findall(pattern, text)
    return [''.join(('Ġ', word[1:]))
            if word[0] == ' ' else word
            for word in words]

In [15]:
if pretokenization == 'simple':
    pretokenize = pretokenize_1
else:
    pretokenize = pretokenize_2

In [16]:
words = pretokenize(pattern, text)

In [17]:
word_cnts = Counter(words)

In [18]:
word_cnts.most_common(5)

[(',', 19410), ('the', 15258), ('and', 11467), ('of', 8640), ('.', 6839)]

In [19]:
word_cnts['her']

1145

In [20]:
pat = r'\p{L}+|\p{N}+|\p{P}|[^\s\p{L}\p{N}\p{P}]+'

In [21]:
re.findall(pat, 'Wait... here!')

['Wait', '.', '.', '.', 'here', '!']

## Class with Pretokenization

In [22]:
class BPE():
    def __init__(self):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

In [23]:
bpe = BPE()
words = bpe.pretokenize(text)
word_cnts = Counter(words)

In [24]:
word_cnts.most_common(5)

[(',', 19410), ('the', 15258), ('and', 11467), ('of', 8640), ('.', 6839)]

## Initial Vocabulary
We create a second dictionary to count the subword tokens. At each iteration, the keys will store the subtokens.

In [25]:
class BPE():
    def __init__(self):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                 for char in self.words_bpe[word]['swords']]))

In [26]:
bpe = BPE()
bpe._bpe_init(text)

In [27]:
bpe.words_bpe['her']

{'freq': 1145, 'swords': ['h', 'e', 'r']}

## Counting Pairs

We count the bigrams from the `swords` lists

In [28]:
class BPE():
    def __init__(self):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

In [29]:
bpe = BPE()
bpe._bpe_init(text)
bpe._count_bigrams()

In [30]:
max(bpe.pair_cnts, key=bpe.pair_cnts.get)

('h', 'e')

## Merging a Pair

We merge a pair in a sequence of subwords. The structure of the pair is a list as in: `['h', 'e']`. `swords` is also a list.

In [31]:
class BPE():
    def __init__(self):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

In [32]:
bpe = BPE()
bpe._bpe_init(text)
bpe._count_bigrams()

In [33]:
bpe._merge_pair(['h', 'e'], ['t', 'h', 'e', 'y'])

['t', 'he', 'y']

## The Loop

In [34]:
class BPE():
    def __init__(self, merge_cnt=200):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'
        self.merge_cnt = merge_cnt

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

    def fit(self, text):
        self._bpe_init(text)

        self.merge_ops = []
        for _ in range(self.merge_cnt):
            self._count_bigrams()
            self.best_pair = max(self.pair_cnts,
                                 key=self.pair_cnts.get)
            merge_op = list(self.best_pair)
            self.merge_ops.append(merge_op)
            for word_dict in self.words_bpe.values():
                word_dict['swords'] = self._merge_pair(
                    merge_op,
                    word_dict['swords'])
        self._build_vocab()

    def _build_vocab(self):
        swords = list(map(lambda x: ''.join(x), self.merge_ops))
        self.vocab += swords

In [35]:
bpe = BPE()
bpe.fit(text)

In [36]:
bpe.merge_ops[:5]

[['h', 'e'], ['t', 'he'], ['a', 'n'], ['i', 'n'], ['o', 'u']]

In [37]:
len(bpe.vocab)

266

In [38]:
bpe.vocab[:10]

['v', 'r', 'W', 'T', 'Q', 'R', '[', '!', 'c', 'A']

## Encoding a Word

We apply the rules in the same order.

We can define a vocabulary consisting of the characters in the training corpus and the subwords. The characters outside this set will be mapped to `'UNK'`. Otherwise the initial vocabulary consists of all the Unicode characters. 

In [39]:
class BPE():
    def __init__(self, merge_cnt=200):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'
        self.merge_cnt = merge_cnt

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

    def fit(self, text):
        self._bpe_init(text)

        self.merge_ops = []
        for _ in range(self.merge_cnt):
            self._count_bigrams()
            self.best_pair = max(self.pair_cnts,
                                 key=self.pair_cnts.get)
            merge_op = list(self.best_pair)
            self.merge_ops.append(merge_op)
            for word_dict in self.words_bpe.values():
                word_dict['swords'] = self._merge_pair(
                    merge_op,
                    word_dict['swords'])
        self._build_vocab()

    def _build_vocab(self):
        swords = list(map(lambda x: ''.join(x), self.merge_ops))
        self.vocab += swords

    def encode(self, word):
        swords = list(word)
        for op in self.merge_ops:
            swords = self._merge_pair(op, swords)
        return swords

In [40]:
bpe = BPE()
bpe.fit(text)

In [41]:
bpe.encode('therefore')

['there', 'fore']

## Tokenizing a Whole Text
We can now write the complete subword tokenization function. We use a cache to speed up the search

In [42]:
class BPE():
    def __init__(self, merge_cnt=200):
        self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'
        self.merge_cnt = merge_cnt

    def pretokenize(self, text):
        return re.findall(self.pattern, text)

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

    def fit(self, text):
        self._bpe_init(text)

        self.merge_ops = []
        for _ in range(self.merge_cnt):
            self._count_bigrams()
            self.best_pair = max(self.pair_cnts,
                                 key=self.pair_cnts.get)
            merge_op = list(self.best_pair)
            self.merge_ops.append(merge_op)
            for word_dict in self.words_bpe.values():
                word_dict['swords'] = self._merge_pair(
                    merge_op,
                    word_dict['swords'])
        self._build_vocab()

    def _build_vocab(self):
        swords = list(map(lambda x: ''.join(x), self.merge_ops))
        self.vocab += swords

    def encode(self, word):
        swords = list(word)
        for op in self.merge_ops:
            swords = self._merge_pair(op, swords)
        return swords

    def tokenize(self, text):
        tokenized_text = []
        cache = {}
        words = self.pretokenize(text)
        for word in words:
            if word not in cache:
                cache[word] = self.encode(word)
            subwords = cache[word]
            tokenized_text += subwords
        return tokenized_text

In [43]:
bpe = BPE()
bpe.fit(text)

In [44]:
ecloges_str = """Sit careless in the shade"""

In [45]:
bpe.tokenize(ecloges_str)

['S', 'it', 'c', 'are', 'le', 's', 's', 'in', 'the', 's', 'had', 'e']

## Leading Whitespaces

In [46]:
pattern = r' ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+'

Karpathy

In [47]:
# pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [48]:
class BPE():
    def __init__(self, merge_cnt=200, leading_space=False):
        self.merge_cnt = merge_cnt
        self.leading_space = leading_space
        if leading_space:
            self.pattern = r' ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+'
        else:
            self.pattern = r'\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+'

    def pretokenize(self, text):
        words = re.findall(self.pattern, text)
        if self.leading_space:
            words = [''.join(('Ġ', word[1:]))
                     if word[0] == ' ' else word
                     for word in words]
        return words

    def _bpe_init(self, text):
        words = self.pretokenize(text)
        word_cnts = Counter(words)
        self.words_bpe = {
            word: {'freq': freq,
                   'swords': list(word)}
            for word, freq in word_cnts.items()}
        self.vocab = list(
            set([char for word in self.words_bpe
                for char in self.words_bpe[word]['swords']]))

    def _count_bigrams(self):
        self.pair_cnts = Counter()
        for word_dict in self.words_bpe.values():
            swords = tuple(word_dict['swords'])
            freq = word_dict['freq']
            for i in range(len(swords) - 1):
                self.pair_cnts[swords[i:i + 2]] += freq

    def _merge_pair(self, pair, swords):
        pair_str = ''.join(pair)
        i = 0
        temp = []
        while i < len(swords) - 1:
            if pair == swords[i:i + 2]:
                temp += [pair_str]
                i += 2
            else:
                temp += [swords[i]]
                i += 1
        if i == len(swords) - 1:
            temp += [swords[i]]
        swords = temp
        return swords

    def fit(self, text):
        self._bpe_init(text)

        self.merge_ops = []
        for _ in range(self.merge_cnt):
            self._count_bigrams()
            self.best_pair = max(self.pair_cnts,
                                 key=self.pair_cnts.get)
            merge_op = list(self.best_pair)
            self.merge_ops.append(merge_op)
            for word_dict in self.words_bpe.values():
                word_dict['swords'] = self._merge_pair(
                    merge_op,
                    word_dict['swords'])
        self._build_vocab()

    def _build_vocab(self):
        swords = list(map(lambda x: ''.join(x), self.merge_ops))
        self.vocab += swords

    def encode(self, word):
        swords = list(word)
        for op in self.merge_ops:
            swords = self._merge_pair(op, swords)
        return swords

    def tokenize(self, text):
        tokenized_text = []
        cache = {}
        words = self.pretokenize(text)
        for word in words:
            if word not in cache:
                cache[word] = self.encode(word)
            subwords = cache[word]
            tokenized_text += subwords
        return tokenized_text

In [49]:
bpe = BPE(leading_space=True)
bpe.fit(text)

In [50]:
bpe.tokenize('Sit careless in the shade')

['S', 'it', 'Ġc', 'a', 're', 'l', 'ess', 'Ġin', 'Ġthe', 'Ġsh', 'ad', 'e']

In [51]:
bpe.tokenize(text)[-100:]

['Ġsp',
 'o',
 'ke',
 'ĠM',
 'in',
 'er',
 'v',
 'a',
 ',',
 'Ġand',
 'Ġ',
 'U',
 'ly',
 's',
 'ses',
 'Ġo',
 'b',
 'e',
 'y',
 'ed',
 'Ġher',
 'Ġg',
 'l',
 'ad',
 'ly',
 '.',
 'ĠT',
 'hen',
 'ĠM',
 'in',
 'er',
 'v',
 'a',
 'Ġas',
 's',
 'u',
 'm',
 'ed',
 'the',
 'Ġfor',
 'm',
 'Ġand',
 'Ġ',
 'v',
 'o',
 'i',
 'ce',
 'Ġof',
 'ĠM',
 'ent',
 'or',
 ',',
 'Ġand',
 'Ġp',
 're',
 's',
 'ent',
 'ly',
 'Ġm',
 'ad',
 'e',
 'Ġa',
 'Ġc',
 'o',
 'ven',
 'an',
 't',
 'Ġof',
 'Ġp',
 'e',
 'a',
 'ce',
 'b',
 'et',
 'w',
 'e',
 'en',
 'Ġthe',
 'Ġt',
 'w',
 'o',
 'Ġc',
 'on',
 't',
 'e',
 'nd',
 'ing',
 'Ġp',
 'ar',
 't',
 'i',
 'es',
 '.',
 'T',
 'H',
 'E',
 'Ġ',
 'E',
 'N',
 'D']

## Byte-level tokenization

In [52]:
control_chars1 = range(0, 33)
control_chars2 = range(127, 160)
control_chars3 = range(173, 174)

In [53]:
shift_table = list(enumerate(list(control_chars1) +
                   list(control_chars2) + list(control_chars3)))

In [54]:
len(shift_table)

67

In [55]:
ord(' ')

32

In [56]:
chr(ord(' ') + 256)

'Ġ'

In [57]:
gtp2_special_bytes = {}
for n, b in shift_table:
    gtp2_special_bytes[b] = chr(n + 256)

In [58]:
gtp2_special_bytes

{0: 'Ā',
 1: 'ā',
 2: 'Ă',
 3: 'ă',
 4: 'Ą',
 5: 'ą',
 6: 'Ć',
 7: 'ć',
 8: 'Ĉ',
 9: 'ĉ',
 10: 'Ċ',
 11: 'ċ',
 12: 'Č',
 13: 'č',
 14: 'Ď',
 15: 'ď',
 16: 'Đ',
 17: 'đ',
 18: 'Ē',
 19: 'ē',
 20: 'Ĕ',
 21: 'ĕ',
 22: 'Ė',
 23: 'ė',
 24: 'Ę',
 25: 'ę',
 26: 'Ě',
 27: 'ě',
 28: 'Ĝ',
 29: 'ĝ',
 30: 'Ğ',
 31: 'ğ',
 32: 'Ġ',
 127: 'ġ',
 128: 'Ģ',
 129: 'ģ',
 130: 'Ĥ',
 131: 'ĥ',
 132: 'Ħ',
 133: 'ħ',
 134: 'Ĩ',
 135: 'ĩ',
 136: 'Ī',
 137: 'ī',
 138: 'Ĭ',
 139: 'ĭ',
 140: 'Į',
 141: 'į',
 142: 'İ',
 143: 'ı',
 144: 'Ĳ',
 145: 'ĳ',
 146: 'Ĵ',
 147: 'ĵ',
 148: 'Ķ',
 149: 'ķ',
 150: 'ĸ',
 151: 'Ĺ',
 152: 'ĺ',
 153: 'Ļ',
 154: 'ļ',
 155: 'Ľ',
 156: 'ľ',
 157: 'Ŀ',
 158: 'ŀ',
 159: 'Ł',
 173: 'ł'}