# MinBPE exercise

Implement the tokenizer we describe in the video and exercise

In [23]:
from IPython.display import display, Markdown
import regex as re
from multiprocessing import Pool
import os
from tqdm import tqdm

In [229]:
# some basic tests
test_strings = [
    "", # empty string
    "?", # single character
    "hello world!!!? (안녕하세요!) lol123 😉", # fun small string
    """The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.

Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (also historically spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]""",
"FILE:test/taylorswift.txt"
]

def unpack(text):
    # we do this because `pytest -v .` prints the arguments to console, and we don't
    # want to print the entire contents of the file, it creates a mess. So here we go.
    if text.startswith("FILE:"):
        taylorswift_file = text[5:]
        contents = open(taylorswift_file, "r", encoding="utf-8").read()
        return contents
    else:
        return text

In [230]:
# read content
# source: https://www.kaggle.com/datasets/ukveteran/big-text
with open("train/big.txt", "r") as file:
    text = file.read()

# lets train on 1st million words in this
text = text[:1000000]

len(text)

1000000

In [231]:
display(Markdown(text[:1000]))

The Project Gutenberg EBook of The Adventures of Sherlock Holmes
by Sir Arthur Conan Doyle
(#15 in our series by Sir Arthur Conan Doyle)

Copyright laws are changing all over the world. Be sure to check the
copyright laws for your country before downloading or redistributing
this or any other Project Gutenberg eBook.

This header should be the first thing seen when viewing this Project
Gutenberg file.  Please do not remove it.  Do not change or edit the
header without written permission.

Please read the "legal small print," and other information about the
eBook and Project Gutenberg at the bottom of this file.  Included is
important information about your specific rights and restrictions in
how the file may be used.  You can also find out about how to make a
donation to Project Gutenberg, and how to get involved.


**Welcome To The World of Free Plain Vanilla Electronic Texts**

**eBooks Readable By Both Humans and By Computers, Since 1971**

*****These eBooks Were Prepared By Thousan

In [232]:
def get_stats(tokens):
    pair_freq_stats = {}
    for c1, c2 in zip(tokens, tokens[1:]):
        pair_freq_stats[(c1, c2)] = pair_freq_stats.get((c1, c2), 0) + 1
    return pair_freq_stats

def replace_pair(tokens, pair, pair_idx):
    new_tokens, i = [], 0

    while i < len(tokens):
        if (i < len(tokens) - 1) and ((tokens[i], tokens[i + 1]) == pair):
            new_tokens.append(pair_idx)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1

    return new_tokens

## BasicTokenizer

In [246]:
class BasicTokenizer:
    def __init__(self):
        self.merges = {}
        self.vocab_size = 256
        self.vocab = {idx: bytes([idx]) for idx in range(self.vocab_size)}

    def train(self, text, vocab_size, verbose=False, verbose_iters=None):
        ## Train the BPE tokenizer.
        assert vocab_size > self.vocab_size # to ensure we have a larger vocab size
        num_merges = vocab_size - self.vocab_size

        tokens = list(text.encode("utf-8"))
        old_token_length = len(tokens)

        if verbose:
            start = "Start"
            print(f"{start:.20s} | No. of tokens: {len(tokens):6d} | Vocab size: {self.vocab_size:5d}")

        for i in range(num_merges):
            # first get the stats of the bigrams
            pair_stats = get_stats(tokens)
            # get the max freq bigram
            max_freq_pair = max(pair_stats, key=lambda k: pair_stats.get(k, -float("inf")))
            # now lets create new tokens by replacing this pair with this
            tokens = replace_pair(tokens, max_freq_pair, self.vocab_size)
            # update the running variables
            self.merges[max_freq_pair] = self.vocab_size
            self.vocab[self.vocab_size] = self.vocab[max_freq_pair[0]] + self.vocab[max_freq_pair[1]]
            self.vocab_size += 1

            if verbose:
                if ((i + 1) % verbose_iters == 0) or (i == num_merges - 1):
                    print((f"Iteration {(i + 1):4d} | No. of tokens: {len(tokens):6d} | Merged pair: {str(max_freq_pair):10s} --> {self.merges[max_freq_pair]:5d}"))

        if verbose:
            compression = old_token_length / len(tokens)
            print(f"Compression : {compression:.2f} X")


    def encode(self, text):
        tokens = list(text.encode("utf-8"))

        while len(tokens)>=2:
            # first get the stats of the bigrams
            pair_stats = get_stats(tokens)
            # now check if there is a pair which is merged as per our tokenizer
            merge_pair = min(pair_stats, key=lambda k: self.merges.get(k, float("inf"))) # it will check if we get a merge pair candidate, else returns the first element
            # check if there is actually a match
            if self.merges.get(merge_pair) is None:
                break

            # now replace with the merges token
            tokens = replace_pair(tokens, merge_pair, self.merges[merge_pair])

        return tokens


    def decode(self, ids):
        enc_text = b"".join(self.vocab[id] for id in ids)
        text = enc_text.decode("utf-8", errors="replace")
        return text

In [247]:
# check encoded str length
len(list(text.encode("utf-8")))

1000000

In [9]:
# now lets train for some small iterations
basic_tokenizer = BasicTokenizer()
basic_tokenizer.train(text=text, vocab_size=300, verbose=True, verbose_iters=1)

Start | No. of tokens: 1000000 | Vocab size:   256
Iteration    1 | No. of tokens: 970416 | Merged pair: (101, 32)  -->   256
Iteration    2 | No. of tokens: 948950 | Merged pair: (116, 104) -->   257
Iteration    3 | No. of tokens: 932251 | Merged pair: (100, 32)  -->   258
Iteration    4 | No. of tokens: 916862 | Merged pair: (115, 32)  -->   259
Iteration    5 | No. of tokens: 902588 | Merged pair: (116, 32)  -->   260
Iteration    6 | No. of tokens: 888972 | Merged pair: (105, 110) -->   261
Iteration    7 | No. of tokens: 875719 | Merged pair: (101, 114) -->   262
Iteration    8 | No. of tokens: 863648 | Merged pair: (97, 110)  -->   263
Iteration    9 | No. of tokens: 852577 | Merged pair: (44, 32)   -->   264
Iteration   10 | No. of tokens: 842084 | Merged pair: (257, 256) -->   265
Iteration   11 | No. of tokens: 832073 | Merged pair: (111, 110) -->   266
Iteration   12 | No. of tokens: 823657 | Merged pair: (121, 32)  -->   267
Iteration   13 | No. of tokens: 815311 | Merged p

In [17]:
basic_tokenizer.merges

{(101, 32): 256,
 (116, 104): 257,
 (100, 32): 258,
 (115, 32): 259,
 (116, 32): 260,
 (105, 110): 261,
 (101, 114): 262,
 (97, 110): 263,
 (44, 32): 264,
 (257, 256): 265,
 (111, 110): 266,
 (121, 32): 267,
 (101, 110): 268,
 (111, 117): 269,
 (111, 32): 270,
 (102, 32): 271,
 (111, 114): 272,
 (46, 32): 273,
 (101, 258): 274,
 (111, 271): 275,
 (97, 114): 276,
 (32, 32): 277,
 (114, 101): 278,
 (263, 258): 279,
 (116, 105): 280,
 (116, 270): 281,
 (261, 103): 282,
 (97, 108): 283,
 (104, 105): 284,
 (115, 116): 285,
 (97, 32): 286,
 (104, 97): 287,
 (10, 10): 288,
 (32, 265): 289,
 (97, 259): 290,
 (97, 260): 291,
 (262, 32): 292,
 (101, 115): 293,
 (111, 109): 294,
 (282, 32): 295,
 (73, 32): 296,
 (99, 104): 297,
 (111, 108): 298,
 (261, 32): 299}

In [248]:
# now lets get 1k compressions
basic_tokenizer = BasicTokenizer()
basic_tokenizer.train(text=text, vocab_size=1256, verbose=True, verbose_iters=100)

Start | No. of tokens: 1000000 | Vocab size:   256
Iteration  100 | No. of tokens: 576783 | Merged pair: (117, 112) -->   355
Iteration  200 | No. of tokens: 495579 | Merged pair: (266, 103) -->   455
Iteration  300 | No. of tokens: 452955 | Merged pair: (109, 105) -->   555
Iteration  400 | No. of tokens: 424971 | Merged pair: (119, 423) -->   655
Iteration  500 | No. of tokens: 403227 | Merged pair: (99, 326)  -->   755
Iteration  600 | No. of tokens: 386082 | Merged pair: (282, 264) -->   855
Iteration  700 | No. of tokens: 372312 | Merged pair: (668, 405) -->   955
Iteration  800 | No. of tokens: 360795 | Merged pair: (573, 265) -->  1055
Iteration  900 | No. of tokens: 350883 | Merged pair: (84, 73)   -->  1155
Iteration 1000 | No. of tokens: 342119 | Merged pair: (103, 346) -->  1255
Compression : 2.92 X


In [249]:
# lets visualise the merges
with open("basic_merges.txt", "w") as file:
    sorted_merges = sorted(list(basic_tokenizer.merges.items()), key=lambda x: x[1])
    for k, v in sorted_merges:
        file.write(f"[{(basic_tokenizer.vocab[k[0]]).decode('utf-8', errors='replace')}][{(basic_tokenizer.vocab[k[1]]).decode('utf-8', errors='replace')}]  ---->   {v}\n")

In [251]:
for i, test_string in enumerate(test_strings):
    test_text = unpack(test_string)
    try:
        assert test_text == basic_tokenizer.decode(basic_tokenizer.encode(test_text))
        print(f"Test string: {i} Passed! :)")
    except AssertionError:
        print(f"Test string: {i} Failed! :(")
    # assert test_text == basic_tokenizer.decode(basic_tokenizer.encode(test_text))

Test string: 0 Passed! :)
Test string: 1 Passed! :)
Test string: 2 Passed! :)
Test string: 3 Passed! :)
Test string: 4 Passed! :)


