In [6]:
"""
This notebook, consists my version of reference bpe implementation  vs optimized python implementation vs huggingface tokeninzers implementation
vs tiktoken implementation vs rust based bpe implementation by Karpathy.
"""

# imports 
import os 
from collections import Counter, defaultdict
import numpy as np 
import time
import rustbpe 
import tiktoken 
import pytest

In [11]:
GPT4_SPLIT_REGEX = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]|\s[\r\n]|\s+(?!\S)|\s+"""
# \s denotes space and | denotes OR operator

Simple space based tokenizer implementation for understanding the basics

In [39]:
import regex as re  # IMPORTANT: use 'regex', not 're'

with open("the-verdict.txt", 'r', encoding='utf-8') as f:
    raw_text = f.read()

print(len(raw_text), raw_text[:100])

result = re.findall(GPT4_SPLIT_REGEX, raw_text)

# we can also define an regex for splitting based on spaces and puntuations 
# TEST_REGEX = r"[,.;:]|--|\s"
# result_test = re.split(TEST_REGEX, raw_text)
# result_test = [token for token in result_test if token]
# print(len(result_test), result_test[:20])

# Remove empty strings
result = [token for token in result if token]

print(f"Total tokens after splitting using GPT4 regex: {len(result)}")
print(result[:20])

20479 I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no g
Total tokens after splitting using GPT4 regex: 4068
['I', ' HAD', ' always', ' thought', ' Jack', ' Gisburn', ' rather', ' a', ' cheap', ' genius', '-though', ' a', ' good', ' fellow', ' enough', '-so', ' it', ' was', ' no', ' great']


In [42]:
all_words = sorted(set(result))
all_words.extend(["<|unk|>", "<|endoftext|>"])  # adding unknown token + endoftext token
vocab_size = len(all_words)
print(f"Vocab size: {vocab_size}") # before adding additonal tokens - 1250 

vocab = {token:integer for integer,token in enumerate(all_words)}
for i, item in enumerate(vocab.items()):
    print(item)
    if i >= 50:
        break


# let's define a simple tokenizer based on the above vocab
class SimpleTokenizer:
    def __init__(self, vocab):
        self.str_to_id = vocab
        self.id_to_str = { i:s for s, i in vocab.items() }
    
    def encode(self, text):
        preprocessed = re.findall(GPT4_SPLIT_REGEX, text)
        preprocessed = [token for token in preprocessed if token]
        # replaces unknown words by <unk> token
        preprocessed = [token if token in self.str_to_id else '<|unk|>' for token in preprocessed]

        return [self.str_to_id[token] for token in preprocessed]

    def decode(self, token_ids):
        txt = " ".join([self.id_to_str[token_id] for token_id in token_ids])
        txt = re.sub(r'\s+([,.:;?!"()\'])', r'\1', txt)
        return txt 

Vocab size: 1252
('\n', 0)
(' ', 1)
(' .\n', 2)
(' ."\n', 3)
(' A', 4)
(' Among', 5)
(' And', 6)
(' Arrt', 7)
(' At', 8)
(' Burlington', 9)
(' But', 10)
(' By', 11)
(' Carlo', 12)
(' Chicago', 13)
(' Claude', 14)
(' Croft', 15)
(' Devonshire', 16)
(' Don', 17)
(' Dubarry', 18)
(' Emperors', 19)
(' Florence', 20)
(' For', 21)
(' Gallery', 22)
(' Gideon', 23)
(' Gisburn', 24)
(' Grafton', 25)
(' Greek', 26)
(' Grindle', 27)
(' HAD', 28)
(' Had', 29)
(' He', 30)
(' Her', 31)
(' Hermia', 32)
(' His', 33)
(' I', 34)
(' If', 35)
(' It', 36)
(' Jack', 37)
(' Jove', 38)
(' Just', 39)
(' Lord', 40)
(' Made', 41)
(' Miss', 42)
(' Monte', 43)
(' Mr', 44)
(' Mrs', 45)
(' My', 46)
(' No', 47)
(' Now', 48)
(' Nutley', 49)
(' Of', 50)


In [41]:
tokenizer = SimpleTokenizer(vocab)
test_text = "Hi, I am idiot - which you may already know, you dumbassmf, and my favourite webtoon is the greatest_estate developer!!! lloyd = water, water = good, so lloyd is good."

encoded = tokenizer.encode(test_text)
print(f"Encoded: {encoded}")
decoded = tokenizer.decode(encoded)
print(f"Decoded: {decoded}")

Encoded: [1250, 34, 1250, 1250, 1, 1032, 1064, 613, 1250, 549, 1064, 1250, 125, 648, 1250, 1250, 538, 930, 466, 1250, 1250, 1250, 1, 1020, 1020, 1, 459, 855, 1250, 538, 459]
Decoded: <|unk|>  I <|unk|> <|unk|>    which  you  may <|unk|>  know  you <|unk|>  and  my <|unk|> <|unk|>  is  the  greatest <|unk|> <|unk|> <|unk|>    water  water    good  so <|unk|>  is  good


Reference BPE Implementation

In [54]:
# get stats function

def get_stats(token_ids, counts=None):
    """
    Given a list of token_ids, return a dictionary of pair frequencies, 
    which are then used to merge based on highest frequency.
    """
    counts = {} if counts is None else counts 
    for pair in zip(token_ids, token_ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts 

# def get_stats(token_ids, counts=None):
#     """
#     Given a list of token_ids, return a dictionary of pair frequencies, 
#     which are then used to merge based on highest frequency.
#     """
#     if counts is None:
#         counts = defaultdict(int)
#     else:
#         counts = defaultdict(int, counts)
    
