In [46]:
# text from https://www.reedbeta.com/blog/programmers-intro-to-unicode/
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! a12a12a12a12a1212😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."
tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens))

In [10]:
def get_stats(tokens, counts = None): #organizes tokens into a lidictionary with respective num of appearances in the text
    counts={}  if counts is None else counts
    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx): # in a sequence of tokens substitute a pair of tokens by a new individual token idx
    newids=[]
    i=0
    while i<len(ids):
        if ids[i] == pair[0] and i < len(ids)-1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        
        else:
            newids.append(ids[i])
            i+=1
    return newids

In [11]:
class BasicTokenizer:
    def __init__(self):      
        self.merges = {}
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
    
    def train(self, text, vocab_size, verbose=False):
        tokens = list(text.encode("utf-8"))        
        
        num_merges = vocab_size - 256
        for i in range(num_merges):
            stats = get_stats(tokens)
            top_pair = max(stats, key=stats.get) #returns most common pair
            idx = 256 + i
            
            if verbose:
                print(f'Merging {top_pair} into {idx}')
            
            tokens = merge(tokens, top_pair, idx)
            
            self.merges[top_pair] = idx
            self.vocab[idx] = self.vocab[top_pair[0]] + self.vocab[top_pair[1]]

    def encode(self, text):
        tokens = list(text.encode("utf-8"))

        while len(tokens)>=2:
            stats = get_stats(tokens)
            pair = min(stats, key=lambda p:self.merges.get(p, float("inf"))) #get the pair that has was merged first
            if pair not in self.merges:
                break #nothing else to merge
            
            idx = self.merges[pair]
            tokens = merge(tokens, pair, idx)
        
        return tokens

    def decode(self, ids):
        tokens = b"".join([self.vocab[idx] for idx in ids])
        text = tokens.decode("utf-8", errors='replace') #translates bytes to characters
        return text
        

In [12]:
tok = BasicTokenizer()
tok.train(text, 270, verbose=True)
print(tok.encode('care '))
print(tok.decode(tok.encode('care ')))

Merging (101, 32) into 256
Merging (240, 159) into 257
Merging (226, 128) into 258
Merging (105, 110) into 259
Merging (115, 32) into 260
Merging (97, 110) into 261
Merging (116, 104) into 262
Merging (257, 133) into 263
Merging (257, 135) into 264
Merging (97, 114) into 265
Merging (239, 189) into 266
Merging (258, 140) into 267
Merging (267, 264) into 268
Merging (101, 114) into 269
[99, 265, 256]
care 


In [13]:
#Implementing GPT-4-like tokenizer
import regex as re
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


In [27]:
class RegexTokenizer(BasicTokenizer):
    def __init__(self):
        super().__init__() 
        self.pattern = GPT4_SPLIT_PATTERN
        self.compiled_pattern = re.compile(self.pattern)
        self.merges = {} # (int, int) -> int
        self.vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes

    def train(self, text, vocab_size, verbose=False):
        
        num_merges = vocab_size - 256

        text_chunks = re.findall(self.compiled_pattern, text)        
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]
        
        for i in range(num_merges):

            stats = {}

            for chunk_ids in ids:
                get_stats(chunk_ids, stats)

            top_pair = max(stats, key=stats.get) #returns most common pair
            idx = 256 + i
            
            if verbose:
                print(f'Merging {top_pair} into {idx}')
            
            ids = [merge(chunk_ids, top_pair, idx) for chunk_ids in ids]
            
            self.merges[top_pair] = idx
            self.vocab[idx] = self.vocab[top_pair[0]] + self.vocab[top_pair[1]]

    def _encode_chunk(self, chunk_bytes):
        ids = list(chunk_bytes)
        while len(ids)>=2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda p:self.merges.get(p, float("inf"))) #get the pair that has was merged first
            if pair not in self.merges:
                break #nothing else to merge
            
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        
        return ids

    def encode(self, text):
        # chunks encoded separately and then merged together
        text_chunks = re.findall(self.compiled_pattern, text)

        ids = []

        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8")
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        
        return ids

    def decode(self, ids):
        tokens = b"".join([self.vocab[idx] for idx in ids])
        text = tokens.decode("utf-8", errors='replace') #translates bytes to characters
        return text

In [50]:
tok = RegexTokenizer()
tok.train(text, 270, verbose=True)
print(tok.encode('care a12!'))
print(tok.decode(tok.encode('care ')))

Merging (240, 159) into 256
Merging (226, 128) into 257
Merging (105, 110) into 258
Merging (32, 97) into 259
Merging (32, 116) into 260
Merging (260, 104) into 261
Merging (256, 133) into 262
Merging (256, 135) into 263
Merging (97, 114) into 264
Merging (239, 189) into 265
Merging (257, 140) into 266
Merging (266, 263) into 267
Merging (101, 114) into 268
Merging (111, 114) into 269
[99, 264, 101, 259, 49, 50, 33]
care 