In [252]:
def encode(text):
    tokens = list(text.encode("utf-8"))

    num_merges = 0
    while True:
        # first get the stats of the bigrams
        pair_stats = get_stats(tokens)
        # now check if there is a pair which is merged as per our tokenizer
        merge_pair = min(pair_stats, key=lambda k: basic_tokenizer.merges.get(k, float("inf"))) # it will check if we get a merge pair candidate, else returns the first element
        # check if there is actually a match
        if basic_tokenizer.merges.get(merge_pair) is None:
            break
        num_merges += 1
        # now replace with the merges token
        tokens = replace_pair(tokens, merge_pair, basic_tokenizer.merges[merge_pair])

    print(num_merges)

    return tokens

In [254]:
encode(unpack(test_strings[4]));

883


## RegexTokenizer

In [255]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_PATTERN_REGEX = re.compile(GPT4_SPLIT_PATTERN)

In [256]:
class RegexTokenizer:
    def __init__(self):
        self.merges = {}
        self.vocab_size = 256
        self.vocab = {idx: bytes([idx]) for idx in range(self.vocab_size)}

    def train(self, text, vocab_size, verbose=False, verbose_iters=None):
        ## Train the BPE tokenizer.
        assert vocab_size > self.vocab_size # to ensure we have a larger vocab size
        num_merges = vocab_size - self.vocab_size

        split_text = GPT4_PATTERN_REGEX.findall(text)

        tokens_list = [list(t.encode("utf-8")) for t in split_text]
        old_token_length = sum(len(tokens) for tokens in tokens_list)

        if verbose:
            start = "Start"
            print(f"{start:.20s} | No. of tokens: {old_token_length:6d} | Vocab size: {self.vocab_size:5d}")

        for i in range(num_merges):
            # itrate over each token list
            pair_stats_full = {}
            for tokens in tokens_list:
                # first get the stats of the bigrams
                pair_stats = get_stats(tokens)
                # combine
                for k, v in pair_stats.items():
                    pair_stats_full[k] = pair_stats_full.get(k, 0) + v

            # get the max freq bigram
            max_freq_pair = max(pair_stats_full, key=lambda k: pair_stats_full.get(k, -float("inf")))
            # now lets create new tokens by replacing this pair with this
            new_tokens_list = []
            for tokens in tokens_list:
                new_tokens_list.append(replace_pair(tokens, max_freq_pair, self.vocab_size))
            # update the running variables
            self.merges[max_freq_pair] = self.vocab_size
            self.vocab[self.vocab_size] = self.vocab[max_freq_pair[0]] + self.vocab[max_freq_pair[1]]
            self.vocab_size += 1
            tokens_list = list(new_tokens_list)

            if verbose:
                if ((i + 1) % verbose_iters == 0) or (i == num_merges - 1):
                    new_token_length = sum(len(tokens) for tokens in tokens_list)
                    print((f"Iteration {(i + 1):4d} | No. of tokens: {new_token_length:6d} | Merged pair: {str(max_freq_pair):10s} --> {self.merges[max_freq_pair]:5d}"))

        if verbose:
            new_token_length = sum(len(tokens) for tokens in tokens_list)
            compression = old_token_length / new_token_length
            print(f"Compression : {compression:.2f} X")


    def encode(self, text):
        split_text = GPT4_PATTERN_REGEX.findall(text)

        tokens_list = [list(t.encode("utf-8")) for t in split_text]

        encoded_tokens_list = []

        num_merges = 0

        for tokens in tqdm(tokens_list):
            while len(tokens) >= 2:
                # first get the stats of the bigrams
                pair_stats = get_stats(tokens)
                # now check if there is a pair which is merged as per our tokenizer
                merge_pair = min(pair_stats, key=lambda k: self.merges.get(k, float("inf"))) # it will check if we get a merge pair candidate, else returns the first element
                # check if there is actually a match
                if self.merges.get(merge_pair) is None:
                    break
                
                num_merges += 1
                # now replace with the merges token
                tokens = replace_pair(tokens, merge_pair, self.merges[merge_pair])

            encoded_tokens_list.append(tokens)

        final_tokens = [item for sublist in encoded_tokens_list for item in sublist]
        print(num_merges)
        return final_tokens


    def decode(self, ids):
        enc_text = b"".join(self.vocab[id] for id in ids)
        text = enc_text.decode("utf-8", errors="replace")
        return text