#     for pair in zip(token_ids, token_ids[1:]):
#         counts[pair] += 1
#     return counts

# def get_stats(token_ids):
#     return Counter(zip(token_ids, token_ids[1:]))

# test
test_token_ids = [1, 2, 3, 2, 3, 4, 1, 2, 3]
pair_counts = get_stats(test_token_ids)
print(pair_counts)

{(1, 2): 2, (2, 3): 3, (3, 2): 1, (3, 4): 1, (4, 1): 1}


In [None]:
# merge_token_ids function 

def merge_token_ids(ids, pair, idx):
    """
    In the list of token_ids, merge the given pair and assign it idx
    """
    i = 0
    new_ids = []
    while i < len(ids):
        # if not at the very end of the list and found the pair
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

# test 
pair = (1,2)
idx = 5
print(merge_token_ids(test_token_ids, pair, idx))

[5, 3, 2, 3, 4, 5, 3]


In [65]:
# Simple Regex tokenizer implementation 

class RegexTokenizer:
    
    def __init__(self, pattern):
        # pattern - optional string to override the default GPT4 regex pattern
        # special_tokens - str -> int dict of special tokens used during tokenization
        self.pattern = GPT4_SPLIT_REGEX if pattern is None else pattern 
        self.merges = {}
        self.compiled_pattern = re.compile(self.pattern) # what does compile do? 
        self.special_tokens = {} # like "<|unk|>" or "<|endoftext|>"
        self.inverse_special_tokens = {}
        self.vocab = self._build_vocab()
    
    def _build_vocab(self):
        # vocab is deterministic and is derived from the merges 
        vocab = {idx: bytes([idx]) for idx in range(256)} # initialize the vocab
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def train(self, text, vocab_size, verbose=True):
        assert vocab_size >= 256 
        num_merges = vocab_size - 256 

        # keep track of whether at any point, during the training, the merge is ambigious - counts of pairs are not unique 
        # good technique 

        text_chunks = re.findall(self.compiled_pattern, text)

        ids = [list(chunk.encode("utf-8")) for chunk in text_chunks]

        # iteratively merge the most common pairs to create new tokens until we reach the desired vocab_size or num_merges which is vocab_size - 256
        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes 

        for i in range(num_merges):
            # count the no of times every consecutive pair appears
            stats = {}
            # chunk_ids = token_ids 
            for chunk_ids in ids:
                # passing in stats will update in place, adding up counts - we are doing chunk_wise?
                get_stats(chunk_ids, stats)
            
            # find the pair with the highest count 
            pair = max(stats, key=stats.get) #type: ignore 
            # check if the merge is ambiguous - i.e max value is not unique 
            pair_count = stats[pair]
            pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
            if len(pairs_with_max_count) > 1:
                # we somehow have to break the tie 
                # print the top 10 pairs with their counts
                # print(f"{i} Merge is ambiguous! {pair} has {pair_count} occurrences")
                # for print_pair, print_count in sorted(stats.items(), key=lambda x: x[1], reverse=True)[:10]:
                #     print(f"{print_pair}: {print_count}")
                ambiguous = True 
            # mint a new token: assign it the next available id 
            idx = 256 + i 
            # replace all occurences of pair with idx - done by the merge function 
            ids = [merge_token_ids(chunk_ids, pair, idx) for chunk_ids in ids]
            # save the merge 
            merges[pair] = idx 
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints 
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
        
        # save class variables 
        self.merges = merges # used in encode() - updating from {}
        self.vocab = vocab # used in decode() - updating from the initial 256 dict 

        return ambiguous 
    
    def _encode_chunk(self, text_bytes):
        # return the token_ids 
        # convert all the bytes into integers in range 0...255 
        ids = list(text_bytes)
        while len(ids) >=2:
            # find the pair with the lowest marge index 
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # subtle: if there are no more merges available, the key will 
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list, arbitrarily we can detect this terminating case by a membership check.
            if pair not in self.merges:
                break # nothing else can be merged anymore 
            # otherwise let's merge the best pair
            idx = self.merges[pair]
            ids = merge_token_ids(ids, pair, idx)
        return ids

    def encode_ordinary(self, text):
        # Encoding that ingnores any special tokens 
        # split text into chunks of text by catergories defined in regex pattern 
        text_chunks = re.findall(self.compiled_pattern, text)
        ids = []
        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8")
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        return ids

Faster Python tokenizer - optimized version of the reference tokenizer implementation

In [66]:
# just an clever implementation!
def fast_merge_inplace_token_ids(ids, pair, idx):
    """
    In the list of integers (ids), replace all the concurrent occurences of pair, with the new 
    integer token idx in place.
    """
    # find all the positions where the pair occurs 
    i = 0
    while i < len(ids) - 1:
        if ids[i] == pair[0] and ids[i+1] == pair[1]:
            ids[i] = idx 
            ids.pop(i+1)
        else:
            i += 1
    return ids 