# Module 1 Project 1: Tokenization

Recreate the GPT tokenizer following minBPE exercises

## STEP 1: UTF-8 ENCODING 
- Intro to UTF-8 encoding for text and mapping byte encodings in a list to characters
- Display the original text and the 'tokenized' form

In [None]:
text = """The Tokenizer is a necessary and pervasive component of Large Language Models (LLMs), where it translates between strings and tokens (text chunks). Tokenizers are a completely separate stage of the LLM pipeline: they have their own training sets, training algorithms (Byte Pair Encoding), and after training implement two fundamental functions: encode() from strings to tokens, and decode() back from tokens to strings. In this lecture we build from scratch the Tokenizer used in the GPT series from OpenAI. In the process, we will see that a lot of weird behaviors and problems of LLMs actually trace back to tokenization. We'll go through a number of these issues, discuss why tokenization is at fault, and why someone out there ideally finds a way to delete this stage entirely."""

tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
print(text)
print("Text: " + str(len(text)) + " characters")
print(tokens)
print("Tokens: " + str(len(tokens)))

## STEP 2: GET PAIR COUNTS
- Iterate over the pairs of byte encodings to determine which pair happens the most frequently in the given text
- Output the pairs as keys with their count as the value in a dicitonary

In [None]:
def get_pair_counts(token_ids):
    counts = {}
    for pair in zip(token_ids, token_ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

pairs = get_pair_counts(tokens)
print(pairs)

In [None]:
# Use 'max' to get the most commonly occurring pair
top_pair = max(pairs, key=pairs.get)
print(top_pair)

## STEP 3: ADD NEW TOKEN AS MERGE
- We have 256 indices of pairs from UTF-8 encoding.
- If we want to merge common pairs (liek the one found above) we need to add new indices to the list
- The below method will take a list of token ids, a pair, and a new id, and will merge all occurrences of the pair into the new index

In [None]:
# This method merges a pair into a list of token IDs and assigns it the given index
def new_token(token_ids, pair, index):
    new_ids = []
    i = 0
    while i < len(token_ids):
        if i < len(token_ids) - 1 and token_ids[i] == pair[0] and token_ids[i+1] == pair[1]:
            new_ids.append(index)
            i += 2
        else:
            new_ids.append(token_ids[i])
            i += 1
    return new_ids

print(new_token(tokens, (32, 116), 256))

## STEP 4: PERFORM MERGE AND BUILD VOCAB
- In order to go 'backwards' from tokens to text, we need to build a mapping of byte pairs to characters
- We do this by building a 'vocab' to use as a mapping reference
- The below code will build this vocab after performing the merge for a given # of steps (adding new tokens X amount of times)

In [None]:
vocab_size = 276 # We want to perform 20 merges since our original index count is 256, chosen arbitrarily
num_merges = vocab_size - 256
token_ids = list(tokens)

# Loop to perform the merge. Steps are: get pair counts, find the max occurrence, merge it into a new ID, repeat N times
merges = {}
for i in range(num_merges):
    stats = get_pair_counts(token_ids)
    pair = max(stats, key=stats.get)
    index = 256 + i
    print(f"Merging {pair} into new token {index}")
    token_ids = new_token(token_ids, pair, index)
    merges[pair] = index

# Compression ratio compared to original token length
print(f"Compression ratio {len(tokens) / len(token_ids):.2f}X")

# Building our vocab for decoding
vocab = {index: bytes([index]) for index in range(256)}
for (p0, p1), index in merges.items():
    vocab[index] = vocab[p0] + vocab[p1]

print(vocab)

## STEP 5: ENCODING AND DECODING
- Decoding step is relatively straightforward, just concatenate bytes into a string, and decode from UTF-8
- Be sure to use `errors='replace'` with the call to 'decode' to ensure non-UTF-8 characters get handled appropriately
- The use of our previously built 'vocab' here helps us go from merged token pairs to their character representations.
- Encoding process uses our 'get_pair_counts' and 'new_token' methods to get the 'next' index that was merged into our encoding map, the same order that we used to build our vocab initially.
- Doing this allows us to go from individual token representations to merged pairs as tokens that are consistent with the way we built the initial vocab and merges dictionary

In [None]:
# Given a list of token IDs, return the text representation
def decode(token_ids):
    tokens = b"".join(vocab[index] for index in token_ids)
    text = tokens.decode("utf-8", errors='replace')
    return text

# Given a string of text, encode into tokens using our BPE algorithm
def encode(text):
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_pair_counts(tokens)
        pair = min(stats, key= lambda x: merges.get(x, float("inf")))
        if pair not in merges:
            break
        index = merges[pair]
        tokens = new_token(tokens, pair, index)
    return tokens

print(decode(encode("Hello! this is a test string!")))

## STEP 6: PUT IT ALL TOGETHER
- Put everything above together in a single class, and test out model training and encoding/decoding
- At this point we have re-created a simple tokenizer following minBPE and understand the fundamentals around tokenization.

In [None]:
class Tokenizer:
    """Base class for Tokenizer"""

    def __init__(self):
        self.merges = {} # (int, int) -> int
        self.vocab = {}

    # Get pair counts
    def get_pair_counts(self, token_ids):
        counts = {}
        for pair in zip(token_ids, token_ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

    # Merge pair into new token with index
    def new_token(self, token_ids, pair, index):
        new_ids = []
        i = 0
        while i < len(token_ids):
            if i < len(token_ids) - 1 and token_ids[i] == pair[0] and token_ids[i+1] == pair[1]:
                new_ids.append(index)
                i += 2
            else:
                new_ids.append(token_ids[i])
                i += 1
        return new_ids

    # Train a vocab of a given size form the given text
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        text_bytes = text.encode("utf-8")
        token_ids = list(text_bytes)

        merges = {}
        vocab = {index: bytes([index]) for index in range(256)}

        for i in range(num_merges):
            pair_counts = self.get_pair_counts(token_ids)
            pair = max(pair_counts, key=pair_counts.get)
            index = 256 + i
            token_ids = self.new_token(token_ids, pair, index)
            merges[pair] = index
            vocab[index] = vocab[pair[0]] + vocab[pair[1]]
        
        self.merges = merges
        self.vocab = vocab

    # Given a string of text, encode into BPE representation (list of integers)
    def encode(self, text):
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = self.get_pair_counts(tokens)
            pair = min(stats, key= lambda x: self.merges.get(x, float("inf")))
            if pair not in self.merges:
                break
            index = self.merges[pair]
            tokens = self.new_token(tokens, pair, index)
        return tokens

    # Given BPE representation (list of integers), return the string representation
    def decode(self, token_ids):
        tokens = b"".join(self.vocab[index] for index in token_ids)
        text = tokens.decode("utf-8", errors='replace')
        return text

token = Tokenizer()
text = "Norman Gene Macdonald[i] (October 17, 1959[ii] September 14, 2021) was a Canadian stand-up comedian, actor, and writer whose style was characterized by deadpan delivery and the use of folksy, old-fashioned turns of phrase.[1][2][3] He appeared in many films and was a regular guest on late-night talk shows, where he became known for his chaotic, yet understated style of comedy.[4] Many critics and fellow comedians considered him to be the ultimate talk show guest, while prominent late-night figure David Letterman regarded him as 'the best' of stand-up comedians.[5][6] Earlier in his career, Macdonald's first work on television included writing for such comedies as Roseanne and The Dennis Miller Show. In 1993, Macdonald was hired as a writer and cast member on Saturday Night Live (SNL), spending a total of five seasons on the series, which included anchoring the show's Weekend Update segment for three and a half seasons.[7] He was removed as host of SNL's Weekend Update in 1998, allegedly for relentlessly mocking OJ Simpson during his murder trial, offending producer Don Ohlmeyer who was a close friend of Simpson.[8][9] After being fired from SNL, he wrote and starred in the 1998 film Dirty Work and headlined his own sitcom The Norm Show from 1999 to 2001. Macdonald was also a voice actor, and provided voice acting roles for Family Guy, The Fairly OddParents, Mike Tyson Mysteries, The Orville, and the Dr. Dolittle films."

token.train(text, 286) # We want to do 30 merges, as an example
print(token.merges)
print(token.decode(token.encode("Hello! this is a text string!")))