In [27]:
# now lets train for some small iterations
regex_tokenizer = RegexTokenizer()
regex_tokenizer.train(text=text, vocab_size=300, verbose=True, verbose_iters=10)

Start | No. of tokens: 1000000 | Vocab size:   256
Iteration   10 | No. of tokens: 861074 | Merged pair: (101, 114) -->   265
Iteration   20 | No. of tokens: 790967 | Merged pair: (32, 99)   -->   275
Iteration   30 | No. of tokens: 740730 | Merged pair: (259, 103) -->   285
Iteration   40 | No. of tokens: 700386 | Merged pair: (105, 99)  -->   295
Iteration   44 | No. of tokens: 687143 | Merged pair: (32, 104)  -->   299
Compression : 1.46 X


In [257]:
# now lets train for 1k iterations
regex_tokenizer = RegexTokenizer()
regex_tokenizer.train(text=text, vocab_size=1256, verbose=True, verbose_iters=100)

Start | No. of tokens: 1000000 | Vocab size:   256
Iteration  100 | No. of tokens: 580314 | Merged pair: (97, 109)  -->   355
Iteration  200 | No. of tokens: 501016 | Merged pair: (32, 118)  -->   455
Iteration  300 | No. of tokens: 458709 | Merged pair: (453, 110) -->   555
Iteration  400 | No. of tokens: 430368 | Merged pair: (459, 110) -->   655
Iteration  500 | No. of tokens: 410025 | Merged pair: (300, 408) -->   755
Iteration  600 | No. of tokens: 394498 | Merged pair: (103, 103) -->   855
Iteration  700 | No. of tokens: 382035 | Merged pair: (290, 107) -->   955
Iteration  800 | No. of tokens: 371558 | Merged pair: (341, 285) -->  1055
Iteration  900 | No. of tokens: 362520 | Merged pair: (735, 1080) -->  1155
Iteration 1000 | No. of tokens: 354597 | Merged pair: (109, 98)  -->  1255
Compression : 2.82 X


