<a href="https://colab.research.google.com/github/yhanyi/MLNotebooks/blob/main/Tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT Tokenization Lecture
- Andrej Karpathy's video @ https://youtu.be/zduSFxRajkE?si=lB2uuYe0Rg1K6FgN

In [1]:
string = "안녕하세요 👋 (hello in Korean!)"

Strings are immutable sequences of Unicode code points.

In [2]:
[ord(x) for x in string]

[50504,
 45397,
 54616,
 49464,
 50836,
 32,
 128075,
 32,
 40,
 104,
 101,
 108,
 108,
 111,
 32,
 105,
 110,
 32,
 75,
 111,
 114,
 101,
 97,
 110,
 33,
 41]

UTF-8: Takes every code point and translates to a bytestream, which is between 1-4 bytes.

In [3]:
 string.encode("utf-8")

b'\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94 \xf0\x9f\x91\x8b (hello in Korean!)'

In [4]:
list(string.encode("utf-8"))

[236,
 149,
 136,
 235,
 133,
 149,
 237,
 149,
 152,
 236,
 132,
 184,
 236,
 154,
 148,
 32,
 240,
 159,
 145,
 139,
 32,
 40,
 104,
 101,
 108,
 108,
 111,
 32,
 105,
 110,
 32,
 75,
 111,
 114,
 101,
 97,
 110,
 33,
 41]

Byte Pair Encoding: Compresses byte sequences to a variable amount.

- Iteratively find pairs of tokens that occur the most frequently.
- Identify and replace with a new token to append to vocabulary.
- For example: "aaabdaaabac" can be converted to "ZabdZabac" which is then converted to "ZYdZYac".
- Might even continue recursively to give "XdXac"

In [5]:
text = "My name is Yoshikage Kira. I'm 33 years old. My house is in the northeast section of Morioh, where all the villas are, and I am not married. I work as an employee for the Kame Yu department stores, and I get home every day by 8 PM at the latest. I don't smoke, but I occasionally drink. I'm in bed by 11 PM, and make sure I get eight hours of sleep, no matter what. After having a glass of warm milk and doing about twenty minutes of stretches before going to bed, I usually have no problems sleeping until morning. Just like a baby, I wake up without any fatigue or stress in the morning. I was told there were no issues at my last check-up. I'm trying to explain that I'm a person who wishes to live a very quiet life. I take care not to trouble myself with any enemies, like winning and losing, that would cause me to lose sleep at night. That is how I deal with society, and I know that is what brings me happiness. Although, if I were to fight I wouldn't lose to anyone."
tokens = list(map(int, text.encode("utf-8")))
print(tokens)
print(len(tokens))

