In [1]:
import re
import collections
import urllib.request
from typing import Dict, List, Tuple, Set

import pandas as pd
from IPython.display import display, Markdown, Latex
from tokenizers import BertWordPieceTokenizer

## BPE (Byte Pair Encoding)

BPE is a data compression algorithm proposed in 1994.  
Basically, it works by finding a pair of consecutive words most frequently appeared, and merged it into one letter.

- e.g. `aaabdaaabac`
    - `Z=aa` $\rightarrow$   `ZabdZabac`
    - `Z=aa, Y=ab` $\rightarrow$  `ZYdZYac`
    - `Z=aa, Y=ab, X=ZY` $\rightarrow$  `XdXac`

In natural language preprocessing, BPE is a subword segmentation algorithm, which means it splits exsiting word.
- e.g. frequency of each word in train vocaburay
    ```python
    # dictionary (frequency of each word in train vocaburary)
    low : 5, lower : 2, newest : 6, widest : 3

    # vocabulary
    low, lower, newest, widest

        ↓

    # dictionary
    l o w : 5,  l o w e r : 2,  n e w e s t : 6,  w i d e s t : 3
    
    # vocabulary
    l, o, w, e, r, n, s, t, i, d

        ↓ (1st update, "(e, s)" is the most frequent pair)

    # dictionary update
    l o w : 5,
    l o w e r : 2,
    n e w es t : 6,
    w i d es t : 3

    # vocabulary update
    l, o, w, e, r, n, s, t, i, d, es

        ↓ (2nd update, "(es, t)" is the most frequent pair)

    # dictionary update
    l o w : 5,
    l o w e r : 2,
    n e w es t : 6,
    w i d es t : 3

    # vocabulary update
    l, o, w, e, r, n, s, t, i, d, es, est

    ↓ (3rd update, "(l, o)" is the most frequent pair)

    # dictionary update
    l o w : 5,
    l o w e r : 2,
    n e w es t : 6,
    w i d es t : 3

    # vocabulary update
    l, o, w, e, r, n, s, t, i, d, es, est, lo

    ...
    ```
- see [Sennrich et al., (2016)]


[Sennrich et al., (2016)]: https://arxiv.org/abs/1508.07909

In [2]:
def get_stats(dictionary: Dict[str, int]) -> Dict[str, int]:
    pairs = collections.defaultdict(int)
    for word, freq in dictionary.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs


def merge_dictionary(pair: Tuple[str, str], v_in: Dict[str, int]) -> Dict[str, int]:
    v_out = {}
    bigram = re.escape(" ".join(pair))
    # (?<!\S) => negative lookbehind
    #   - ?<!X: case where there's no X right in front of the current location
    #   - \S: non-white-space character
    # (?!\S) => negative lookahead
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub("".join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [3]:
num_merges = 10

dictionary = {
    "l o w </w>": 5,
    "l o w e r </w>": 2,
    "n e w e s t </w>": 6,
    "w i d e s t </w>": 3,
}

In [4]:
get_stats(dictionary)

defaultdict(int,
            {('l', 'o'): 7,
             ('o', 'w'): 7,
             ('w', '</w>'): 5,
             ('w', 'e'): 8,
             ('e', 'r'): 2,
             ('r', '</w>'): 2,
             ('n', 'e'): 6,
             ('e', 'w'): 6,
             ('e', 's'): 9,
             ('s', 't'): 9,
             ('t', '</w>'): 9,
             ('w', 'i'): 3,
             ('i', 'd'): 3,
             ('d', 'e'): 3})

In [5]:
bpe_codes = {}
bpe_codes_reverse = {}
for i in range(num_merges):
    pairs = get_stats(dictionary)
    most_frequent_pair = max(pairs, key=pairs.get)
    
    dictionary = merge_dictionary(most_frequent_pair, dictionary)

    bpe_codes[most_frequent_pair] = i
    bpe_codes_reverse["".join(most_frequent_pair)] = most_frequent_pair

In [6]:
get_stats(dictionary)

defaultdict(int,
            {('low', 'e'): 2,
             ('e', 'r'): 2,
             ('r', '</w>'): 2,
             ('wi', 'd'): 3,
             ('d', 'est</w>'): 3})

In [7]:
bpe_codes

{('e', 's'): 0,
 ('es', 't'): 1,
 ('est', '</w>'): 2,
 ('l', 'o'): 3,
 ('lo', 'w'): 4,
 ('n', 'e'): 5,
 ('ne', 'w'): 6,
 ('new', 'est</w>'): 7,
 ('low', '</w>'): 8,
 ('w', 'i'): 9}

In [8]:
bpe_codes_reverse

{'es': ('e', 's'),
 'est': ('es', 't'),
 'est</w>': ('est', '</w>'),
 'lo': ('l', 'o'),
 'low': ('lo', 'w'),
 'ne': ('n', 'e'),
 'new': ('ne', 'w'),
 'newest</w>': ('new', 'est</w>'),
 'low</w>': ('low', '</w>'),
 'wi': ('w', 'i')}

In [9]:
def get_pairs(word: Tuple[str]) -> Set[str]:
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    if not word:
        return set()
    
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(word: str, bpe_codes: Dict[str, int]) ->  Tuple[str]:
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    # e.g. word="loki"
    #  - chars = ("l", "o", "k", "i", "</w>")
    #  - pairs = (("l", "o"), ("o", "k"), ("k", "i"), ("i", "</w>"))
    chars = tuple(word) + ("</w>", )
    pairs = get_pairs(word)

    if not pairs:
        return word

    num_iter = 0
    while True:
        num_iter += 1        
        bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float("inf")))

        # there's no further merge
        if bigram not in bpe_codes:
            break
        
        c1, c2 = bigram
        new = []
        i = 0
        while i < len(chars):
            cur = chars[i]
            # update i until cur == c1
            if c1 != cur:
                new.append(cur)
                i += 1
                continue

            # in case of (..., c1, c2, ...), merge c1 and c2
            #                  i       next 
            if c1 == cur and i < len(chars)-1 and c2 == chars[i+1]:                
                new.append(c1+c2)
                i += 2
            # in case of (..., c1, c3, ...), where c2 != c3
            #                  i   next
            else:
                new.append(c1)
                i += 1

        chars = tuple(new)
        if len(chars) == 1:
            break
        else:
            pairs = get_pairs(chars)

    # ignore </w> token
    if chars[-1] == "</w>":
        chars = chars[:-1]
    elif chars[-1].endswith("</w>"):
        chars = chars[:-1] + (chars[-1].replace("</w>", ""),)

    return chars

In [14]:
encode("", bpe_codes)

''

In [15]:
encode("loki", bpe_codes)

('lo', 'k', 'i')

In [16]:
encode("lowest", bpe_codes)

('low', 'est')

In [17]:
encode("lowing", bpe_codes)

('low', 'i', 'n', 'g')

In [18]:
encode("highing", bpe_codes)

('h', 'i', 'g', 'h', 'i', 'n', 'g')

## SentencePiece

## SubwodTextEncoder

## Huggingface Tokenizer