In [258]:
# lets visualise the merges
with open("regex_merges.txt", "w") as file:
    sorted_merges = sorted(list(regex_tokenizer.merges.items()), key=lambda x: x[1])
    for k, v in sorted_merges:
        file.write(f"[{(regex_tokenizer.vocab[k[0]]).decode('utf-8', errors='replace')}][{(regex_tokenizer.vocab[k[1]]).decode('utf-8', errors='replace')}]  ---->   {v}\n")

In [259]:
for i, test_string in enumerate(test_strings):
    test_text = unpack(test_string)
    try:
        assert test_text == regex_tokenizer.decode(regex_tokenizer.encode(test_text))
        print(f"Test string: {i} Passed! :)")
    except AssertionError:
        print(f"Test string: {i} Failed! :(")

0it [00:00, ?it/s]


0
Test string: 0 Passed! :)


100%|██████████| 1/1 [00:00<00:00, 27235.74it/s]


0
Test string: 1 Passed! :)


100%|██████████| 9/9 [00:00<00:00, 84260.57it/s]


8
Test string: 2 Passed! :)


100%|██████████| 139/139 [00:00<00:00, 116415.39it/s]


328
Test string: 3 Passed! :)


100%|██████████| 46195/46195 [00:00<00:00, 293444.57it/s]

