In [None]:
from typing import List, Dict, Tuple

def merge(ids: List[int], pair: Tuple[int, int], idx: int) -> List[int]: 
    # for instance, if ids is [1, 2, 3, 1, 2, 4]
    # and pair is [1, 2] -> idx=7
    # then we output [7, 3, 7, 4]
    n = len(ids)
    new = []
    i = 0
    while i < n:
        if i < n-1 and (ids[i], ids[i+1]) == pair: 
            new.append(idx)
            i += 2
        else: 
            new.append(ids[i])
            i += 1
    return new 

def get_stats(ids: List[int]) -> Dict[Tuple[int, int], int]:
    n = len(ids)
    freqs = {}
    for i in range(n-1): 
        pair = (ids[i], ids[i+1])
        freqs[pair] = freqs.get(pair, 0) + 1
    return freqs

class Tokenizer: 
    def __init__(self): 
        self.merges = {} # (int, int) -> int, these are our "rules" we learn during training and use during inference 
        self.pattern = "" # str  
        self.vocab = self._build_vocab() # takes merges -> vocab using 256 bytes as inirt

    def train(self):
        raise NotImplementedError

    def encode(self): 
        raise NotImplementedError

    def decode(self): 
        raise NotImplementedError

    def _build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (i1, i2), idx in self.merges:
            vocab[idx] = vocab[i1] + vocab[i2]

        return vocab 
        

class BPE(Tokenizer): # prayaas=effort, sujhav=suggestion
    def __init__(self): 
        super().__init__()

    def train(self, text: List[str], vocab_size: int): # corpus Dict[Word: freq] -> populate merges and vocab 
        # use get_stats() and merge() update merges and vocab 
        merges, vocab = {}, self.vocab # (int, int) -> int and {idx:bytes} where former determines the latter
        
        BASE_VOCAB_SIZE = 256
        assert vocab_size >= BASE_VOCAB_SIZE
        num_merges = vocab_size-BASE_VOCAB_SIZE

        ids = list(text.encode("utf-8")) # list of ints representing chars 

        for i in range(num_merges): 
            # find the most popular pair using get_stats
            freqs = get_stats(ids)
            if not freqs:  # If no more pairs to merge
                break
            best_pair = max(freqs, key=freqs.get)            
            new_idx = BASE_VOCAB_SIZE+i
            # then merge using merge() to update merges 
            ids = merge(ids, best_pair, new_idx) 
            # update merges and vocab 
            merges[best_pair] = new_idx 
            vocab[new_idx] = vocab[best_pair[0]] + vocab[best_pair[1]]

        self.merges, self.vocab = merges, vocab

    # readme 
    def encode(self, text: str) -> List[int]: # use all merges
        # Start with raw bytes
        ids = list(text.encode("utf-8"))
        
        # Apply merges iteratively until no more can be applied
        while True:
            # Try to find a mergeable pair
            merged = False
            for i in range(len(ids) - 1): # O(L)
                pair = (ids[i], ids[i+1])
                if pair in self.merges: 
                    # Apply the merge
                    new_ids = ids[:i] + [self.merges[pair]] + ids[i+2:] # O(L)
                    ids = new_ids
                    merged = True
                    break
            
            # If no merges were applied in this pass, we're done
            if not merged:
                break
                
        return ids # O(ML^2), so very slow for long sequences 

    def decode(self, ids: List[int]) -> str: 
        byte_strings = [self.vocab[i] for i in ids]
        return b''.join(byte_strings).decode('utf-8')
