In [43]:
with open('smiles.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [52]:
tokens = text[:1_000_000].encode('utf-8')
tokens = list(map(int, tokens))

In [53]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

merging (99, 99) into a new token 256
merging (79, 41) into a new token 257
merging (99, 49) into a new token 258
merging (40, 61) into a new token 259
merging (259, 257) into a new token 260
merging (99, 40) into a new token 261
merging (67, 67) into a new token 262
merging (99, 50) into a new token 263
merging (41, 67) into a new token 264
merging (258, 256) into a new token 265
merging (99, 51) into a new token 266
merging (261, 256) into a new token 267
merging (67, 260) into a new token 268
merging (72, 93) into a new token 269
merging (10, 67) into a new token 270
merging (263, 256) into a new token 271
merging (91, 67) into a new token 272
merging (272, 64) into a new token 273
merging (40, 67) into a new token 274
merging (10, 265) into a new token 275


In [87]:
def decode(ids):
    rev_merge = {v:k for k,v in merges.items()}
    def rec(subids, out=[]):
        for id in subids:
            if id < 256:
                out.append(id)
            else:
                if id in rev_merge:
                    rec(rev_merge[id], out)
        return out
    out = rec(ids)
    out = b"".join([bytes([x]) for x in out])
    return out.decode('utf-8', errors='replace')
decoded = decode(ids)

In [88]:
upto = 10000
decoded[:upto] == text[:upto]

True

In [89]:
merges

{(99, 99): 256,
 (79, 41): 257,
 (99, 49): 258,
 (40, 61): 259,
 (259, 257): 260,
 (99, 40): 261,
 (67, 67): 262,
 (99, 50): 263,
 (41, 67): 264,
 (258, 256): 265,
 (99, 51): 266,
 (261, 256): 267,
 (67, 260): 268,
 (72, 93): 269,
 (10, 67): 270,
 (263, 256): 271,
 (91, 67): 272,
 (272, 64): 273,
 (40, 67): 274,
 (10, 265): 275}

In [90]:
def encode(text):
    tokens = list(text.encode("utf-8"))
    # go through merges in the order they were generated
    for pair in merges:
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i],tokens[i+1]) == pair:
                new_tokens.append(merges[pair])
                i += 2
                continue
            new_tokens.append(tokens[i])
            i += 1
        tokens = new_tokens
    return tokens

encode(text[:20])


[67, 273, 64, 93, 49, 40, 268, 67, 61, 67, 40, 79, 49, 41]

In [91]:
sample = text[:200]
sample == decode(encode(sample))

True