In [5]:
with open("text.txt", "r") as file:
    text = file.read()

tokens = [int(token) for token in text.encode("utf-8")]

print("text length:", len(text))
print("tokens length:", len(tokens))
print(tokens)


text length: 533
tokens length: 616
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 101, 32, 118, 101, 114, 121, 32, 110, 97, 109, 101, 32, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 101, 32, 105, 110, 116, 111, 32, 116, 104, 101, 32, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 101, 32, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 101, 

In [24]:
from collections import defaultdict

def get_stats(tokens):
    stats = defaultdict(int)
    for pair in zip(tokens, tokens[1:]):
        stats[pair] += 1
    return dict(stats)

stats = get_stats(tokens)

In [27]:
from pprint import pprint

def print_top_pairs(stats, n=5):
    pprint(sorted(((v,k) for k,v in stats.items()), reverse=True)[:n])

print_top_pairs(stats)

[(20, (101, 32)),
 (15, (240, 159)),
 (12, (226, 128)),
 (12, (105, 110)),
 (10, (115, 32))]


In [28]:
chr(101), chr(32)

('e', ' ')

In [29]:
chr(240), chr(159)

('ð', '\x9f')

In [30]:
top_pair = max(stats, key=stats.get)
top_pair

(101, 32)

In [33]:
def merge(ids, pair, idx):
    """In the list of ids, replace each instance of the pair with the new id"""
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids)-1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

assert merge([1,2,3,4,5], (1,2), 6) == [6,3,4,5]
tokens2 = merge(tokens, top_pair, 256)

In [35]:
print(len(tokens2), len(tokens))
stats2 = get_stats(tokens2)
print_top_pairs(stats2)

596 616
[(15, (240, 159)),
 (12, (226, 128)),
 (12, (105, 110)),
 (10, (115, 32)),
 (10, (97, 110))]


In [55]:
with open("long_text.txt", "r") as file:
    text = file.read()

tokens = [int(token) for token in text.encode("utf-8")]

print("text length:", len(text))
print("tokens length:", len(tokens))

text length: 24141
tokens length: 25433


In [56]:
vocab_size = 276 # The desired vocab size
num_merges = vocab_size - 256
ids = list(tokens) # Make a copy

merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (101, 32) into 256
merging (105, 110) into 257
merging (115, 32) into 258
merging (116, 104) into 259
merging (101, 114) into 260
merging (99, 111) into 261
merging (116, 32) into 262
merging (226, 128) into 263
merging (44, 32) into 264
merging (97, 110) into 265
merging (111, 114) into 266
merging (100, 32) into 267
merging (97, 114) into 268
merging (101, 110) into 269
merging (257, 103) into 270
merging (261, 100) into 271
merging (32, 32) into 272
merging (121, 32) into 273
merging (97, 108) into 274
merging (111, 110) into 275


In [57]:
print(f"tokens length: {len(tokens)}")
print(f"ids length: {len(ids)}")
print(f"compression ratio: {len(tokens) / len(ids):.2f}")

tokens length: 25433
ids length: 20248
compression ratio: 1.26


In [58]:
# build inverse merge map
vocab = {v:k for k,v in merges.items()}
pprint(list(vocab.items())[:5])

[(256, (101, 32)),
 (257, (105, 110)),
 (258, (115, 32)),
 (259, (116, 104)),
 (260, (101, 114))]


In [71]:
def decode(ids):
    out = []
    for id in ids:
        if id < 256:
            out.append(id)
        else:
            out.extend(decode(vocab[id]))
    return out

decoded = decode(ids)
assert text == bytes(decoded).decode("utf-8")