86282
Test string: 4 Passed! :)





In [262]:
len(regex_tokenizer.encode(unpack(test_strings[4])))

100%|██████████| 46195/46195 [00:00<00:00, 303990.39it/s]

86282





99286

## GPT-4 tokenizer

In [263]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back

In [264]:
ids

[15339,
 1917,
 12340,
 30,
 320,
 31495,
 230,
 75265,
 243,
 92245,
 16715,
 28509,
 4513,
 57037]

In [37]:
enc._mergeable_ranks # similar to vocab except we map string to id

{b'!': 0,
 b'"': 1,
 b'#': 2,
 b'$': 3,
 b'%': 4,
 b'&': 5,
 b"'": 6,
 b'(': 7,
 b')': 8,
 b'*': 9,
 b'+': 10,
 b',': 11,
 b'-': 12,
 b'.': 13,
 b'/': 14,
 b'0': 15,
 b'1': 16,
 b'2': 17,
 b'3': 18,
 b'4': 19,
 b'5': 20,
 b'6': 21,
 b'7': 22,
 b'8': 23,
 b'9': 24,
 b':': 25,
 b';': 26,
 b'<': 27,
 b'=': 28,
 b'>': 29,
 b'?': 30,
 b'@': 31,
 b'A': 32,
 b'B': 33,
 b'C': 34,
 b'D': 35,
 b'E': 36,
 b'F': 37,
 b'G': 38,
 b'H': 39,
 b'I': 40,
 b'J': 41,
 b'K': 42,
 b'L': 43,
 b'M': 44,
 b'N': 45,
 b'O': 46,
 b'P': 47,
 b'Q': 48,
 b'R': 49,
 b'S': 50,
 b'T': 51,
 b'U': 52,
 b'V': 53,
 b'W': 54,
 b'X': 55,
 b'Y': 56,
 b'Z': 57,
 b'[': 58,
 b'\\': 59,
 b']': 60,
 b'^': 61,
 b'_': 62,
 b'`': 63,
 b'a': 64,
 b'b': 65,
 b'c': 66,
 b'd': 67,
 b'e': 68,
 b'f': 69,
 b'g': 70,
 b'h': 71,
 b'i': 72,
 b'j': 73,
 b'k': 74,
 b'l': 75,
 b'm': 76,
 b'n': 77,
 b'o': 78,
 b'p': 79,
 b'q': 80,
 b'r': 81,
 b's': 82,
 b't': 83,
 b'u': 84,
 b'v': 85,
 b'w': 86,
 b'x': 87,
 b'y': 88,
 b'z': 89,
 b'{': 90,
 b'|': 9

In [38]:
len(enc._mergeable_ranks) #100k merges

100256

In [39]:
vocab = {v:k for k, v in enc._mergeable_ranks.items()}
len(vocab)

100256

In [40]:
vocab[100255]

b' Conveyor'

### Now lets implement the GPT4Tokenizer 
without handling special tokens

In [278]:
class GPT4Tokenizer:
    def __init__(self):
        self.tokenizer_path = "cl100k_base"
        self.enc = tiktoken.get_encoding(self.tokenizer_path)
        self._create_vocab_and_merges()
        # now here is another tricky part.
        # for some reason, the tokens corresponding to individual bytes
        # are permuted in a different order. This is completely non-sensical
        # and probably historical, but therefore we have to deal with it here.
        
        self.byte_shuffle = {i: self.enc._mergeable_ranks[bytes([i])] for i in range(256)} # does map actual byte 0 to the one in this dict
        self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}

    def _get_split_tokens(self, mergeable_ranks, token, max_rank):
        # Idea: Get the most optimal split of the token into exactly two pre-existing tokens
        parts = [bytes([t]) for t in token]
        while True:
            min_id, min_rank = None, None # min_id -> the point to split the token, min_rank -> to find the most optimal split
            for i, pair in enumerate(zip(parts, parts[1:])):
                rank = mergeable_ranks.get(pair[0] + pair[1])

                if (rank is not None) and (min_rank is None or rank < min_rank):
                    min_id = i
                    min_rank = rank

            if (min_rank is None) or (min_rank >= max_rank): # no split was obtained
                break
            # else update the tokens
            assert min_id is not None
            parts = parts[:min_id] + [parts[min_id] + parts[min_id + 1]] + parts[min_id + 2:]

        return parts

    def _create_vocab_and_merges(self):
        self.vocab = {v:k for k, v in self.enc._mergeable_ranks.items()}
        self.vocab_size = len(self.vocab)
        
        self.merges = {}
        for token, idx in self.enc._mergeable_ranks.items():
            if len(token) < 2:
                continue
            parts = self._get_split_tokens(self.enc._mergeable_ranks, token, idx)
            assert len(parts) == 2
            self.merges[(self.enc._mergeable_ranks[parts[0]], self.enc._mergeable_ranks[parts[1]])] = idx


    def train(self, text, vocab_size, verbose=False, verbose_iters=None):
        raise NotImplementedError
    
    def encode(self, text):
        split_text = GPT4_PATTERN_REGEX.findall(text)
        tokens_list = [list(t.encode("utf-8")) for t in split_text]
        # now shuffle the byte tokens
        shuffled_tokens_list = [[self.byte_shuffle[b] for b in tokens] for tokens in tokens_list]
        encoded_tokens_list = []

        num_merges = 0

        for tokens in shuffled_tokens_list:
            while len(tokens) >= 2:
                # first get the stats of the bigrams
                pair_stats = get_stats(tokens)
                # now check if there is a pair which is merged as per our tokenizer
                merge_pair = min(pair_stats, key=lambda k: self.merges.get(k, float("inf"))) # it will check if we get a merge pair candidate, else returns the first element
                # check if there is actually a match
                if merge_pair not in self.merges:
                    break

                # now replace with the merges token
                num_merges += 1
                tokens = replace_pair(tokens, merge_pair, self.merges[merge_pair])

            encoded_tokens_list.append(tokens)

        print(f"Number of merges: {num_merges}")
        final_tokens = [item for sublist in encoded_tokens_list for item in sublist]

        return final_tokens


    def decode(self, ids):
        enc_text = b"".join(self.vocab[id] for id in ids)
        text = enc_text.decode("utf-8", errors="replace")
        return text

