In [None]:
def get_stats(ids):
    counts = {}
    for a,b in zip(ids,ids[1:]):
        counts[(a,b)] = counts.get((a,b),0) + 1
    return counts

In [None]:
get_stats([1,2,1,2,1,3])

{(1, 2): 2, (2, 1): 2, (1, 3): 1}

In [None]:
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i<len(ids):
        if i<len(ids)-1 and ids[i]==pair[0] and ids[i+1]==pair[1]:
            new_ids.append(idx)
            i += 2
        else: 
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [None]:
merge([1,2,1,2,1,3], (1,2), 4)

[4, 4, 1, 3]

In [None]:
class BasicTokenizer:
    """implememt train, encode, decode a string"""
    def __int__(self):
        self.merges = {} # int, int -> int
        self.vocab = {} # int -> bytes

    def train(self, text:str, vocab_size:int, verbose=False):
        ids = list(text.encode('utf-8')) # list of ids in range 0, ...,255
        assert vocab_size>255
        n_merges = vocab_size - 256
        merges = {}
        vocab = {i:bytes([i]) for i in range(256)}
        idx = 256
        for _ in range(n_merges):
            stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            ids = merge(ids,pair,idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]]+vocab[pair[1]]
            if verbose: 
                print(f"most frequent pair {pair} {vocab[idx].decode()} merged into {idx}")
            idx+=1
        self.merges = merges
        self.vocab = vocab

    def decode(self,ids)-> str:
        b = b"".join(self.vocab[i] for i in ids)
        return b.decode("utf-8")

    def encode(self,s:str)->list[int]:
        ids = list(s.encode('utf-8'))
        while len(ids)>=2:
            stats = get_stats(ids)
            pair = min(stats, key= lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

In [None]:
tok = BasicTokenizer()

In [None]:
tok.train('hello!', 258,verbose=True)

most frequent pair (104, 101) he merged into 256
most frequent pair (256, 108) hel merged into 257


In [None]:
tok.decode(tok.encode("hey jude"))

'hey jude'

In [None]:
with open('mj.txt', 'r') as f:
    text = f.readlines()

In [None]:
text = ''.join(text)

In [None]:
tok.train(text,1024)

In [None]:
tok.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'