In [4]:
import regex
from abc import ABC
from dataclasses import dataclass
from collections import defaultdict
import random

class Tokenizer(ABC):
    """Abstract interface for a tokenizer."""
    def encode(self, string: str) -> list[int]:
        raise NotImplementedError
    def decode(self, indices: list[int]) -> str:
        raise NotImplementedError
    
def get_compression_ratio(string: str, indices: list[int]) -> float:
    """Given `string` that has been tokenized into `indices`, ."""
    num_bytes = len(bytes(string, encoding="utf-8")) 
    num_tokens = len(indices)                       
    return num_bytes / num_tokens

# Character Tokenizer

In [5]:
assert ord("a") == 97
assert ord("🌍") == 127757

assert chr(97) == "a"
assert chr(127757) == "🌍"

In [6]:
class CharacterTokenizer(Tokenizer):
    """Represent a string as a sequence of Unicode code points."""
    def encode(self, string: str) -> list[int]:
        return list(map(ord, string))
    def decode(self, indices: list[int]) -> str:
        return "".join(map(chr, indices))

tokenizer = CharacterTokenizer()
string = "Hello, 🌍! 你好!"  
print(f"string: {string}")  # @inspect string
indices = tokenizer.encode(string) 
print(f"indices: {indices}")  # @inspect indices
reconstructed_string = tokenizer.decode(indices) 
assert string == reconstructed_string

vocabulary_size = max(indices) + 1  # This is a lower bound
compression_ratio = get_compression_ratio(string, indices)
print(f"compression_ratio: {compression_ratio}")  # @inspect compression_ratio

string: Hello, 🌍! 你好!
indices: [72, 101, 108, 108, 111, 44, 32, 127757, 33, 32, 20320, 22909, 33]
compression_ratio: 1.5384615384615385


# Byte Tokenizer

In [7]:
assert bytes("a", encoding="utf-8") == b"a"
assert bytes("🌍", encoding="utf-8") == b"\xf0\x9f\x8c\x8d"

In [18]:
list(map(int, b'H'))  

[72]

In [None]:
class ByteTokenizer(Tokenizer):
    """Represent a string as a sequence of bytes."""
    def encode(self, string: str) -> list[int]:
        string_bytes = string.encode("utf-8")  
        print(f"string_bytes: {string_bytes}")  # @inspect string_bytes
        indices = list(map(int, string_bytes))  
        return indices
    
    def decode(self, indices: list[int]) -> str:
        string_bytes = bytes(indices)  
        string = string_bytes.decode("utf-8")
        return string
    
tokenizer = ByteTokenizer()
string = "Hello, 🌍! 你好!"  
print(f"string: {string}")  # @inspect string
indices = tokenizer.encode(string)  
print(f"indices: {indices}")  # @inspect indices
reconstructed_string = tokenizer.decode(indices)
assert string == reconstructed_string

vocabulary_size = 256
compression_ratio = get_compression_ratio(string, indices)
assert compression_ratio == 1

string: Hello, 🌍! 你好!
string_bytes: b'Hello, \xf0\x9f\x8c\x8d! \xe4\xbd\xa0\xe5\xa5\xbd!'
indices: [72, 101, 108, 108, 111, 44, 32, 240, 159, 140, 141, 33, 32, 228, 189, 160, 229, 165, 189, 33]


# Word Tokenizer

In [9]:
string = "I'll say supercalifragilisticexpialidocious!"

segments = regex.findall(r"\w+|.", string) 
print(segments)     

# fincier version
GPT2_TOKENIZER_REGEX = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
pattern = GPT2_TOKENIZER_REGEX
segments = regex.findall(pattern, string)
print(segments)

['I', "'", 'll', ' ', 'say', ' ', 'supercalifragilisticexpialidocious', '!']
['I', "'ll", ' say', ' supercalifragilisticexpialidocious', '!']


# BPE Tokenizer

In [10]:
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]     # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index

def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    """Return `indices`, but with all instances of `pair` replaced with `new_index`."""
    new_indcies = []
    i = 0
    while i < len(indices):
        if i + 1 < len(indices) and (indices[i], indices[i + 1]) == pair:
            new_indcies.append(new_index)
            i += 2
        else:
            new_indcies.append(indices[i])
            i += 1
    return new_indcies

class BPETokenizer(Tokenizer):
    """BPE tokenizer given a set of merges and a vocabulary."""
    def __init__(self, params: BPETokenizerParams):
        self.params = params

    def encode(self, string: str) -> list[int]:
        indices = list(map(int, string.encode("utf-8")))
        # Note: this is a very slow implementation
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices
    
    def decode(self, indices: list[int]) -> str:
        bytes_list = list(map(self.params.vocab.get, indices))
        string = b"".join(bytes_list).decode("utf-8")
        return string

## Training the Tokenizer

In [11]:
def train_bpe(string: str, num_merges: int) -> BPETokenizerParams:
    # start with the list of bytes of string
    indices = list(map(int, string.encode('utf-8')))
    merges: dict[tuple[int, int], int] = {}
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}

    for i in range(num_merges):
        # count the number of occurrences of each pair of tokens
        counts = defaultdict(int)
        for index1, index2 in zip(indices, indices[1:]):
            counts[(index1, index2)] += 1

        # find the most common pair
        pair = max(counts, key=counts.get)
        index1, index2 = pair

        # merge the pair
        new_index = 256 + i
        merges[pair] = new_index
        vocab[new_index] = vocab[index1] + vocab[index2]
        indices = merge(indices, pair, new_index)

    return BPETokenizerParams(merges=merges, vocab=vocab)

string = "the cat in the hat"  
params = train_bpe(string, num_merges=3)

In [12]:
tokenizer = BPETokenizer(params)
string = "the quick brown fox"  # @inspect string
indices = tokenizer.encode(string)  # @inspect indices
reconstructed_string = tokenizer.decode(indices)  # @inspect reconstructed_string
assert string == reconstructed_string