In [279]:
gpt4_tokenizer = GPT4Tokenizer()

In [267]:
# test against vocab and the function he had created directly
enc = tiktoken.get_encoding("cl100k_base")
vocab = {v:k for k, v in enc._mergeable_ranks.items()}

# Ref: https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960

def bpe(mergeable_ranks, token, max_rank = None):
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts, parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank): # in case we get a hit on a pair and min rank is not set or min rank is greater
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank): # did not get merge
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts

merges = {}

for token, idx in enc._mergeable_ranks.items():
    if len(token) < 2:
        continue
    parts = bpe(enc._mergeable_ranks, token, idx)
    assert len(parts) == 2
    merges[(enc._mergeable_ranks[parts[0]], enc._mergeable_ranks[parts[1]])] = idx


In [268]:
assert vocab == gpt4_tokenizer.vocab
assert merges == gpt4_tokenizer.merges

In [269]:
# lets visualise the merges
with open("gpt4_merges.txt", "w") as file:
    sorted_merges = sorted(list(gpt4_tokenizer.merges.items()), key=lambda x: x[1])
    for k, v in sorted_merges:
        file.write(f"[{(gpt4_tokenizer.vocab[k[0]]).decode('utf-8', errors='replace')}][{(gpt4_tokenizer.vocab[k[1]]).decode('utf-8', errors='replace')}]  ---->   {v}\n")

In [276]:
for i, test_string in enumerate(test_strings):
    test_text = unpack(test_string)
    try:
        assert test_text == gpt4_tokenizer.decode(gpt4_tokenizer.encode(test_text))
        print(f"Test string: {i} Passed! :)")
    except AssertionError:
        print(f"Test string: {i} Failed! :(")
        break

0it [00:00, ?it/s]


Number of merges: 0
Test string: 0 Passed! :)


100%|██████████| 1/1 [00:00<00:00, 29330.80it/s]


Number of merges: 0
Test string: 1 Passed! :)


100%|██████████| 9/9 [00:00<00:00, 37635.83it/s]


Number of merges: 32
Test string: 2 Passed! :)


100%|██████████| 139/139 [00:00<00:00, 121612.07it/s]


Number of merges: 445
Test string: 3 Passed! :)


100%|██████████| 46195/46195 [00:00<00:00, 251278.24it/s]

Number of merges: 136282
Test string: 4 Passed! :)





In [277]:
# compare our encoding with the internal encode function
for i, test_string in enumerate(test_strings):
    test_text = unpack(test_string)
    try:
        assert enc.encode(test_text) == gpt4_tokenizer.encode(test_text)
        print(f"Test string: {i} Passed! :)")
    except AssertionError:
        print(f"Test string: {i} Failed! :(")
        break

