### Create sample and load pre-token

In [1]:
full_text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"

freq_pre_token = {}
for p in full_text.split(" "):
    if p not in freq_pre_token: 
        freq_pre_token[p] = 1
    else:
        freq_pre_token[p] += 1

freq_pre_token
print(f"Get key low: 'low': {freq_pre_token['low']}")
freq_pre_token = {tuple(k.encode('utf-8')): v for k, v in freq_pre_token.items()}

print(f' low encode to: {[i for i in "low".encode("utf-8")]}')
print(f"Get key low: 'low': {freq_pre_token[tuple('low'.encode('utf-8'))]}")
freq_pre_token

Get key low: 'low': 5
 low encode to: [108, 111, 119]
Get key low: 'low': 5


{(108, 111, 119): 5,
 (108, 111, 119, 101, 114): 2,
 (119, 105, 100, 101, 115, 116): 3,
 (110, 101, 119, 101, 115, 116): 6}

In [2]:
# create vocab with 256 bytes values and <|endoftext|>
index_2_vocab = {i: bytes([i]) for i in range(256)}
index_2_vocab[256] = '<|endoftext|>'

index_2_vocab[108] + index_2_vocab[111] + index_2_vocab[119]

b'low'

In [3]:
index_2_vocab[108]

b'l'

In [4]:
index_2_vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [5]:
# create linked list of each pre-token to tracking and update pair
class Node:
    def __init__(self, vocab_index, n_count=0):
        self.vocab_index = vocab_index
        self.next = None
        self.prev = None
        self.n_count = n_count

    def __repr__(self):
        return str(index_2_vocab[self.prev.vocab_index] if self.prev else None) + "-" + str(index_2_vocab[self.vocab_index]) + "-" + str(index_2_vocab[self.next.vocab_index] if self.next else None) + " (" + str(self.n_count) + ")"


freq_linked_list = {}

for k, v in freq_pre_token.items():
    nodes = tuple([Node(i, v) for i in k])
    for pre, n in zip(nodes, nodes[1:]):
        pre.next = n
        n.prev = pre
    freq_linked_list[nodes] = v


freq_linked_list

{(None-b'l'-b'o' (5), b'l'-b'o'-b'w' (5), b'o'-b'w'-None (5)): 5,
 (None-b'l'-b'o' (2),
  b'l'-b'o'-b'w' (2),
  b'o'-b'w'-b'e' (2),
  b'w'-b'e'-b'r' (2),
  b'e'-b'r'-None (2)): 2,
 (None-b'w'-b'i' (3),
  b'w'-b'i'-b'd' (3),
  b'i'-b'd'-b'e' (3),
  b'd'-b'e'-b's' (3),
  b'e'-b's'-b't' (3),
  b's'-b't'-None (3)): 3,
 (None-b'n'-b'e' (6),
  b'n'-b'e'-b'w' (6),
  b'e'-b'w'-b'e' (6),
  b'w'-b'e'-b's' (6),
  b'e'-b's'-b't' (6),
  b's'-b't'-None (6)): 6}

In [6]:
import heapq


def update_pair_count(n1, n2):
    pair_bytes = (index_2_vocab[n1.vocab_index], index_2_vocab[n2.vocab_index])
    if pair_bytes not in pair_count:
        pair_count[pair_bytes] = {'n_count': n1.n_count, 'pair_nodes': [(n1, n2)]}
    else:
        pair_count[pair_bytes]['n_count'] += n1.n_count
        pair_count[pair_bytes]['pair_nodes'].append((n1, n2))

def update_pair_version(n1, n2):
    pair_bytes = (index_2_vocab[n1.vocab_index], index_2_vocab[n2.vocab_index])
    if pair_bytes not in pair_version:
        pair_version[pair_bytes] = 1
    else:
        pair_version[pair_bytes] += 1

# init the pair_count

# pair_count is a dict, with key is pair of bytes, and value is the list of pair node.
# pair_version is a dict, with key is pair of bytes, and value is the version of the pair. Used to check out of date pair.
# pair_max_heap is max heap of pair_count, 

pair_version = {}
pair_count = {}
max_heap = []

for ns, v in freq_linked_list.items():
    for n, next_n in zip(ns, ns[1:]):
        update_pair_count(n, next_n)
        update_pair_version(n, next_n)
        pair_bytes = (index_2_vocab[n.vocab_index], index_2_vocab[next_n.vocab_index])
        # update max_heap
        heapq.heappush(max_heap, (-pair_count[pair_bytes]['n_count'], pair_bytes, pair_version[pair_bytes]))

max_heap

[(-9, (b'e', b's'), 2),
 (-9, (b's', b't'), 2),
 (-8, (b'w', b'e'), 2),
 (-7, (b'o', b'w'), 2),
 (-3, (b'e', b's'), 1),
 (-6, (b'n', b'e'), 1),
 (-7, (b'l', b'o'), 2),
 (-5, (b'o', b'w'), 1),
 (-3, (b'd', b'e'), 1),
 (-2, (b'w', b'e'), 1),
 (-3, (b's', b't'), 1),
 (-2, (b'e', b'r'), 1),
 (-5, (b'l', b'o'), 1),
 (-3, (b'w', b'i'), 1),
 (-6, (b'e', b'w'), 1),
 (-3, (b'i', b'd'), 1)]

In [7]:
merged_pair = []

In [8]:
# get max pair
max_pair = heapq.heappop(max_heap)

In [9]:
# merge max pair -> update pair_count -> update max_heap -> update pair_version
max_pair

(-9, (b'e', b's'), 2)

In [10]:
# update vocab
new_vocab_index = len(index_2_vocab)
index_2_vocab[new_vocab_index] = max_pair[1][0] + max_pair[1][1]

In [11]:
# update merge pair
merged_pair.append((max_pair[1][0], max_pair[1][1], max_pair[1][0] + max_pair[1][1]))

In [12]:
# update pair_count
pair_count

list_max_pair = pair_count[max_pair[1]]
for n1, n2 in list_max_pair['pair_nodes']:
    new_node = Node(new_vocab_index, n1.n_count)
    new_node.next = n2.next
    new_node.prev = n1.prev
    print(n1, n2, new_node)
    # 2 new pair nodes (n1.prev, new_node) and (new_node, n2.next)


b'd'-b'e'-b's' (3) b'e'-b's'-b't' (3) b'd'-b'es'-b't' (3)
b'w'-b'e'-b's' (6) b'e'-b's'-b't' (6) b'w'-b'es'-b't' (6)


In [13]:
new_node

b'w'-b'es'-b't' (6)

In [14]:
pair_bytes

(b's', b't')