In [1]:
from datasets import load_dataset
from collections import Counter
from collections import defaultdict, OrderedDict
from random import randint
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset('swag', 'regular', split='train[:1000]')
corpus = dataset['ending0']

In [3]:
words_freq = Counter(re.sub(r"([.,])", r" \1 ", ' '.join(corpus)).split())
vocab = []

for word in words_freq:
    if word[0] not in vocab:
        vocab.append(word[0])
    
    for c in word[1:]:
        if not f"##{c}" in vocab:
            vocab.append(f"##{c}")

vocab_itos = dict(zip(range(len(vocab)), vocab))
vocab_stoi = dict(zip(vocab, range(len(vocab))))

print("vocab:", sorted(vocab))
print("vocab_size: ", len(vocab))

splits = {
    word: [f"##{item}" if i!=0 else item for i, item in enumerate(list(word))]
    for word in words_freq
}

def make_pair(l: list):
    """
    creates pairs for the given list
    """
    return [(l[i], l[i+1]) for i in range(len(l)-1)]

def compute_score(pairs_count_, letter_counts_):
    scores = defaultdict(float)

    for k, v in pairs_count_.items():
        letter1_count = letter_counts_[k[0]]
        letter2_count = letter_counts_[k[1]]
        pair_count = v
        score = pair_count / (letter1_count * letter2_count)
        scores[k] = score
    return scores

def merge_pair(a, b):
    new_splits = defaultdict(list)
    new_tokens = set()

    for word, split in splits.items():
        new_split = split
        new_splits[word] = new_split
        for i in range(len(split)-1):
            if split[i] == a and split[i+1] == b:
                merge = a + b[2:] if b[:2] == "##" else a + b
                new_split = split[:i] + [merge] + split[i+2:]
                new_tokens.add(merge)
        new_splits[word] = new_split
    return (new_splits, new_tokens)

def encode_word(word):
    tokens = []
    found = True
    while word != "##" and found:
        found = False
        for i in range(len(word), -1, -1):
            if word[:i] in vocab_stoi:
                tokens.append(word[:i])
                word = f"##{word[i:]}"
                found = True
                break
        if not found:
            tokens.append("UNK")
    return tokens

vocab: ['##!', "##'", '##7', '##:', '##;', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##q', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', '##z', "'", ',', '-', '.', '1', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z']
vocab_size:  61


In [4]:
i = 0
max_iters = 1000

while True:
    letter_counts = Counter(sum([splits[w] * c for w, c in words_freq.items()], []))
    pair_counts = Counter(sum([make_pair(splits[w]) * c for w, c in words_freq.items()], []))

    scores = compute_score(pair_counts, letter_counts)
    if not scores or i == max_iters:
        break
    best_pair = max(scores.items(), key=lambda item: item[1])

    new_splits, new_tokens = merge_pair(*best_pair[0])
    splits = new_splits
    vocab.extend(new_tokens)
    i += 1

print(f"Total iterations: {i}")

Total iterations: 1000


In [5]:
vocab_itos = dict(zip(range(len(vocab)), vocab))
vocab_stoi = dict(zip(vocab, range(len(vocab))))

In [6]:
len(vocab)

1061

In [7]:
vocab

['p',
 '##a',
 '##s',
 '##e',
 'b',
 '##y',
 'w',
 '##l',
 '##k',
 '##i',
 '##n',
 '##g',
 'd',
 '##o',
 '##w',
 't',
 '##h',
 's',
 '##t',
 '##r',
 'i',
 '##u',
 '##m',
 '.',
 'a',
 '##d',
 'c',
 '##b',
 'o',
 'l',
 '##f',
 'e',
 '##c',
 'q',
 '##p',
 'f',
 'm',
 '##x',
 'h',
 'v',
 'g',
 '##z',
 "##'",
 'u',
 ',',
 '##v',
 'r',
 'k',
 '-',
 'j',
 'n',
 'z',
 '##q',
 'y',
 '##j',
 '##;',
 '##:',
 "'",
 '##!',
 '1',
 '##7',
 '17',
 'ex',
 '##bj',
 '##p:',
 'obj',
 'qu',
 '##qu',
 'up',
 'exp',
 'ju',
 'jum',
 'jump',
 'equ',
 'of',
 'off',
 'ov',
 '##ff',
 '##ck',
 '##ap:',
 'map:',
 'qui',
 'quick',
 'quickl',
 'quickly',
 'equi',
 'equip',
 'equipm',
 'ey',
 'ev',
 "n'",
 "n't",
 '##s;',
 'us;',
 "'s",
 '##s!',
 'expl',
 "##'s",
 'th',
 '##ch',
 '##dj',
 '##udj',
 'adj',
 '##udja',
 '##oudja',
 '##boudja',
 '##aboudja',
 '##raboudja',
 '##araboudja',
 'karaboudja',
 'adjo',
 'adjoi',
 'karaboudjan',
 'adjoin',
 'adjoini',
 'adjoinin',
 'adjoining',
 'zo',
 'zu',
 'zum',
 'zumb',
 'zu

In [8]:
# let's encode some random words
start = randint(0, len(corpus)-10)
end = start + 10
words = re.sub(r"([.,])", r" \1 ", ' '.join(corpus[start:end])).split()
max_word_len = len(max(words, key=len))

for word in words:
    print(f"{word:<{max_word_len}}: {encode_word(word)}")

perform      : ['p', '##e', '##r', '##f', '##o', '##r', '##m']
cartwheels   : ['cartwh', '##e', '##e', '##l', '##s']
,            : [',']
then         : ['th', '##e', '##n']
fly          : ['f', '##l', '##y']
in           : ['in']
the          : ['th', '##e']
water        : ['wat', '##e', '##r']
and          : ['and']
shoot        : ['sh', '##o', '##o', '##t']
another      : ['a', '##n', '##o', '##th', '##e', '##r']
.            : ['.']
takes        : ['tak', '##e', '##s']
off          : ['off']
another      : ['a', '##n', '##o', '##th', '##e', '##r']
mixture      : ['mixtur', '##e']
and          : ['and']
flashes      : ['flash', '##e', '##s']
radar        : ['r', '##a', '##d', '##a', '##r']
at           : ['a', '##t']
the          : ['th', '##e']
end          : ['end']
of           : ['of']
the          : ['th', '##e']
glass        : ['g', '##l', '##a', '##s', '##s']
.            : ['.']
take         : ['tak', '##e']
their        : ['th', '##e', '##i', '##r']
turns        : ['turns']