0it [00:00, ?it/s]


Number of merges: 0
Test string: 0 Passed! :)


100%|██████████| 1/1 [00:00<00:00, 26214.40it/s]


Number of merges: 0
Test string: 1 Passed! :)


100%|██████████| 9/9 [00:00<00:00, 18513.36it/s]


Number of merges: 32
Test string: 2 Passed! :)


100%|██████████| 139/139 [00:00<00:00, 100362.93it/s]


Number of merges: 445
Test string: 3 Passed! :)


100%|██████████| 46195/46195 [00:00<00:00, 214573.11it/s]

Number of merges: 136282
Test string: 4 Passed! :)





### Ability to handle special tokens

In [281]:
# lets see the default one which we have now
test_text = "<|endoftext|>hello world"
[gpt4_tokenizer.vocab[i] for i in gpt4_tokenizer.encode(test_text)]

Number of merges: 15


[b'<', b'|', b'endo', b'ft', b'ext', b'|', b'>', b'hello', b' world']

In [284]:
gpt4_tokenizer.encode(test_text)

Number of merges: 15


[27, 91, 8862, 728, 428, 91, 29, 15339, 1917]

In [283]:
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
ids

[100257, 15339, 1917]

In [288]:
# <|endoftext|> is the special token
ENDOFTEXT = "<|endoftext|>"
FIM_PREFIX = "<|fim_prefix|>"
FIM_MIDDLE = "<|fim_middle|>"
FIM_SUFFIX = "<|fim_suffix|>"
ENDOFPROMPT = "<|endofprompt|>"
SPECIAL_TOKENS = [ENDOFTEXT, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, ENDOFPROMPT]

In [292]:
text

'Hello <|endoftext|> world!'

In [293]:
special_pattern = "(" + "|".join(re.escape(k) for k in SPECIAL_TOKENS) + ")"
special_chunks = re.split(special_pattern, text)
special_chunks

['Hello ', '<|endoftext|>', ' world!']

### Now lets handle this

In [345]:
class GPT4Tokenizer:
    def __init__(self, special_token_map={}):
        self.tokenizer_path = "cl100k_base"
        self.enc = tiktoken.get_encoding(self.tokenizer_path)
        self._create_vocab_and_merges()
        self.special_token_map = special_token_map
        self.inv_special_token_map = {v:k for k, v in self.special_token_map.items()}
        self.byte_shuffle = {i: self.enc._mergeable_ranks[bytes([i])] for i in range(256)} # does map actual byte 0 to the one in this dict
        self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}

    def _get_split_tokens(self, mergeable_ranks, token, max_rank):
        # Idea: Get the most optimal split of the token into exactly two pre-existing tokens
        parts = [bytes([t]) for t in token]
        while True:
            min_id, min_rank = None, None # min_id -> the point to split the token, min_rank -> to find the most optimal split
            for i, pair in enumerate(zip(parts, parts[1:])):
                rank = mergeable_ranks.get(pair[0] + pair[1])

                if (rank is not None) and (min_rank is None or rank < min_rank):
                    min_id = i
                    min_rank = rank

            if (min_rank is None) or (min_rank >= max_rank): # no split was obtained
                break
            # else update the tokens
            assert min_id is not None
            parts = parts[:min_id] + [parts[min_id] + parts[min_id + 1]] + parts[min_id + 2:]

        return parts

    def _create_vocab_and_merges(self):
        self.vocab = {v:k for k, v in self.enc._mergeable_ranks.items()}
        self.vocab_size = len(self.vocab)
        
        self.merges = {}
        for token, idx in self.enc._mergeable_ranks.items():
            if len(token) < 2:
                continue
            parts = self._get_split_tokens(self.enc._mergeable_ranks, token, idx)
            assert len(parts) == 2
            self.merges[(self.enc._mergeable_ranks[parts[0]], self.enc._mergeable_ranks[parts[1]])] = idx


    def train(self, text, vocab_size, verbose=False, verbose_iters=None):
        raise NotImplementedError
    
    def encode_ordinary(self, text):
        split_text = GPT4_PATTERN_REGEX.findall(text)

        tokens_list = [list(t.encode("utf-8")) for t in split_text]
        # now shuffle the byte tokens
        shuffled_tokens_list = [[self.byte_shuffle[b] for b in tokens] for tokens in tokens_list]
        encoded_tokens_list = []

        num_merges = 0
        for tokens in shuffled_tokens_list:
            while len(tokens) >= 2:
                # first get the stats of the bigrams
                pair_stats = get_stats(tokens)
                # now check if there is a pair which is merged as per our tokenizer
                merge_pair = min(pair_stats, key=lambda k: self.merges.get(k, float("inf"))) # it will check if we get a merge pair candidate, else returns the first element
                # check if there is actually a match
                if merge_pair not in self.merges:
                    break

                # now replace with the merges token
                tokens = replace_pair(tokens, merge_pair, self.merges[merge_pair])
                num_merges += 1
            encoded_tokens_list.append(tokens)

        final_tokens = [item for sublist in encoded_tokens_list for item in sublist]

        return final_tokens, num_merges

    def encode(self, text, allowed_special="none"):
        # first identify if special tokens should be handles or not
        # if none, encode ordinary :)
        num_merges = 0
        if allowed_special == "none":
            tokenized_text, num_merges = self.encode_ordinary(text)
            print(f"Number of merges = {num_merges}")
            return tokenized_text
        elif allowed_special != "all":
            raise NotImplementedError
        
        # else first handle special token
        SPECIAL_PATTERN = "(" + "|".join(re.escape(k) for k in self.special_token_map.keys()) + ")"
        special_split_text = re.split(SPECIAL_PATTERN, text)

        # next tokenize these individual non special ones and add them to split_text list
        final_tokens = []
        for t in special_split_text:
            if t in self.special_token_map:
                final_tokens.append(self.special_token_map[t])
            else:
                tokenized_text, num_merges_tmp = self.encode_ordinary(t)
                final_tokens.extend(tokenized_text)
                num_merges += num_merges_tmp

        print(f"Number of merges = {num_merges}")
        return final_tokens
        

    def decode(self, ids):
        enc_text_split = []
        for id in ids:
            if id in self.vocab:
                enc_text_split.append(self.vocab[id])
            elif id in self.inv_special_token_map:
                enc_text_split.append(self.inv_special_token_map[id].encode("utf-8"))
        enc_text = b"".join(t for t in enc_text_split)
        text = enc_text.decode("utf-8", errors="replace")
        return text

