In [None]:
def get_stats(ids):
    counts = {}
    for a,b in zip(ids,ids[1:]):
        counts[(a,b)] = counts.get((a,b),0) + 1
    return counts

In [None]:
get_stats([1,2,1,2,1,3])

{(1, 2): 2, (2, 1): 2, (1, 3): 1}

In [None]:
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i<len(ids):
        if i<len(ids)-1 and ids[i]==pair[0] and ids[i+1]==pair[1]:
            new_ids.append(idx)
            i += 2
        else: 
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [None]:
merge([1,2,1,2,1,3], (1,2), 4)

[4, 4, 1, 3]

In [None]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [None]:
import regex as re
pattern = re.compile(GPT4_SPLIT_PATTERN)

In [None]:
chunks = re.findall(pattern, "this is a sentence, and another sentence! then what \n yay", )

In [None]:
chunks = list(map(lambda s:list(s.encode()),chunks))
chunks

[[116, 104, 105, 115],
 [32, 105, 115],
 [32, 97],
 [32, 115, 101, 110, 116, 101, 110, 99, 101],
 [44],
 [32, 97, 110, 100],
 [32, 97, 110, 111, 116, 104, 101, 114],
 [32, 115, 101, 110, 116, 101, 110, 99, 101],
 [33],
 [32, 116, 104, 101, 110],
 [32, 119, 104, 97, 116],
 [32, 10],
 [32, 121, 97, 121]]

In [None]:
stats = {}
for ids in chunks:
    counts = get_stats(ids)
    for pair,v in counts.items():
        stats[pair] = stats.get(pair,0) + v
stats

{(116, 104): 3,
 (104, 105): 1,
 (105, 115): 2,
 (32, 105): 1,
 (32, 97): 3,
 (32, 115): 2,
 (115, 101): 2,
 (101, 110): 5,
 (110, 116): 2,
 (116, 101): 2,
 (110, 99): 2,
 (99, 101): 2,
 (97, 110): 2,
 (110, 100): 1,
 (110, 111): 1,
 (111, 116): 1,
 (104, 101): 2,
 (101, 114): 1,
 (32, 116): 1,
 (32, 119): 1,
 (119, 104): 1,
 (104, 97): 1,
 (97, 116): 1,
 (32, 10): 1,
 (32, 121): 1,
 (121, 97): 1,
 (97, 121): 1}

In [None]:
class RegexTokenizer:
    """implememt train, encode, decode a string"""
    def __init__(self, pattern=None):
        self.merges = {} # int, int -> int
        self.vocab = {} # int -> bytes
        self.pattern = re.compile(GPT4_SPLIT_PATTERN) if pattern is None else pattern

    def train(self, text:str, vocab_size:int, verbose=False):
        chunks = re.findall(self.pattern,text)
        chunks = [list(ch.encode()) for ch in chunks]
        # ids = list(text.encode('utf-8')) # list of ids in range 0, ...,255
        assert vocab_size>255
        n_merges = vocab_size - 256
        merges = {}
        vocab = {i:bytes([i]) for i in range(256)}
        idx = 256
        for _ in range(n_merges):
            stats = {}
            for ids in chunks:
                counts = get_stats(ids)
                for pair,v in counts.items():
                    stats[pair] = stats.get(pair,0) + v
            # stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            for i in range(len(chunks)):
                # ids = merge(ids,pair,idx)
                chunks[i] = merge(chunks[i],pair,idx)
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]]+vocab[pair[1]]
            if verbose: 
                print(f"most frequent pair {pair} {vocab[idx]} merged into {idx}")
            idx+=1
        self.merges = merges
        self.vocab = vocab

    def decode(self,ids)-> str:
        b = b"".join(self.vocab[i] for i in ids)
        return b.decode("utf-8")

    def encode(self,s:str)->list[int]:
        ids = list(s.encode('utf-8'))
        while len(ids)>=2:
            stats = get_stats(ids)
            pair = min(stats, key= lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

In [None]:
tok = RegexTokenizer()
tok.train('hello!', 258, verbose=True)

most frequent pair (104, 101) b'he' merged into 256
most frequent pair (256, 108) b'hel' merged into 257


In [None]:
tok.decode(tok.encode("hey jude"))

'hey jude'

In [None]:
with open('mj.txt', 'r') as f:
    text = f.readlines()

In [None]:
text = ''.join(text)

In [None]:
tok.train(text,1024, verbose=True)

most frequent pair (32, 116) b' t' merged into 256
most frequent pair (105, 110) b'in' merged into 257
most frequent pair (104, 101) b'he' merged into 258
most frequent pair (111, 114) b'or' merged into 259
most frequent pair (97, 110) b'an' merged into 260
most frequent pair (101, 114) b'er' merged into 261
most frequent pair (101, 100) b'ed' merged into 262
most frequent pair (256, 258) b' the' merged into 263
most frequent pair (50, 48) b'20' merged into 264
most frequent pair (111, 110) b'on' merged into 265
most frequent pair (32, 97) b' a' merged into 266
most frequent pair (97, 114) b'ar' merged into 267
most frequent pair (99, 104) b'ch' merged into 268
most frequent pair (32, 74) b' J' merged into 269
most frequent pair (101, 116) b'et' merged into 270
most frequent pair (97, 108) b'al' merged into 271
most frequent pair (32, 65) b' A' merged into 272
most frequent pair (32, 115) b' s' merged into 273
most frequent pair (259, 100) b'ord' merged into 274
most frequent pair (118

In [None]:
tok.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'