In [1]:
text = "Hello, I am Yashas"
byte_arr = bytearray(text, "utf-8")
print(byte_arr)

bytearray(b'Hello, I am Yashas')


In [2]:
list(byte_arr)

[72,
 101,
 108,
 108,
 111,
 44,
 32,
 73,
 32,
 97,
 109,
 32,
 89,
 97,
 115,
 104,
 97,
 115]

In [4]:
import tiktoken

gpt2_tokenizer = tiktoken.get_encoding("gpt2")
gpt2_tokenizer.encode("Hello, I am Yashas")

[15496, 11, 314, 716, 575, 1077, 292]

In [None]:
from collections import Counter
from functools import lru_cache

class BPETokenizerSimple:
    def __init__(self):
        self.vocab = {}
        self.inverse_vocab = {}
        self.bpe_merges = {}
    
    def train(self, text, vocab_size, allowed_specials="<|unk|>"):
        preprocessed_text = []
        for i, char in enumerate(text):
            if char == " " and i != 0:
                preprocessed_text.append("Ġ")
            if char != " ":
                preprocessed_text.append(char)
        preprocessed_text = "".join(preprocessed_text)
        
        unique_chars = [chr(i) for i in range(256)]
        unique_chars.extend(char for char in sorted(set(preprocessed_text)) if char not in unique_chars)

        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')
        
        self.vocab = {i: char for i, char in enumerate(unique_chars)}     
        self.inverse_vocab = {char: i for i, char in self.vocab.items()}

        if allowed_specials is not None:
            for special in allowed_specials:
                self.vocab[len(vocab)] = special
                self.inverse_vocab[special] = len(vocab)

        token_ids = [self.inverse_vocab[char] for char in preprocessed_text]

        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(token_ids, mode="most")
            if pair_id is None:
                break
            token_ids = self.replace_pair(token_ids, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

        for (p0, p1), new_id in self.bpe_merges.items():
            merged_token = self.vocab[p0] + self.vocab[p1]
            self.vocab[new_id] = merged_token
            self.inverse_vocab[merged_token] = new_id

    def encode(self, text):
        tokens = []
        words = text.replace("\n", " \n ").split()

        for i, word in enumerate(words):
            if i > 0 and not word.startswith("\n"):
                tokens.append("Ġ" + word)
            else:
                tokens.append(word)
        
        token_ids = []
        for token in tokens:
            if token in self.inverse_vocab:
                token_id = self.inverse_vocab[token]
                token_ids.append(token_id)
            else:
                sub_token_ids = self.tokenize_with_bpe(token)
                token_ids.extend(sub_token_ids)
        
        return token_ids

    def tokenize_with_bpe(self, token):
        token_ids = [self.inverse_vocab.get(char, None) for char in token]

        if None in token_ids:
            missing_chars = [char for char, tid in zip(token, token_ids) if tid is None]
            raise ValueError(f"Token {token} contains characters not in vocabulary: {missing_chars}")
        
        can_merge = True
        while can_merge and len(token_ids) > 1:
            can_merge = False
            new_tokens = []
            i = 0
            while i < len(token_ids) - 1:
                pair = (token_ids[i], token_ids[i + 1])
                if pair in self.bpe_merges:
                    merged_token_id = self.bpe_merges[pair]
                    new_tokens.append(merged_token_id)
                    # Uncomment for educational purposes:
                    # print(f"Merged pair {pair} -> {merged_token_id} ('{self.vocab[merged_token_id]}')")
                    i += 2
                    can_merge = True
                else:
                    new_tokens.append(token_ids[i])
                    i += 1
            if i < len(token_ids):
                new_tokens.append(token_ids[i])
            token_ids = new_tokens

        return token_ids

    def decode(self, token_ids):
        decoded_string = ""
        for token_id in token_ids:
            if token_id not in self.vocab:
                raise ValueError(f"Token ID {token_id} not found in vocabulary")
            token = self.vocab[token_id]
            if token.startswith("Ġ"):
                decoded_string += " "
            decoded_string += token[1:]
        return decoded_string

    @lru_cache(maxsize=None)
    def get_special_token_id(self, token):
        return self.inverse_vocab.get(token, None)
    
    @staticmethod
    def find_freq_pair(token_ids, mode="most"):
        pairs = Counter(zip(token_ids, token_ids[1:]))

        if mode == "most":
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == "least":
            return min(pairs.items(), key=lambda x: x[1])[0]
        else:
            raise ValueError("mode must be 'most' or 'least'")
    
    @staticmethod
    def replace_pair(token_ids, pair_id, new_id):
        dq = deque(token_ids)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                dq.popleft()
            else:
                replaced.append(current)
        return replaced