In [2]:
import regex as re

In [3]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [4]:
re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [5]:
corpus = [
    "low low low low low",
    "lower lower widest widest widest",
    "newest newest newest newest newest newest"
]

In [6]:
from collections import Counter
from typing import Dict, Tuple, List
word_freq = Counter()
for line in corpus:
    for w in line.split():
        word_freq[w] += 1
print(word_freq)

Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2})


In [7]:
# ----- utility -----
def split_bytes(word: str) -> Tuple[bytes, ...]:
    """Convert a word into tuple of 1-byte bytes objects."""
    b = word.encode("utf-8")
    return tuple(b[i:i+1] for i in range(len(b)))

In [8]:
split_bytes("café")

(b'c', b'a', b'f', b'\xc3', b'\xa9')

In [9]:
# vocab maps tuple-of-byte-symbols → frequency
vocab: Dict[Tuple[bytes, ...], int] = {
    split_bytes(w): c for w, c in word_freq.items()
}
print(vocab)

{(b'l', b'o', b'w'): 5, (b'l', b'o', b'w', b'e', b'r'): 2, (b'w', b'i', b'd', b'e', b's', b't'): 3, (b'n', b'e', b'w', b'e', b's', b't'): 6}


In [10]:
def get_stats(vocab: Dict[Tuple[bytes, ...], int]) -> Counter:
    """
    Count frequency of every adjacent pair of symbols in the vocab,
    weighed by he word frequency.
    """
    stats = Counter()
    for symbols, freq in vocab.items():
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            stats[pair] += freq
    return stats

In [11]:
stats = get_stats(vocab)
print(stats)

Counter({(b'e', b's'): 9, (b's', b't'): 9, (b'w', b'e'): 8, (b'l', b'o'): 7, (b'o', b'w'): 7, (b'n', b'e'): 6, (b'e', b'w'): 6, (b'w', b'i'): 3, (b'i', b'd'): 3, (b'd', b'e'): 3, (b'e', b'r'): 2})


In [12]:
def choose_pair(stats: Counter) -> Tuple[int, int]:
    """
    Pick the most frequent pair
    Tie-break with lexicographically greater pair
    """
    return max(stats.items(), key=lambda kv: (kv[1], kv[0]))[0]

In [13]:
most_freq_pair = choose_pair(stats)
print(most_freq_pair)

(b's', b't')


In [14]:
def merge_pair_in_word(
    symbols: Tuple[bytes, ...],
    pair: Tuple[bytes, bytes],
    new_symbol: bytes,
) -> Tuple[bytes, ...]:
    """
    Replace every occurence of `pair` in `symbols` with the single symbol `new_symbol`
    """
    a, b = pair
    out: List[bytes] = []
    i = 0
    while i < len(symbols):
        if i < len(symbols) - 1 and symbols[i] == a and symbols[i+1] == b:
            out.append(new_symbol)
            i += 2
        else:
            out.append(symbols[i])
            i += 1
    return tuple(out)

In [15]:
def apply_merge(
    vocab: Dict[Tuple[bytes, ...], int], 
    pair: Tuple[bytes, bytes],
) -> Tuple[Dict[Tuple[bytes, ...], int], bytes]:
    """
    Merge `pair` into a new symbol (their byte concatenation) everywhere.
    """
    a, b = pair
    new_symbol = a + b
    new_vocab: Dict[Tuple[bytes, ...], int] = {}
    for word, freq in vocab.items():
        new_word = merge_pair_in_word(word, pair, new_symbol)
        new_vocab[new_word] = freq
    return new_vocab, new_symbol

In [16]:
def bpe_train(vocab: Dict[Tuple[bytes, ...], int], num_merges: int):
    """
    Run BPE for `num_merges` stes.
    Returns final vocab and the list of merges (pair, new_symbol)
    """
    merges: List[Tuple[Tuple[bytes, bytes], bytes]] = []
    for _ in range(num_merges):
        stats = get_stats(vocab)
        pair = choose_pair(stats)
        vocab, new_symbol = apply_merge(vocab, pair)
        merges.append((pair, new_symbol))
    return vocab, merges

In [17]:
# Train
final_vocab, merges = bpe_train(vocab, num_merges=6)

In [18]:
print("Merges:")
for pair, new_symbol in merges:
    a, b = pair
    print(
        f"({a.decode('utf-8')!r}, {b.decode('utf-8')!r}) "
        f"-> {new_symbol.decode('utf-8')!r}"
    )

Merges:
('s', 't') -> 'st'
('e', 'st') -> 'est'
('o', 'w') -> 'ow'
('l', 'ow') -> 'low'
('w', 'est') -> 'west'
('n', 'e') -> 'ne'


In [20]:
def bpe_encode(word: str, merges: List[Tuple[Tuple[bytes, bytes], bytes]]) -> Tuple[bytes, ...]:
    """
    Encode a word with the learned merges.
    Returns a list of byte symbols
    """
    symbols = split_bytes(word)
    for pair, new_symbol in merges:
        symbols = merge_pair_in_word(symbols, pair, new_symbol)
    return symbols

In [23]:
# Tokenize
new = "newest"
encoded_new = bpe_encode(new, merges)
print(f"Tokenization of '{new}':", [s.decode("utf-8") for s in encoded_new])

Tokenization of 'newest': ['ne', 'west']
