In [15]:
import re

In [16]:
word_split_pat = re.compile(r"\W+")

In [17]:
text: str
with open("./corpora/sherlock.txt", "r") as f:
    text = f.read().lower()

In [18]:
def split_words(s: str) -> list[str]:
    return re.split(word_split_pat, s)


words = split_words(text)
words[:10]

['', 'a', 'study', 'in', 'scarlet', 'arthur', 'conan', 'doyle', 'table', 'of']

In [19]:
def split_chars(words: list[str]) -> list[list[str]]:
    return list(filter(lambda xs: len(xs) > 0, [[c for c in w] for w in words]))


chars = split_chars(words)
chars[:10]

[['a'],
 ['s', 't', 'u', 'd', 'y'],
 ['i', 'n'],
 ['s', 'c', 'a', 'r', 'l', 'e', 't'],
 ['a', 'r', 't', 'h', 'u', 'r'],
 ['c', 'o', 'n', 'a', 'n'],
 ['d', 'o', 'y', 'l', 'e'],
 ['t', 'a', 'b', 'l', 'e'],
 ['o', 'f'],
 ['c', 'o', 'n', 't', 'e', 'n', 't', 's']]

In [20]:
def reduce(tokens: list[list[str]]) -> tuple[str, list[list[str]]]:
    cache: dict[str, int] = {}

    for word in tokens:
        for fst, scnd in zip(word, word[1:]):
            pair = fst + " " + scnd
            if pair in cache.keys():
                continue
            pair_pat = re.compile(fst + scnd)
            freq = 0
            for wordlist in tokens:
                word = "".join(wordlist)
                freq += len(re.findall(pair_pat, word))

            cache[pair] = freq

    space_seperated_replacement_token, _ = max(
        cache.items(), key=lambda x: x[1]
    )  # still contains a whitespace
    fst, scnd = space_seperated_replacement_token.split()
    replacement_token = space_seperated_replacement_token.replace(" ", "")

    new_tokens = []

    for w in tokens:
        if fst in w and scnd in w:
            if w.index(fst) == w.index(scnd) - 1:
                new_word = (
                    w[: w.index(fst)] + [replacement_token] + w[w.index(scnd) + 1 :]
                )
                new_tokens.append(new_word)
                continue
        new_tokens.append(w)

    return replacement_token, new_tokens

In [21]:
def bpe(text: str, nsteps: int) -> tuple[set[str], list[list[str]]]:
    vocab = set()
    words = split_words(text)
    tokens = split_chars(words)

    for w in tokens:
        for c in w:
            vocab.add(c)

    for _ in range(nsteps):
        new_token, tokens = reduce(tokens)
        vocab.add(new_token)

    return vocab, tokens

In [22]:
vocab, tokens = bpe(text, 2)

In [23]:
sorted([t for t in vocab], key=lambda s: len(s), reverse=True)

['th',
 'q',
 '6',
 'u',
 'c',
 'l',
 'o',
 'g',
 's',
 '9',
 '0',
 't',
 'h',
 '7',
 '2',
 'v',
 '3',
 'x',
 'n',
 'r',
 'e',
 '8',
 'y',
 'b',
 'j',
 'd',
 'i',
 'ã',
 'k',
 '4',
 '5',
 'w',
 'p',
 'z',
 'f',
 'a',
 '1',
 'm']