In [None]:
from datasets import load_dataset
import os

os.makedirs("data", exist_ok=True)

data = load_dataset("wikitext", 'wikitext-103-v1', split="train")
data_1m = ""
with open("data/wikitext_1m.txt", "w") as f:
    for line in data["text"][:1_000_000]:
        line = line.replace("<unk>", "").strip()
        if line:
            f.write(line+"\n")
            data_1m += line+"\n"


In [3]:
from tokenizers import (
    models,
    pre_tokenizers,
    trainers,
    Tokenizer,
)
import collections
dists = {}

def get_bpe_unigram(vocab_size):
    print(vocab_size, "training")
    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    trainer = trainers.BpeTrainer(vocab_size=vocab_size)
    tokenizer.train(["data/wikitext_1m.txt"], trainer=trainer)
    print(vocab_size, "encoding")
    encoding = tokenizer.encode(data_1m)
    return collections.Counter(encoding.tokens)


dists["bpe_8k"] = get_bpe_unigram(8_000)
dists["bpe_28k"] = get_bpe_unigram(28_000)

8000 training



8000 encoding
28000 training



28000 encoding


In [None]:
from transformers import AutoTokenizer

gpt2tokenizer = AutoTokenizer.from_pretrained("gpt2")
dists["gpt2"] = collections.Counter(gpt2tokenizer.tokenize(data_1m))
new_collection = collections.Counter()
for k,v in dists["gpt2"].items():
    new_collection[k.replace("Ġ", "")] = v
dists["gpt2"] = new_collection

In [39]:
def vocab_to_unmerges(freqs):
    # TODO: this is observed vocab which might not be the real one
    vocab = set(freqs.keys())
    unmerges = {}
    for word in vocab:
        for i in range(len(word)):
            if word[:i] in vocab and word[i:] in vocab:
                unmerges[word] = (word[:i], word[i:])
                # go to next word
                break
    return unmerges

In [40]:
import copy
import random
import numpy as np

def renyi_entropy(P, alpha):
    scale = 1 / (1 - alpha)

    return scale * np.log2(np.sum([
        prob**alpha
        for prob in P
    ]))

def renyi_eff(P, alpha):
    return renyi_entropy(P, alpha)/np.log(len(P))

def shannon_entropy(P):
    P = np.array(P)
    return -np.sum(P * np.log2(P))

def shannon_eff(P):
    return shannon_entropy(P)/np.log(len(P))

def table_line(P, extra=[]):
    out = extra
    out.append(f"{shannon_entropy(P):.2f}")
    out.append(f"{renyi_entropy(P, 0.5):.2f}")
    out.append(f"{renyi_entropy(P, 3):.2f}")
    out.append(f"{shannon_eff(P):.2f}")
    out.append(f"{renyi_eff(P, 0.5):.2f}")
    out.append(f"{renyi_eff(P, 3):.2f}")
    return "& " + " & ".join(out) + r"\\"

def freqs_to_p(freqs):
    total = sum(freqs.values())
    return [v/total for k, v in freqs.most_common()]

def drop_bpe(freqs, unmerges, N, k):
    freqs = copy.deepcopy(freqs)
    # get top N
    words_top_N = [k for k, v in freqs.most_common() if len(k) > 1][:N]

    # sample k words
    dead_tokens = random.sample(words_top_N, k=k)
    for token in dead_tokens:
        # add the old frequency to the individual characters
        if token in unmerges:
            # remove existing token
            token_freq = freqs.pop(token)
            for c in unmerges[token]:
                freqs[c] += token_freq
        else:
            pass

    return freqs_to_p(freqs)


# NOTE for future self: the fact that the efficiency goes down is likely
# caused by the unmerges vocabulary being made of observed vocabulary and not
# the real unmerges (?)

random.seed(0)
for tokenizer, freqs in dists.items():
    unmerges = vocab_to_unmerges(freqs)
    print(tokenizer)
    print(table_line(freqs_to_p(freqs), extra=["", ""]))
    print(table_line(drop_bpe(freqs, unmerges, 2_500, 500), extra=["2500", "500"]))
    print(table_line(drop_bpe(freqs, unmerges, 2_500, 1000), extra=["2500", "1000"]))
    print(table_line(drop_bpe(freqs, unmerges, None, 500), extra=[r"$\infty$", "500"]))
    print(table_line(drop_bpe(freqs, unmerges, None, 1000), extra=[r"$\infty$", "1000"]))
    print(r"\hdashline", "\n")

bpe_8k
&  &  & 10.16 & 11.82 & 6.17 & 1.13 & 1.32 & 0.69\\
& 2500 & 500 & 9.46 & 11.47 & 5.70 & 1.06 & 1.28 & 0.64\\
& 2500 & 1000 & 9.05 & 11.20 & 5.84 & 1.02 & 1.26 & 0.66\\
& $\infty$ & 500 & 9.96 & 11.65 & 6.26 & 1.12 & 1.31 & 0.70\\
& $\infty$ & 1000 & 9.62 & 11.44 & 6.25 & 1.09 & 1.29 & 0.71\\
\hdashline 

bpe_28k
&  &  & 10.65 & 13.29 & 5.83 & 1.04 & 1.30 & 0.57\\
& 2500 & 500 & 10.43 & 13.14 & 5.99 & 1.02 & 1.29 & 0.59\\
& 2500 & 1000 & 9.31 & 12.77 & 5.21 & 0.91 & 1.25 & 0.51\\
& $\infty$ & 500 & 10.63 & 13.25 & 5.85 & 1.04 & 1.30 & 0.57\\
& $\infty$ & 1000 & 10.62 & 13.21 & 5.88 & 1.04 & 1.30 & 0.58\\
\hdashline 

gpt2
&  &  & 12.51 & 14.16 & 6.59 & 1.18 & 1.34 & 0.62\\
& 2500 & 500 & 11.88 & 14.04 & 5.16 & 1.12 & 1.33 & 0.49\\
& 2500 & 1000 & 11.28 & 13.94 & 4.08 & 1.07 & 1.32 & 0.39\\
& $\infty$ & 500 & 12.41 & 14.14 & 6.52 & 1.17 & 1.34 & 0.62\\
& $\infty$ & 1000 & 12.37 & 14.13 & 6.45 & 1.17 & 1.34 & 0.61\\
\hdashline 

