In [82]:
import torch
import heapq

def merge_tokens(tokens: torch.Tensor, pair: tuple[int, int], new_token: int) -> torch.Tensor:

    pair_mask = (tokens[:-1] == pair[0]) & (tokens[1:] == pair[1])
    num_pairs = pair_mask.sum().item()
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i].item(), tokens[i + 1].item()) == pair:
            new_tokens.append(new_token)
            i += 2 
        else:
            new_tokens.append(tokens[i].item())
            i += 1
    return torch.tensor(new_tokens, dtype=tokens.dtype, device=tokens.device)


def unmerge_tokens(tokens: torch.Tensor, pair: tuple[int, int], new_token: int) -> torch.Tensor:
    
    pair_tensor = torch.tensor(pair, dtype=tokens.dtype, device=tokens.device)
    mask = tokens == new_token
    num_pairs = mask.sum().item()
    
    expanded_tokens = torch.empty(tokens.size(0) + num_pairs, dtype=tokens.dtype, device=tokens.device)
    expanded_tokens_index = 0
    for i in range(tokens.size(0)):
        if mask[i]:
            expanded_tokens[expanded_tokens_index] = pair_tensor[0]
            expanded_tokens[expanded_tokens_index + 1] = pair_tensor[1]
            expanded_tokens_index += 2
        else:
            expanded_tokens[expanded_tokens_index] = tokens[i]
            expanded_tokens_index += 1
    
    return expanded_tokens


def get_token_pair_counts(tokens: torch.Tensor) -> dict[tuple[int, int], int]:
    pairs = torch.stack((tokens[:-1], tokens[1:]), dim=1)
    pairs_tuple = [tuple(pair.tolist()) for pair in pairs]
    
    token_pair_counts = {}
    for pair in pairs_tuple:
        if pair in token_pair_counts:
            token_pair_counts[pair] += 1
        else:
            token_pair_counts[pair] = 1

    return token_pair_counts

def generate_merges(tokens: torch.Tensor, num_merges: int) -> dict[tuple[int, int], int]:
    merges = {}
    i = 256
    count = 0
    pair_count = get_token_pair_counts(tokens)
    max_heap = [(-count, pair) for pair, count in pair_count.items()]
    heapq.heapify(max_heap)
    
    while count < num_merges and max_heap:
        neg_count, merge_pair = heapq.heappop(max_heap)
        merge_pair = tuple(merge_pair)
        tokens = merge_tokens(tokens, merge_pair, i)
        merges[merge_pair] = i
        i += 1
        count += 1
        pair_count = get_token_pair_counts(tokens)
        max_heap = [(-count, pair) for pair, count in pair_count.items()]
        heapq.heapify(max_heap)
    
    return merges

In [76]:
import os
import pickle


class Tokenizer:
    def __init__(self) -> None:
        self.vocab = None
        self.special = None
        self.merges = None
        self.device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
        
    @classmethod
    def load(cls, path):
        tokenizer_file = os.path.join(path, "tokenizer.pkl")

        if not os.path.exists(path) or not os.path.exists(os.path.join(path, "tokenizer.pkl")):
            raise ValueError(cls.load.__name__ + ": No tokenizer found at the specified directory")

        with open(tokenizer_file, "rb") as pkl_file:
            return pickle.load(pkl_file)
        
    def save(self, path):
        os.makedirs(path, exist_ok=True)
        with open(os.path.join(path, "tokenizer.pkl"), 'wb') as pkl_file:
            pickle.dump(self, pkl_file)
            
    
    def train(self, corpus_path, vocab_size):
        
        with open(corpus_path) as f: corpus = f.read()
        tokens = torch.tensor(list(corpus.encode('utf-8', errors='replace')), dtype=torch.int32, device=self.device)
        
        unique_tokens = torch.unique(tokens)
        vocab = {token.item():i for i,token in enumerate(unique_tokens)}
        num_merges = max(0, vocab_size - len(vocab))
        self.merges = generate_merges(tokens, num_merges)
        
        n = len(vocab)
        for i, new_token in enumerate(self.merges.values()):
            vocab[new_token] = n + i
        self.vocab = vocab
        
    def encode(self, string: str) -> torch.Tensor:
        tokens = torch.tensor(list(string.encode("utf-8", errors='replace')), dtype=torch.int32, device=self.device)
        for pair, target in self.merges.items():
            tokens = merge_tokens(tokens, pair, target)
        return tokens
    
    def decode(self, tokens: torch.Tensor) -> str:
        for pair, target in reversed(list(self.merges.items())):
            tokens = unmerge_tokens(tokens, pair, target)
        string = bytes(tokens.tolist()).decode('utf-8', errors='replace')
        return string
        
        


In [83]:
tokenizer = Tokenizer()

tokenizer.train('/Users/sam/Desktop/github/word2vec/src/sample.txt', 100)

In [22]:
bytes(97).decode('utf-8')

'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'

In [28]:
import torch
tokens = torch.tensor([98,99,97])

string = bytes(tokens.tolist()).decode('utf-8', errors='replace')
string

'bca'

In [29]:
bytes(tokens.tolist())

b'bca'

In [30]:
ord('£')

163