<a href="https://colab.research.google.com/github/rajdeep-biswas/Building-GPT/blob/master/Building_GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

I started with [Let's build GPT: from scratch, in code, spelled out](https://www.youtube.com/watch?v=kCc8FmEb1nY), but apparently he has just used character level tokenization (ASCII) and the language model also predicts that way.  

Since I want to replicate a more token-level model, I will be following [Let's build the GPT Tokenizer](https://www.youtube.com/watch?v=zduSFxRajkE) first.

# Tokenization

Tokens vocabularies are constructed not at character level but at "chunk level". And these are constructed using [Byte pair encoding](https://en.wikipedia.org/wiki/Byte_pair_encoding) algorithm.

[Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) is the "GPT2 paper" that introduced this algorithm.

Interesting notes about [GPT2 tokens](https://tiktokenizer.vercel.app/?model=gpt2) -
* "the" and " the" (whitespaced) etc. are counted as different tokens.
* "Egg" at the beginning of a sentence is two tokens "E" + "gg", but in a different sentence it is a single token " Egg".
* when it generated code in python, it uses single whitespaces for indentation which is pretty _wasteful_ in terms of how many tokens are bloated and the context window is lost / forgotten. OpenAI deliberately included many more grouped whitespace tokens in GPT4 etc tokenizers to mitigate this problem.

In [None]:
# using ord() has the exact same encoding result as using utf-8 as long as you stick to ASCII characters
test_string = "hello"
list(test_string.encode("utf-8")) == [ord(x) for x in test_string]

True

In [None]:
# not sure what Karpathy is saying makes sense when you're expanding to other language characters and emoji
# is it better because it's reducing the vocab size (but also is increasing the sequence length as a consequence)?
test_string = "안녕하세요 👋 (hello in Korean!)"
list(test_string.encode("utf-8")) == [ord(x) for x in test_string]

False

In [None]:
len(list(test_string.encode("utf-8"))), len([ord(x) for x in test_string])

(39, 26)

In [93]:
"""
byte pair encoding sounds like a recursive algorithm where you replace most frequently occuring byte pairs by minting a new token
and we repeat this until no byte pairs repeat any longer?
a byte is just a pair of two adjacent characters in a string
so, for example, we go from a sequence length of 11 tokens (with vocab length of 4), we go to a sequence length of 5 (but with vocab length of 7)
"""

sample_string = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."
tokens = sample_string.encode('utf-8')
tokens = list(map(int, tokens))
print(tokens)

[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 101, 32, 118, 101, 114, 121, 32, 110, 97, 109, 101, 32, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 101, 32, 105, 110, 116, 111, 32, 116, 104, 101, 32, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 101, 32, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 101, 32, 111, 117, 103, 104, 116, 32, 116

In [103]:
def byte_pair_encode(tokens):

  pair_mapping_dict = dict()
  set_of_used_tokens = set(tokens)

  while True:

    highest_count = 0

    pair_frequency_dict = dict()

    for i in range(len(tokens) - 1):
      byte_pair_tuple = (tokens[i], tokens[i + 1])
      pair_frequency_dict[byte_pair_tuple] = 1 + pair_frequency_dict.get(byte_pair_tuple, 0)

      highest_count = max(highest_count, pair_frequency_dict[byte_pair_tuple])

    if highest_count <= 1:
      break

    new_tokens = []

    i = 0
    while i < len(tokens) - 1:
      byte_pair_tuple = (tokens[i], tokens[i + 1])

      if pair_frequency_dict[byte_pair_tuple] > 1:

        find_new_token = 1
        while find_new_token in set_of_used_tokens:
          find_new_token += 1

        set_of_used_tokens.add(find_new_token)

        pair_mapping_dict[find_new_token] = byte_pair_tuple

        new_tokens.append(find_new_token)
        i += 2

      else:
        new_tokens.append(byte_pair_tuple[0])
        i += 1

    if i == len(tokens) - 1:
      new_tokens.append(tokens[-1])

    tokens = new_tokens

  return tokens, pair_mapping_dict

In [104]:
tokens = sample_string.encode('utf-8')
tokens = list(map(int, tokens))

encoded_tokens, pair_mapping_dict = byte_pair_encode(tokens)

print(tokens)
print(encoded_tokens)
print(pair_mapping_dict)

[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 84, 104, 101, 32, 118, 101, 114, 121, 32, 110, 97, 109, 101, 32, 115, 116, 114, 105, 107, 101, 115, 32, 102, 101, 97, 114, 32, 97, 110, 100, 32, 97, 119, 101, 32, 105, 110, 116, 111, 32, 116, 104, 101, 32, 104, 101, 97, 114, 116, 115, 32, 111, 102, 32, 112, 114, 111, 103, 114, 97, 109, 109, 101, 114, 115, 32, 119, 111, 114, 108, 100, 119, 105, 100, 101, 46, 32, 87, 101, 32, 97, 108, 108, 32, 107, 110, 111, 119, 32, 119, 101, 32, 111, 117, 103, 104, 116, 32, 116

In [105]:
def byte_pair_decode(tokens, pair_mapping_dict):

  while True:

    match_found = False

    for encoded_token in sorted(pair_mapping_dict.keys())[::-1]:

      original_tokens = []

      for token in tokens:
        if token == encoded_token:
          original_tokens.extend(pair_mapping_dict[token])
          match_found = True
        else:
          original_tokens.append(token)

      tokens = original_tokens

      if not match_found:
        break

    return tokens

In [106]:
original_tokens = byte_pair_decode(encoded_tokens, pair_mapping_dict)

print(pair_mapping_dict)
print(original_tokens)

{1: (239, 189), 2: (239, 189), 3: (239, 189), 4: (239, 189), 5: (239, 189), 6: (239, 189), 7: (33, 32), 8: (240, 159), 9: (240, 159), 10: (240, 159), 11: (240, 159), 12: (240, 159), 13: (240, 159), 14: (240, 159), 15: (226, 128), 16: (32, 240), 17: (159, 135), 18: (226, 128), 19: (140, 240), 20: (159, 135), 21: (226, 128), 22: (140, 240), 23: (159, 135), 24: (226, 128), 25: (140, 240), 26: (159, 135), 27: (226, 128), 28: (140, 240), 29: (159, 135), 30: (226, 128), 31: (140, 240), 34: (159, 135), 35: (226, 128), 36: (140, 240), 37: (159, 135), 38: (33, 32), 39: (240, 159), 42: (104, 101), 43: (118, 101), 47: (114, 121), 49: (32, 110), 50: (97, 109), 52: (101, 32), 53: (115, 116), 54: (114, 105), 55: (107, 101), 56: (115, 32), 57: (101, 97), 58: (114, 32), 59: (97, 110), 60: (100, 32), 61: (119, 101), 62: (32, 105), 64: (110, 116), 65: (111, 32), 67: (116, 104), 68: (101, 32), 69: (104, 101), 70: (97, 114), 71: (116, 115), 72: (32, 111), 74: (102, 32), 75: (112, 114), 76: (111, 103), 77:

In [107]:
print(encoded_tokens)
print(tokens)
print(original_tokens)
print(tokens == original_tokens)

[239, 188, 181, 1, 142, 2, 137, 3, 131, 4, 143, 5, 132, 6, 133, 7, 8, 133, 164, 9, 133, 157, 10, 133, 152, 11, 133, 146, 12, 133, 158, 13, 133, 147, 14, 133, 148, 15, 189, 16, 17, 186, 18, 19, 20, 179, 21, 22, 23, 174, 24, 25, 26, 168, 27, 28, 29, 180, 30, 31, 34, 169, 35, 36, 37, 170, 38, 39, 152, 132, 32, 84, 42, 32, 43, 47, 49, 50, 52, 53, 54, 55, 56, 102, 57, 58, 59, 60, 97, 61, 62, 64, 65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 79, 80, 119, 81, 108, 100, 119, 82, 101, 86, 87, 88, 89, 90, 107, 91, 119, 92, 93, 94, 96, 106, 113, 32, 123, 156, 124, 125, 126, 127, 129, 130, 134, 136, 128, 157, 138, 139, 141, 144, 115, 145, 116, 119, 149, 150, 40, 151, 154, 155, 160, 161, 162, 163, 165, 166, 115, 167, 148, 171, 172, 32, 173, 175, 176, 119, 99, 177, 114, 95, 178, 182, 183, 184, 185, 187, 190, 191, 192, 193, 194, 195, 105, 196, 116, 63, 41, 197, 66, 117, 198, 199, 200, 201, 202, 203, 204, 205, 206, 98, 207, 114, 208, 101, 209, 210, 211, 212, 118, 213, 214, 215, 216, 217, 218, 219, 2