In [20]:
# Unicode
h = 'Я хочу учить французский!'
print([ord(l) for l in h])
# Vocabulary is big (~150k) and it's still alive and keeps changing.
# Having constant text representation is preferrable.

[1071, 32, 1093, 1086, 1095, 1091, 32, 1091, 1095, 1080, 1090, 1100, 32, 1092, 1088, 1072, 1085, 1094, 1091, 1079, 1089, 1082, 1080, 1081, 33]


In [21]:
# Encodings turn unicode codepoints (unique integers) 
# and turn them to sequence from 1 to 4 bytes long.
# No matter what state Unicode is at the moment,
# shape of "utf-8" remains the same.
byteobject = h.encode("utf-8")
integers = list(byteobject)
print(integers)
# Vocabulary size is small now what results in longer text represenation.
# Since we work on bathes, bigger size of one is needed for considering same context.
# Bigger batch increases computational cost of attention.

[208, 175, 32, 209, 133, 208, 190, 209, 135, 209, 131, 32, 209, 131, 209, 135, 208, 184, 209, 130, 209, 140, 32, 209, 132, 209, 128, 208, 176, 208, 189, 209, 134, 209, 131, 208, 183, 209, 129, 208, 186, 208, 184, 208, 185, 33]


In [22]:
# Instead we want to support bigger vocabulary size, that we can tune as a hyperparameter.
# But at the same time to stick to raw bytes "utf-8" representation. 
string = 'Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.'
tokens = list(string.encode('utf-8'))

def get_stats(ids: list) -> dict:
    stats = dict()
    for pair in zip(ids, ids[1:]):
        stats[pair] = stats.get(pair, 0) + 1
    return stats

stats = get_stats(tokens)
print(sorted(((v, k) for k, v in stats.items()), reverse=True))

[(20, (101, 32)), (15, (240, 159)), (12, (226, 128)), (12, (105, 110)), (10, (115, 32)), (10, (97, 110)), (10, (32, 97)), (9, (32, 116)), (8, (116, 104)), (7, (159, 135)), (7, (159, 133)), (7, (97, 114)), (6, (239, 189)), (6, (140, 240)), (6, (128, 140)), (6, (116, 32)), (6, (114, 32)), (6, (111, 114)), (6, (110, 103)), (6, (110, 100)), (6, (109, 101)), (6, (104, 101)), (6, (101, 114)), (6, (32, 105)), (5, (117, 115)), (5, (115, 116)), (5, (110, 32)), (5, (100, 101)), (5, (44, 32)), (5, (32, 115)), (4, (116, 105)), (4, (116, 101)), (4, (115, 44)), (4, (114, 105)), (4, (111, 117)), (4, (111, 100)), (4, (110, 116)), (4, (110, 105)), (4, (105, 99)), (4, (104, 97)), (4, (103, 32)), (4, (101, 97)), (4, (100, 32)), (4, (99, 111)), (4, (97, 109)), (4, (85, 110)), (4, (32, 119)), (4, (32, 111)), (4, (32, 102)), (4, (32, 85)), (3, (118, 101)), (3, (116, 115)), (3, (116, 114)), (3, (116, 111)), (3, (114, 116)), (3, (114, 115)), (3, (114, 101)), (3, (111, 102)), (3, (111, 32)), (3, (108, 108)), (

In [23]:
def merge(ids: list, target_pair: tuple, new_token: int) -> list:
    new_ids, i = [], 0
    while i < len(ids):
        if i < len(ids) - 1 and (ids[i], ids[i+1]) == target_pair:
            new_ids.append(new_token)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

res = merge(tokens, max(stats, key=stats.get), 256)
len(res)

596

In [24]:
# Read the whole article
with open('text.txt', 'r', encoding='utf-8') as file:
    text = file.read()
    
print('Length in characters:', len(text))
tokens = list(text.encode('utf-8'))
print('Length in bytes:     ', len(tokens))

Length in characters: 22191
Length in bytes:      23431


In [29]:
vocab_size = 276  # desired final voacabulary size
num_merges = vocab_size - 256
ids = list(tokens)  # deepcopy to save original tokens

merges = {}  # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f'pair {pair} gets merged to --> {idx}')
    ids = merge(ids, pair, idx)
    merges[pair] = idx

print(f'{"-"*55}\nbefore: {len(tokens)} | after: {len(ids)} | compression ratio: {len(tokens) / len(ids):.2f}x')

pair (101, 32) gets merged to --> 256
pair (105, 110) gets merged to --> 257
pair (115, 32) gets merged to --> 258
pair (116, 104) gets merged to --> 259
pair (101, 114) gets merged to --> 260
pair (116, 32) gets merged to --> 261
pair (99, 111) gets merged to --> 262
pair (226, 128) gets merged to --> 263
pair (44, 32) gets merged to --> 264
pair (97, 110) gets merged to --> 265
pair (111, 114) gets merged to --> 266
pair (100, 32) gets merged to --> 267
pair (97, 114) gets merged to --> 268
pair (101, 110) gets merged to --> 269
pair (257, 103) gets merged to --> 270
pair (262, 100) gets merged to --> 271
pair (121, 32) gets merged to --> 272
pair (259, 256) gets merged to --> 273
pair (97, 108) gets merged to --> 274
pair (111, 110) gets merged to --> 275
-------------------------------------------------------
before: 23431 | after: 18487 | compression ratio: 1.27x


In [65]:
# Here's my implementation, not the prettiest in the world, but it works
def my_decode(sequence: list, vocab_size: int, decodings: dict) -> list:
    """
    Sequence of tokens, [0; vocab_size] ---> Sequence of bytes, [0; vocab_size]
    It's like a reverse transform. Running loop decode token pair at iteration.
    """
    # invert decodings dictionary since it's values are unique
    decodings = {v: k for k, v in decodings.items()}  # int -> (int, int)
    
    for token in range(vocab_size, 255, -1):
        state = []
        # iterate over all tokens being decoded
        for elem in sequence:
            # iterate over given sequence
            if elem == token:
                state.extend(decodings[token])
            else:
                state.append(elem)
        # point state of sequence after current token decoding
        sequence = state
    # integers -> bytes -> unicode characters 
    return bytes(sequence).decode('utf-8')

res = my_decode(ids, vocab_size, merges)
print(res == text)

True


In [76]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]  # concatenation of byte objects

def decode(ids: list[int]) -> str:
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors='replace')
    return text

res2 = decode(ids)
print(res2 == text)

True
