In [1]:
import os
import pickle
import heapq
import torch

def merge_ids(ids: torch.Tensor, pair: tuple[int, int], idx: int) -> torch.Tensor:
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and (ids[i].item(), ids[i + 1].item()) == pair:
            new_ids.append(idx)
            i += 2 
        else:
            new_ids.append(ids[i].item())
            i += 1
    return torch.tensor(new_ids, dtype=ids.dtype, device=ids.device)


def get_pair_counts(ids: torch.Tensor) -> dict[tuple[int, int], int]:
    
    pairs = torch.stack((ids[:-1], ids[1:]), dim=1)
    pairs_tuple = [tuple(pair.tolist()) for pair in pairs]
    
    pair_counts = {}
    for pair in pairs_tuple:
        if pair in pair_counts:
            pair_counts[pair] += 1
        else:
            pair_counts[pair] = 1

    return pair_counts

def generate_merges(ids: torch.Tensor, num_merges: int) -> dict[tuple[int, int], int]:
    merges = {}
    i = 256
    count = 0
    pair_count = get_pair_counts(ids)
    max_heap = [(-count, pair) for pair, count in pair_count.items()]
    heapq.heapify(max_heap)
    
    while count < num_merges and max_heap:
        _, merge_pair = heapq.heappop(max_heap)
        merge_pair = tuple(merge_pair)
        ids = merge_ids(ids, merge_pair, i)
        merges[merge_pair] = i
        i += 1
        count += 1
        pair_count = get_pair_counts(ids)
        max_heap = [(-count, pair) for pair, count in pair_count.items()]
        heapq.heapify(max_heap)
    
    return merges

class Tokenizer:
    def __init__(self) -> None:
        self.vocab = {idx : bytes([idx]) for idx in range(256)}
        self.special = None
        self.merges = None
        self.isTrain = False
        self.device = torch.device('cpu')
        
    @classmethod
    def load(cls, path):
        tokenizer_file = os.path.join(path, "tokenizer.pkl")

        if not os.path.exists(path) or not os.path.exists(os.path.join(path, "tokenizer.pkl")):
            raise ValueError(cls.load.__name__ + ": No tokenizer found at the specified directory")

        with open(tokenizer_file, "rb") as pkl_file:
            return pickle.load(pkl_file)
        
    def save(self, path):
        os.makedirs(path, exist_ok=True)
        with open(os.path.join(path, "tokenizer.pkl"), 'wb') as pkl_file:
            pickle.dump(self, pkl_file)
            
    
    def train(self, corpus_path, vocab_size):
        
        with open(corpus_path) as f: corpus = f.read()
        ids = torch.tensor(list(corpus.encode('utf-8', errors='replace')), dtype=torch.int16, device=self.device)
        
        num_merges = max(0, vocab_size - 256)
        self.merges = generate_merges(ids, num_merges)
        
        for pair, new_idx in self.merges.items():
            self.vocab[new_idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.isTrain = True
        
    def encode(self, string: str) -> torch.Tensor:
        
        if self.isTrain == False: print('Warning: Please train the tokenizer first!!!') 
               
        ids = torch.tensor(list(string.encode("utf-8", errors='replace')), dtype=torch.int16, device=self.device)
        for pair, target in self.merges.items():
            ids = merge_ids(ids, pair, target)
        return ids

    def decode(self, tokens: torch.Tensor) -> str:
        if self.isTrain == False: print('Warning: Please train the tokenizer first!!!')
        
        string = ""
        for token in tokens:
            string += self.vocab[token.item].decode('utf-8', errors='replace')
        return string
        

In [3]:
tok = Tokenizer()
tok.train('/Users/sam/Desktop/github/word2vec/src/sample.txt', 276)

In [4]:
dic = {'a':1, 'b':2, 'c':3}
dic['d'] = 4
dic

{'a': 1, 'b': 2, 'c': 3, 'd': 4}

In [7]:
dic['a'] = 5
dic

{'a': 5, 'b': 2, 'c': 3, 'd': 4}

In [8]:
del dic['d']
dic

{'a': 5, 'b': 2, 'c': 3}

In [11]:


graph = {
        'a' : ['b', 'c'],
        'b' : [ 'd'],
        'c' : ['e'],
        'd' : [],
        'e' : ['b'],
        'f' : ['d']  
}

start_node = 'a'

stack = list(start_node)

while len(stack) > 0:
    node = stack.pop()
    print(node)
    for val in graph[node]:
        stack.append(val)


a
c
e
b
d
b
d


In [16]:
graph = {
        'a' : ['b', 'c'],
        'b' : ['d'],
        'c' : ['e'],
        'd' : [],
        'e' : ['b'],
        'f' : ['d']  
}

from collections import deque
def BFS(start_node):
    q = deque(start_node)
    while len(q) > 0:
        node = q.popleft()
        print(node)
        for val in graph[node]:
            q.append(val)



start_node = 'd'
BFS(start_node)

d