[77, 121, 32, 110, 97, 109, 101, 32, 105, 115, 32, 89, 111, 115, 104, 105, 107, 97, 103, 101, 32, 75, 105, 114, 97, 46, 32, 73, 39, 109, 32, 51, 51, 32, 121, 101, 97, 114, 115, 32, 111, 108, 100, 46, 32, 77, 121, 32, 104, 111, 117, 115, 101, 32, 105, 115, 32, 105, 110, 32, 116, 104, 101, 32, 110, 111, 114, 116, 104, 101, 97, 115, 116, 32, 115, 101, 99, 116, 105, 111, 110, 32, 111, 102, 32, 77, 111, 114, 105, 111, 104, 44, 32, 119, 104, 101, 114, 101, 32, 97, 108, 108, 32, 116, 104, 101, 32, 118, 105, 108, 108, 97, 115, 32, 97, 114, 101, 44, 32, 97, 110, 100, 32, 73, 32, 97, 109, 32, 110, 111, 116, 32, 109, 97, 114, 114, 105, 101, 100, 46, 32, 73, 32, 119, 111, 114, 107, 32, 97, 115, 32, 97, 110, 32, 101, 109, 112, 108, 111, 121, 101, 101, 32, 102, 111, 114, 32, 116, 104, 101, 32, 75, 97, 109, 101, 32, 89, 117, 32, 100, 101, 112, 97, 114, 116, 109, 101, 110, 116, 32, 115, 116, 111, 114, 101, 115, 44, 32, 97, 110, 100, 32, 73, 32, 103, 101, 116, 32, 104, 111, 109, 101, 32, 101, 118, 101,

In [6]:
def get_pairs(tokens):
  counts = {}
  for p in zip(tokens, tokens[1:]):
    counts[p] = counts.get(p, 0) + 1
  return counts

pairs = get_pairs(tokens)
print(sorted(((v, k) for k, v in pairs.items()), reverse=True))

[(32, (101, 32)), (24, (116, 32)), (22, (32, 97)), (21, (32, 116)), (18, (105, 110)), (18, (32, 73)), (17, (115, 32)), (17, (32, 119)), (14, (116, 104)), (14, (73, 32)), (13, (121, 32)), (12, (97, 116)), (12, (46, 32)), (12, (44, 32)), (11, (114, 101)), (11, (111, 32)), (11, (104, 101)), (11, (32, 109)), (10, (110, 103)), (10, (97, 110)), (10, (32, 115)), (9, (116, 111)), (9, (111, 114)), (9, (104, 97)), (9, (101, 115)), (9, (101, 114)), (9, (100, 32)), (9, (32, 108)), (9, (32, 105)), (8, (111, 117)), (8, (32, 110)), (8, (32, 98)), (7, (115, 116)), (7, (110, 111)), (7, (110, 32)), (7, (104, 111)), (7, (97, 115)), (7, (32, 111)), (7, (32, 104)), (6, (115, 101)), (6, (110, 100)), (6, (109, 101)), (6, (109, 32)), (6, (107, 101)), (6, (105, 115)), (6, (103, 32)), (6, (102, 32)), (6, (97, 114)), (6, (32, 100)), (5, (119, 105)), (5, (111, 110)), (5, (108, 101)), (5, (108, 97)), (5, (101, 116)), (5, (32, 101)), (4, (119, 104)), (4, (118, 101)), (4, (117, 116)), (4, (117, 115)), (4, (116, 114)

In [7]:
top_pair = max(pairs, key=pairs.get)
top_pair

(101, 32)

In [8]:
def merge(tokens, pair, newtoken):
  # In the token list, replace all pairs with newtoken.
  newtokens = []
  i = 0
  while i < len(tokens):
    if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
      newtokens.append(newtoken)
      i += 2
    else:
      newtokens.append(tokens[i])
      i += 1
  return newtokens

print(merge(tokens, (101, 32), 256))

[77, 121, 32, 110, 97, 109, 256, 105, 115, 32, 89, 111, 115, 104, 105, 107, 97, 103, 256, 75, 105, 114, 97, 46, 32, 73, 39, 109, 32, 51, 51, 32, 121, 101, 97, 114, 115, 32, 111, 108, 100, 46, 32, 77, 121, 32, 104, 111, 117, 115, 256, 105, 115, 32, 105, 110, 32, 116, 104, 256, 110, 111, 114, 116, 104, 101, 97, 115, 116, 32, 115, 101, 99, 116, 105, 111, 110, 32, 111, 102, 32, 77, 111, 114, 105, 111, 104, 44, 32, 119, 104, 101, 114, 256, 97, 108, 108, 32, 116, 104, 256, 118, 105, 108, 108, 97, 115, 32, 97, 114, 101, 44, 32, 97, 110, 100, 32, 73, 32, 97, 109, 32, 110, 111, 116, 32, 109, 97, 114, 114, 105, 101, 100, 46, 32, 73, 32, 119, 111, 114, 107, 32, 97, 115, 32, 97, 110, 32, 101, 109, 112, 108, 111, 121, 101, 256, 102, 111, 114, 32, 116, 104, 256, 75, 97, 109, 256, 89, 117, 32, 100, 101, 112, 97, 114, 116, 109, 101, 110, 116, 32, 115, 116, 111, 114, 101, 115, 44, 32, 97, 110, 100, 32, 73, 32, 103, 101, 116, 32, 104, 111, 109, 256, 101, 118, 101, 114, 121, 32, 100, 97, 121, 32, 98, 121

In [9]:
len(merge(tokens, (101, 32), 256))

943

In [11]:
# Assuming we do 20 merges
vocab_size = 276
num_merges = vocab_size - 256
temp = list(tokens)
merges = {}
for i in range(num_merges):
  newtoken = 256 + i
  pairs = get_pairs(temp)
  pair = max(pairs, key=pairs.get)
  print(f"Merge {pair} into token {newtoken}")
  temp = merge(temp, pair, newtoken)
  merges[pair] = newtoken

Merge (101, 32) into token 256
Merge (116, 32) into token 257
Merge (105, 110) into token 258
Merge (32, 97) into token 259
Merge (32, 116) into token 260
Merge (32, 73) into token 261
Merge (32, 119) into token 262
Merge (121, 32) into token 263
Merge (115, 32) into token 264
Merge (258, 103) into token 265
Merge (111, 114) into token 266
Merge (101, 114) into token 267
Merge (259, 110) into token 268
Merge (104, 97) into token 269
Merge (111, 117) into token 270
Merge (111, 32) into token 271
Merge (46, 261) into token 272
Merge (101, 115) into token 273
Merge (116, 104) into token 274
Merge (97, 115) into token 275


In [13]:
print("Original token length:", len(tokens))
print("Compressed token length:", len(temp))
cr = len(tokens) / len(temp)
print("Compression ratio:", cr)

Original token length: 975
Compressed token length: 728
Compression ratio: 1.3392857142857142


Decoding: getting the original string given a list of tokens.

In [17]:
vocab = {x: bytes([x]) for x in range(256)}
for (p0, p1), x in merges.items():
  vocab[x] = vocab[p0] + vocab[p1]

def decode(ids):
  tokens = b"".join(vocab[x] for x in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

Encoding: getting a list of tokens given a string.

In [20]:
def encode(text):
  tokens = list(text.encode("utf-8"))
  while len(tokens) > 1:
    # Returns the most eligible merge pair that occurs in the tokens.
    pairs = get_pairs(tokens)
    pair = min(pairs, key=lambda p: pairs.get(p, float("inf")))
    if pair not in pairs:
      break
    idx = pairs[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

Forced splits using regex patterns

In [32]:
import regex as re
pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

In [34]:
print(re.findall(pattern, "Hello how's are you'll?!?!?!"))

['Hello', ' how', "'s", ' are', ' you', "'ll", '?!?!?!']


In [36]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.6.0


In [37]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("     hello123's world!?!?!"))
enc = tiktoken.get_encoding("cl100k_base")
print(enc.encode("     hello123's world!?!?!"))

[220, 220, 220, 220, 23748, 10163, 338, 995, 0, 12248, 12248]
[257, 24748, 4513, 596, 1917, 0, 27074, 27074]


In [38]:
!wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe
!wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json

--2024-02-24 04:29:34--  https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe
Resolving openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)... 20.60.179.33
Connecting to openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)|20.60.179.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 456318 (446K) [application/octet-stream]
Saving to: ‘vocab.bpe’


2024-02-24 04:29:35 (514 KB/s) - ‘vocab.bpe’ saved [456318/456318]

--2024-02-24 04:29:35--  https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json
Resolving openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)... 20.60.179.33
Connecting to openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)|20.60.179.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1042301 (1018K) [application/json]
Saving to: ‘encoder.json’


2024-02-24 04:29:37 (911 KB/s) - ‘encoder.json’ saved [1042301/1

In [40]:
import os, json
with open("encoder.json", "r") as f:
  encoder = json.load(f)

with open("vocab.bpe", "r", encoding="utf-8") as f:
  bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]

Special Tokens

In [41]:
# 256 raw byte tokens, 50000 merges, +1 special token
len(encoder)

50257

In [42]:
encoder["<|endoftext|>"]

50256

Sentencepiece
- Can efficiently both train and inference BPE tokenizers. Used in both llama and mistral.
- Tiktoken encodes to utf-8 and BPEs bytes.
- Sentencepiece BPEs the code points and optionally falls back to utf-8 bytes for rare code points (rarity is determined by character_coverage hyperparameter) which then gets translate to byte tokens.

In [43]:
import sentencepiece as spm

In [44]:
with open("toy.txt", "w", encoding="utf-8") as f:
  f.write("My name is Yoshikage Kira. I'm 33 years old. My house is in the northeast section of Morioh, where all the villas are, and I am not married. I work as an employee for the Kame Yu department stores, and I get home every day by 8 PM at the latest. I don't smoke, but I occasionally drink. I'm in bed by 11 PM, and make sure I get eight hours of sleep, no matter what. After having a glass of warm milk and doing about twenty minutes of stretches before going to bed, I usually have no problems sleeping until morning. Just like a baby, I wake up without any fatigue or stress in the morning. I was told there were no issues at my last check-up. I'm trying to explain that I'm a person who wishes to live a very quiet life. I take care not to trouble myself with any enemies, like winning and losing, that would cause me to lose sleep at night. That is how I deal with society, and I know that is what brings me happiness. Although, if I were to fight I wouldn't lose to anyone.")

In [45]:
options = dict(
  # input spec
  input="toy.txt",
  input_format="text",
  # output spec
  model_prefix="tok400", # output filename prefix
  # algorithm spec
  # BPE alg
  model_type="bpe",
  vocab_size=400,
  # normalization
  normalization_rule_name="identity", # ew, turn off normalization
  remove_extra_whitespaces=False,
  input_sentence_size=200000000, # max number of training sentences
  max_sentence_length=4192, # max number of bytes per sentence
  seed_sentencepiece_size=1000000,
  shuffle_input_sentence=True,
  # rare word treatment
  character_coverage=0.99995,
  byte_fallback=True,
  # merge rules
  split_digits=True,
  split_by_unicode_script=True,
  split_by_whitespace=True,
  split_by_number=True,
  max_sentencepiece_length=16,
  add_dummy_prefix=True,
  allow_whitespace_only_pieces=True,
  # special tokens
  unk_id=0, # the UNK token MUST exist
  bos_id=1, # the others are optional, set to -1 to turn off
  eos_id=2,
  pad_id=-1,
  # systems
  num_threads=os.cpu_count(), # use ~all system resources
)

spm.SentencePieceTrainer.train(**options)

In [46]:
sp = spm.SentencePieceProcessor()
sp.load('tok400.model')
vocab = [[sp.id_to_piece(idx), idx] for idx in range(sp.get_piece_size())]
vocab

[['<unk>', 0],
 ['<s>', 1],
 ['</s>', 2],
 ['<0x00>', 3],
 ['<0x01>', 4],
 ['<0x02>', 5],
 ['<0x03>', 6],
 ['<0x04>', 7],
 ['<0x05>', 8],
 ['<0x06>', 9],
 ['<0x07>', 10],
 ['<0x08>', 11],
 ['<0x09>', 12],
 ['<0x0A>', 13],
 ['<0x0B>', 14],
 ['<0x0C>', 15],
 ['<0x0D>', 16],
 ['<0x0E>', 17],
 ['<0x0F>', 18],
 ['<0x10>', 19],
 ['<0x11>', 20],
 ['<0x12>', 21],
 ['<0x13>', 22],
 ['<0x14>', 23],
 ['<0x15>', 24],
 ['<0x16>', 25],
 ['<0x17>', 26],
 ['<0x18>', 27],
 ['<0x19>', 28],
 ['<0x1A>', 29],
 ['<0x1B>', 30],
 ['<0x1C>', 31],
 ['<0x1D>', 32],
 ['<0x1E>', 33],
 ['<0x1F>', 34],
 ['<0x20>', 35],
 ['<0x21>', 36],
 ['<0x22>', 37],
 ['<0x23>', 38],
 ['<0x24>', 39],
 ['<0x25>', 40],
 ['<0x26>', 41],
 ['<0x27>', 42],
 ['<0x28>', 43],
 ['<0x29>', 44],
 ['<0x2A>', 45],
 ['<0x2B>', 46],
 ['<0x2C>', 47],
 ['<0x2D>', 48],
 ['<0x2E>', 49],
 ['<0x2F>', 50],
 ['<0x30>', 51],
 ['<0x31>', 52],
 ['<0x32>', 53],
 ['<0x33>', 54],
 ['<0x34>', 55],
 ['<0x35>', 56],
 ['<0x36>', 57],
 ['<0x37>', 58],
 ['<0x38>', 5

In [47]:
ids = sp.encode("hello 안녕하세요")
print(ids)
print([sp.id_to_piece(idx) for idx in ids])

[360, 264, 299, 364, 360, 239, 152, 139, 238, 136, 152, 240, 152, 155, 239, 135, 187, 239, 157, 151]
['▁', 'he', 'll', 'o', '▁', '<0xEC>', '<0x95>', '<0x88>', '<0xEB>', '<0x85>', '<0x95>', '<0xED>', '<0x95>', '<0x98>', '<0xEC>', '<0x84>', '<0xB8>', '<0xEC>', '<0x9A>', '<0x94>']


## MinBPE Exercise



### Step 1

Write the `BasicTokenizer` class, with the following three core functions:

- `def train(self, text, vocab_size, verbose=False)`
- `def encode(self, text)`
- `def decode(self, ids)`

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file `tests/taylorswift.txt`.

In [49]:
class BasicTokenizer:
  def __init__(self):
    self.merges = {}
    self.pattern = ""
    self.special = {}
    self.vocab = None

  def get_pairs(tokens):
    counts = {}
    for p in zip(tokens, tokens[1:]):
      counts[p] = counts.get(p, 0) + 1
    return counts

  def merge(tokens, pair, newtoken):
    # In the token list, replace all pairs with newtoken.
    newtokens = []
    i = 0
    while i < len(tokens):
      if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
        newtokens.append(newtoken)
        i += 2
      else:
        newtokens.append(tokens[i])
        i += 1
    return newtokens

  def train(self, text, vocab_size, verbose=False):
    num_merges = vocab_size - 256
    encoded_text = text.encode("utf-8")
    tokens = list(encoded_text)
    vocab = {x: bytes([x]) for x in range(256)}
    for i in range(num_merges):
      newtoken = 256 + i
      pairs = get_pairs(tokens)
      pair = max(pairs, key=pairs.get)
      tokens = merge(tokens, pair, newtoken)
      merges[pair] = newtoken
      vocab[newtoken] = vocab[pair[0]] + vocab[pair[1]]
      if verbose:
        print(f"Merge {pair} -> {newtoken}")

    self.merges = merges
    self.vocab = vocab

  def encode(self, text):
    tokens = list(text.encode("utf-8"))
    while len(tokens) > 1:
      pairs = get_pairs(tokens)
      pair = min(pairs, key=lambda p: pairs.get(p, float("inf")))
      if pair not in self.merges:
        break
      idx = self.merges[pair]
      tokens = merge(tokens, pair, idx)
    return tokens

  def decode(self, ids):
    tokens = b"".join(vocab[x] for x in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text

### Step 2

Convert your `BasicTokenizer` into a `RegexTokenizer`, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

```
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```


In [50]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [None]:
# NOT COMPLETE

class RegexTokenizer:
  def __init__(self):
    self.merges = {}
    self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
    self.compiled_pattern = re.compile(self.pattern)
    self.special = {}
    self.inverse = {}

  def get_pairs(tokens):
    counts = {}
    for p in zip(tokens, tokens[1:]):
      counts[p] = counts.get(p, 0) + 1
    return counts

  def merge(tokens, pair, newtoken):
    # In the token list, replace all pairs with newtoken.
    newtokens = []
    i = 0
    while i < len(tokens):
      if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
        newtokens.append(newtoken)
        i += 2
      else:
        newtokens.append(tokens[i])
        i += 1
    return newtokens

  def register_special_tokens(self, special_tokens):
    self.special = special_tokens
    self.inverse = {v: k for k, v in special_tokens.items()}

  def train(self, text, vocab_size, verbose=False):
    num_merges = vocab_size - 256
    encoded_text = text.encode("utf-8")
    chunks = re.findall(self.compiled_pattern, text)
    tokens = [list(chunk.encode("utf-8") for chunk in chunks)]
    merges = {}
    vocab = {x: bytes([x]) for x in range(256)}
    for i in range(num_merges):
      t = {}
      for chunk in tokens:
        pair(chunk, t)
      newtoken = 256 + i
      pairs = get_pairs(tokens)
      pair = max(pairs, key=pairs.get)
      tokens = [merge(chunk, pair, newtoken) for chunk in tokens]
      merges[pair] = newtoken
      vocab[newtoken] = vocab[pair[0]] + vocab[pair[1]]
      if verbose:
        print(f"Merge {pair} -> {newtoken}")

    self.merges = merges
    self.vocab = vocab

  def encode(self, text):
    tokens = list(text.encode("utf-8"))
    while len(tokens) > 1:
      pairs = get_pairs(tokens)
      pair = min(pairs, key=lambda p: pairs.get(p, float("inf")))
      if pair not in pairs:
        break
      idx = pairs[pair]
      tokens = merge(tokens, pair, idx)
    return tokens

  def decode(self, ids):
    tokens = b"".join(vocab[x] for x in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text

### Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both `encode` and `decode`, matching [tiktoken](https://github.com/openai/tiktoken).

```
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
```

Unfortunately, you will run into two issues:

1. It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call `vocab` here, and what they call and store under `enc._mergeable_ranks`. Feel free to copy paste the `recover_merges` function in `minbpe/gpt4.py`, which takes these ranks and returns the raw merges. If you wish to know how this function works, read [this](https://github.com/openai/tiktoken/issues/60) and [this](https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306). Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
2. Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as `byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}`. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.

### Step 4

(Optional, irritating, not obviously useful) Add the ability to handle special tokens. You'll then be able to match the output of tiktoken even when special tokens are present, e.g.:

```
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
```

Without `allowed_special` tiktoken will error.

### Step 5

If you've made it this far, you're now a pro at LLM Tokenization! Sadly, you're not exactly done yet because a lot of LLMs outside of OpenAI (e.g. Llama, Mistral) use [sentencepiece](https://github.com/google/sentencepiece) instead. Primary difference being that sentencepiece runs BPE directly on Unicode code points instead of on UTF-8 encoded bytes. Feel free to explore sentencepiece on your own (good luck, it's not too pretty), and stretch goal if you really experience and suffer from the burden of time, re-write your BPE to be on Unicode code points and match the Llama 2 tokenizer.

### Test Cases for Exercises

In [48]:
import pytest

In [None]:
test_strings = [
    "", # empty string
    "?", # single character
    "hello world!!!? (안녕하세요!) lol123 😉", # fun small string
    "FILE:taylorswift.txt", # FILE: is handled as a special string in unpack()
]
def unpack(text):
    # we do this because `pytest -v .` prints the arguments to console, and we don't
    # want to print the entire contents of the file, it creates a mess. So here we go.
    if text.startswith("FILE:"):
        dirname = os.path.dirname(os.path.abspath(__file__))
        taylorswift_file = os.path.join(dirname, text[5:])
        contents = open(taylorswift_file, "r", encoding="utf-8").read()
        return contents
    else:
        return text

specials_string = """
<|endoftext|>Hello world this is one document
<|endoftext|>And this is another document
<|endoftext|><|fim_prefix|>And this one has<|fim_suffix|> tokens.<|fim_middle|> FIM
<|endoftext|>Last document!!! 👋<|endofprompt|>
""".strip()
special_tokens = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}
llama_text = """
<|endoftext|>The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.
Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]
The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5]
<|fim_prefix|>In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology,<|fim_suffix|> where they come from at the end of time.[6]<|fim_middle|> llamas will return to the water springs and ponds<|endofprompt|>
""".strip()

# -----------------------------------------------------------------------------
# tests

# test encode/decode identity for a few different strings
@pytest.mark.parametrize("tokenizer_factory", [BasicTokenizer, RegexTokenizer, GPT4Tokenizer])
@pytest.mark.parametrize("text", test_strings)
def test_encode_decode_identity(tokenizer_factory, text):
    text = unpack(text)
    tokenizer = tokenizer_factory()
    ids = tokenizer.encode(text)
    decoded = tokenizer.decode(ids)
    assert text == decoded

# test that our tokenizer matches the official GPT-4 tokenizer
@pytest.mark.parametrize("text", test_strings)
def test_gpt4_tiktoken_equality(text):
    text = unpack(text)
    tokenizer = GPT4Tokenizer()
    enc = tiktoken.get_encoding("cl100k_base")
    tiktoken_ids = enc.encode(text)
    gpt4_tokenizer_ids = tokenizer.encode(text)
    assert gpt4_tokenizer_ids == tiktoken_ids

# test the handling of special tokens
def test_gpt4_tiktoken_equality_special_tokens():
    tokenizer = GPT4Tokenizer()
    enc = tiktoken.get_encoding("cl100k_base")
    tiktoken_ids = enc.encode(specials_string, allowed_special="all")
    gpt4_tokenizer_ids = tokenizer.encode(specials_string, allowed_special="all")
    assert gpt4_tokenizer_ids == tiktoken_ids

# reference test to add more tests in the future
@pytest.mark.parametrize("tokenizer_factory", [BasicTokenizer, RegexTokenizer])
def test_wikipedia_example(tokenizer_factory):
    """
    Quick unit test, following along the Wikipedia example:
    https://en.wikipedia.org/wiki/Byte_pair_encoding

    According to Wikipedia, running bpe on the input string:
    "aaabdaaabac"

    for 3 merges will result in string:
    "XdXac"

    where:
    X=ZY
    Y=ab
    Z=aa

    Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values)
    so Z will be 256, Y will be 257, X will be 258.

    So we expect the output list of ids to be [258, 100, 258, 97, 99]
    """
    tokenizer = tokenizer_factory()
    text = "aaabdaaabac"
    tokenizer.train(text, 256 + 3)
    ids = tokenizer.encode(text)
    assert ids == [258, 100, 258, 97, 99]
    assert tokenizer.decode(tokenizer.encode(text)) == text

@pytest.mark.parametrize("special_tokens", [{}, special_tokens])
def test_save_load(special_tokens):
    # take a bit more complex piece of text and train the tokenizer, chosen at random
    text = llama_text
    # create a Tokenizer and do 64 merges
    tokenizer = RegexTokenizer()
    tokenizer.train(text, 256 + 64)
    tokenizer.register_special_tokens(special_tokens)
    # verify that decode(encode(x)) == x
    assert tokenizer.decode(tokenizer.encode(text, "all")) == text
    # verify that save/load work as expected
    ids = tokenizer.encode(text, "all")
    # save the tokenizer (TODO use a proper temporary directory)
    tokenizer.save("test_tokenizer_tmp")
    # re-load the tokenizer
    tokenizer = RegexTokenizer()
    tokenizer.load("test_tokenizer_tmp.model")
    # verify that decode(encode(x)) == x
    assert tokenizer.decode(ids) == text
    assert tokenizer.decode(tokenizer.encode(text, "all")) == text
    assert tokenizer.encode(text, "all") == ids
    # delete the temporary files
    for file in ["test_tokenizer_tmp.model", "test_tokenizer_tmp.vocab"]:
        os.remove(file)

if __name__ == "__main__":
    pytest.main()