In [346]:
gpt4_tokenizer_special = GPT4Tokenizer(special_token_map={
    ENDOFTEXT: 100257,
    FIM_PREFIX: 100258,
    FIM_MIDDLE: 100259,
    FIM_SUFFIX: 100260,
    ENDOFPROMPT: 100276,
})

In [347]:
gpt4_tokenizer_special.decode(gpt4_tokenizer_special.encode("Hello <|endoftext|> world!", allowed_special="all"))

Number of merges = 9


'Hello <|endoftext|> world!'

In [348]:
specials_string = """
<|endoftext|>Hello world this is one document
<|endoftext|>And this is another document
<|endoftext|><|fim_prefix|>And this one has<|fim_suffix|> tokens.<|fim_middle|> FIM
<|endoftext|>Last document!!! 👋<|endofprompt|>
""".strip()
specials_string

'<|endoftext|>Hello world this is one document\n<|endoftext|>And this is another document\n<|endoftext|><|fim_prefix|>And this one has<|fim_suffix|> tokens.<|fim_middle|> FIM\n<|endoftext|>Last document!!! 👋<|endofprompt|>'

In [350]:
assert specials_string == gpt4_tokenizer_special.decode(gpt4_tokenizer_special.encode(specials_string, allowed_special="all"))

Number of merges = 85


In [351]:
# test tiktoke equality
assert enc.encode(specials_string, allowed_special="all") == gpt4_tokenizer_special.encode(specials_string, allowed_special="all")

Number of merges = 85


In [353]:
# test easy string as well :)
for i, test_string in enumerate(test_strings):
    test_text = unpack(test_string)
    try:
        assert test_text == gpt4_tokenizer_special.decode(gpt4_tokenizer_special.encode(test_text, allowed_special="none"))
        print(f"Test string: {i} Passed! :)")
    except AssertionError:
        print(f"Test string: {i} Failed! :(")
        break

Number of merges = 0
Test string: 0 Passed! :)
Number of merges = 0
Test string: 1 Passed! :)
Number of merges = 32
Test string: 2 Passed! :)
Number of merges = 445
Test string: 3 Passed! :)
Number of merges = 136282
Test string: 4 Passed! :)
