# Check hardware info

In [13]:
!rm serialization_merge.json serialization_vocab.json

CPU

In [21]:
!lscpu

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          46 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   4
  On-line CPU(s) list:    0-3
Vendor ID:                GenuineIntel
  Model name:             Intel(R) Xeon(R) CPU @ 2.20GHz
    CPU family:           6
    Model:                79
    Thread(s) per core:   2
    Core(s) per socket:   2
    Socket(s):            1
    Stepping:             0
    BogoMIPS:             4399.99
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge m
                          ca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht sysc
                          all nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xt
                          opology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq
                           ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt
                           aes xsave avx f16c rdrand hypervisor 

Memory

In [22]:
!free -h

               total        used        free      shared  buff/cache   available
Mem:            31Gi       805Mi        21Gi       2.0Mi       9.3Gi        30Gi
Swap:             0B          0B          0B


GPU

In [23]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


Download training data

In [4]:
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt

--2025-08-27 11:35:41--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
Resolving huggingface.co (huggingface.co)... 18.244.202.68, 18.244.202.73, 18.244.202.60, ...
Connecting to huggingface.co (huggingface.co)|18.244.202.68|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cas-bridge.xethub.hf.co/xet-bridge-us/645e8da96320b0efe40ade7a/02e40cc51c59a4bc6c51bd7bc9acda4316e208745be060558eaf500cd14e9f96?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20250827%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250827T113541Z&X-Amz-Expires=3600&X-Amz-Signature=8268432855661474f64d3b8caf47066646031449023491f314a8f85b2fe45953&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27TinyStoriesV2-GPT4-train.txt%3B+filename%3D%22TinyStoriesV2-GPT4-train.txt%22%3B&response-content-type=text%2Fplain&x-id=GetObject&Expires

# Training

config file

In [24]:
%%file config_kaggle.yaml
special_tokens:
  - "<|endoftext|>"
enable_log: False
log_path: "gold.log"
serialization: True
serialization_vocab_path: "serialization_vocab.json"
serialization_merge_path: "serialization_merge.json"
traindata_path: "/kaggle/working/TinyStoriesV2-GPT4-train.txt"
vocab_size: 10000
gpt2_regex: True
parallel: True

Overwriting config_kaggle.yaml


Tokenizer

In [25]:
%%file tokenizer.py
import regex as re
import json
import logging
from collections import defaultdict
from typing import BinaryIO
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
import heapq

PAT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
PAT = re.compile(PAT_PATTERN)

# Wrapper for heap comparasion, lexical greater
class _Desc:
    # __slots__ = ['x']

    def __init__(self, x):
        self.x = x
    
    def __lt__(self, other):
        """
        Overwrite reversice, lexical greater
        """
        return self.x > other.x
    

class Tokenizer:
    def __init__(
            self, 
            special_tokens: list[str] | None = None, 
            enable_log: bool = False, 
            log_path: str = "",
            serialization: bool = False,
            serialization_vocab_path: str | None = None,
            serialization_merge_path: str | None = None,
        ):
        self.special_tokens = special_tokens or []
        self.next_id = 0
        self.vocab: dict[int, bytes] = {}
        self.merge: list[tuple[bytes, bytes]] = []
        self.enable_log = enable_log
        self._heap = []

        self.serialization = serialization
        self.serialization_vocab_path = serialization_vocab_path
        self.serialization_merge_path = serialization_merge_path

        if enable_log:
            if not log_path:
                raise ValueError("Logging is enable but no log path was provided")
            
            self._set_log_conifg(log_path)

    @staticmethod
    def _set_log_conifg(log_path: str):
        logging.basicConfig(
            filename=log_path, 
            filemode="w", 
            level=logging.INFO, 
            format="%(message)s"
        )
    
    def dump_pair_count(self, pair_count: dict[tuple[bytes], int], merged_token: tuple[tuple[bytes], int], index: int):
        if self.enable_log:
            serial = { str(k): v for k, v in pair_count.items() }
            serial_merged_token = {str(merged_token[0]): merged_token[1]}
            logging.info(json.dumps({"step": index, "pair": serial, "merged": serial_merged_token}, ensure_ascii=False, sort_keys=True))
    
    def init_vocab(self):
        self.vocab = {x: bytes([x]) for x in range (256)}
        token_id_start = 256
    
        for i, special_token in enumerate(self.special_tokens):
            s_bytes = special_token.encode("utf-8")
            special_token_id = token_id_start + i
            self.vocab[special_token_id] = s_bytes
        
        self.next_id = special_token_id + 1
    
    def remove_special_tokens(self, text: str) -> list[str]:
        stokens_escaped = [re.escape(stoken) for stoken in self.special_tokens]
        return re.split("|".join(stokens_escaped), text)
    
    def remove_special_tokens_static(text: str, special_tokens: list[str]) -> list[str]:
        stokens_escaped = [re.escape(stoken) for stoken in special_tokens]
        return re.split("|".join(stokens_escaped), text)
    
    @staticmethod
    def find_chunk_boundaries(
        file: BinaryIO, 
        desired_num_chunks: int, 
        split_special_token: bytes
    ) -> list[int]:
        """
        Chunk the file into parts that can be counted independently.
        May return fewer chunks if the boundaries end up overlapping.
        """
        assert isinstance(split_special_token, bytes), (
            "Must represent special token as a bytestring"
        )
    
        # Get total file size in bytes
        file.seek(0, os.SEEK_END)
        file_size = file.tell()
        file.seek(0)
    
        chunk_size = file_size // desired_num_chunks
    
        # Initial guesses for chunk boundary locations, uniformly spaced
        # Chunks start on previous index, don't include last index
        chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
        chunk_boundaries[-1] = file_size
    
        mini_chunk_size = 4096  # Read ahead by 4k bytes at a time
    
        for bi in range(1, len(chunk_boundaries) - 1):
            initial_position = chunk_boundaries[bi]
            file.seek(initial_position)  # Start at boundary guess
            while True:
                mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk
    
                # If EOF, this boundary should be at the end of the file
                if mini_chunk == b"":
                    chunk_boundaries[bi] = file_size
                    break
    
                # Find the special token in the mini chunk
                found_at = mini_chunk.find(split_special_token)
                if found_at != -1:
                    chunk_boundaries[bi] = initial_position + found_at
                    break
                initial_position += mini_chunk_size
    
        # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
        return sorted(set(chunk_boundaries))
    
    # @staticmethod
    # def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
    #     token_count : dict[tuple[bytes], int] = {}
    
    #     for doc in docs:
    #         pre_tokens = None
    #         # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
    #         if gpt2_regex:
    #             PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    #             pre_tokens = re.finditer(PAT, doc)
    #             pre_tokens = [match.group(0) for match in pre_tokens]
    #         else:
    #             pre_tokens = doc.split()
    
    #         for token in pre_tokens:
    #             bytes_token = token.encode("utf-8")
                
    #             tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #             token_count[tuple_bytes_token] = token_count.get(tuple_bytes_token, 0) + 1
            
    #     return token_count

    # @staticmethod
    # def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
    #     token_count : dict[tuple[bytes], int] = {}
    #     # token_count_get = token_count.get
    
    #     for doc in docs:
    #         pre_tokens = None
    #         # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
    #         if gpt2_regex:
    #             # PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    #             pre_tokens = PAT.finditer(doc)
    #             pre_tokens = [match.group(0) for match in pre_tokens]
    #         else:
    #             pre_tokens = doc.split()
    
    #         for token in pre_tokens:
    #             bytes_token = token.encode("utf-8")
                
    #             tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #             token_count[tuple_bytes_token] = token_count.get(tuple_bytes_token, 0) + 1
            
    #     return token_count

    # @staticmethod
    # def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
    #     token_count : dict[tuple[bytes], int] = {}
    #     token_count_get = token_count.get
    
    #     for doc in docs:
    #         pre_tokens = None
    #         # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
    #         if gpt2_regex:
    #             # PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    #             pre_tokens = PAT.finditer(doc)
    #             pre_tokens = [match.group(0) for match in pre_tokens]
    #         else:
    #             pre_tokens = doc.split()
    
    #         for token in pre_tokens:
    #             bytes_token = token.encode("utf-8")
                
    #             tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #             token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
            
    #     return token_count
    
    # #OPED
    # @staticmethod
    # def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
    #     token_count : dict[tuple[bytes], int] = {}
    #     token_count_get = token_count.get
    
    #     for doc in docs:
    #         # pre_tokens = None
    #         # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
    #         if gpt2_regex:
    #             for token in PAT.finditer(doc):
    #                 token_str = token.group(0)
    #                 bytes_token = token_str.encode("utf-8")

    #                 tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #                 token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
    #         else:
    #             for token in doc.split():
    #                 bytes_token = token.encode("utf-8")

    #                 tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #                 token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1

    #         # if gpt2_regex:
    #         #     # PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    #         #     pre_tokens = PAT.finditer(doc)
    #         #     pre_tokens = [match.group(0) for match in pre_tokens]
    #         # else:
    #         #     pre_tokens = doc.split()
    
    #         # for token in pre_tokens:
    #         #     bytes_token = token.encode("utf-8")
                
    #         #     tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #         #     token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
            
    #     return token_count
    
    # @staticmethod
    # def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
    #     token_count : dict[tuple[bytes], int] = {}
    #     token_count_get = token_count.get
    
    #     for doc in docs:
    #         # pre_tokens = None
    #         # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
    #         if gpt2_regex:
    #             for token in PAT.finditer(doc):
    #                 token_str = token.group(0)
    #                 bytes_token = token_str.encode("utf-8")

    #                 length_bytes_token = len(bytes_token)

    #                 tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (length_bytes_token))
    #                 token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
    #         else:
    #             for token in doc.split():
    #                 bytes_token = token.encode("utf-8")

    #                 tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #                 token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1

    #         # if gpt2_regex:
    #         #     # PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    #         #     pre_tokens = PAT.finditer(doc)
    #         #     pre_tokens = [match.group(0) for match in pre_tokens]
    #         # else:
    #         #     pre_tokens = doc.split()
    
    #         # for token in pre_tokens:
    #         #     bytes_token = token.encode("utf-8")
                
    #         #     tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
    #         #     token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
            
    #     return token_count
    
    @staticmethod
    def pretokenize_and_count(docs: list[str], gpt2_regex: bool = False) -> dict[tuple[bytes], int]:
        token_count : dict[tuple[bytes], int] = {}
        token_count_get = token_count.get
        
        # Build cache mapping pretoken bytes format to tuple byte format 
        cache :dict[bytes, tuple[bytes]] = {}
        cache_get = cache.get
    
        for doc in docs:
            # pre_tokens = None
            # Use a regex-based pre-tokenizer (used by GPT-2; Radford et al., 2019)
            if gpt2_regex:
                for token in PAT.finditer(doc):
                    token_str = token.group(0)
                    bytes_token = token_str.encode("utf-8")

                    pretoken_tuplebytes = cache_get(bytes_token)

                    if pretoken_tuplebytes is None:
                        # Build cache
                        length_bytes_token = len(bytes_token)
                        pretoken_tuplebytes = tuple(bytes_token[i : i+1] for i in range (length_bytes_token))
                        cache[bytes_token] = pretoken_tuplebytes
                                        
                    token_count[pretoken_tuplebytes] = token_count_get(pretoken_tuplebytes, 0) + 1
            else:
                for token in doc.split():
                    bytes_token = token.encode("utf-8")

                    tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
                    token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1

            # if gpt2_regex:
            #     # PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
            #     pre_tokens = PAT.finditer(doc)
            #     pre_tokens = [match.group(0) for match in pre_tokens]
            # else:
            #     pre_tokens = doc.split()
    
            # for token in pre_tokens:
            #     bytes_token = token.encode("utf-8")
                
            #     tuple_bytes_token = tuple(bytes_token[i : i+1] for i in range (len(bytes_token)))
            #     token_count[tuple_bytes_token] = token_count_get(tuple_bytes_token, 0) + 1
            
        return token_count
    
    def pretokenize_and_count_task(path: str, start: int, end :int, special_token: list[str], gpt2_regex: bool) -> dict[tuple[bytes], int]:
        # Get chunk
        with open(path, "rb") as f:
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
        
        # Remove special token
        docs = Tokenizer.remove_special_tokens_static(chunk, special_token)

        # Build pretoken counts dict
        pretoken_counts = Tokenizer.pretokenize_and_count(docs, gpt2_regex)

        return pretoken_counts
    
    def pretokenize(self, input_path: str, gpt2_regex: bool) -> dict[tuple[bytes], int]:
        # Read training data
        with open(input_path, "r", encoding="utf-8") as f:
            text = f.read()
    
        # Removing special tokens
        docs = self.remove_special_tokens(text)

        # Pre-tokenization
        pretokens = self.pretokenize_and_count(docs, gpt2_regex)

        return pretokens
    
    def pretokenize_parallel(self, path: str, gpt2_regex: bool) -> dict[tuple[bytes], int]:
        # Get logical core number
        core_num = os.cpu_count()

        # Get boundaries of chunks
        # TODO: special token should not hardcode
        with open(path, "rb") as f:
            boundaries = Tokenizer.find_chunk_boundaries(
                f, core_num, "<|endoftext|>".encode("utf-8"))
        
        # Parallel pretoken
        with ProcessPoolExecutor(max_workers=core_num) as executor:
            futures = [executor.submit(Tokenizer.pretokenize_and_count_task, path, start, end, self.special_tokens, gpt2_regex) for start, end in zip(boundaries[:-1], boundaries[1:])]
        
        pretoken_counts = {}
        pretoken_counts_get = pretoken_counts.get

        for future in as_completed(futures):
            for pretoken, count in future.result().items():
                pretoken_counts[pretoken] = pretoken_counts_get(pretoken, 0) + count
        
        return pretoken_counts
  
    # @staticmethod
    # def build_paircount_and_cache(
    #     pretokens : dict[tuple[bytes, ...], int]
    # ) -> tuple[
    #     dict[tuple[bytes], int], 
    #     dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
    #     ]:
    
    #     pair_count: dict[tuple[bytes], int] = {}
    #     cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]] = defaultdict(set)
    
    #     for k, v in pretokens.items():
    #         for i in range(len(k)-1):
    #             pair_count[k[i : i+2]] = pair_count.get(k[i : i+2], 0) + v
    
    #             cache[k[i : i+2]].add((k, v))
    
    #     return pair_count, cache


    # @staticmethod
    # def build_paircount_and_cache(
    #     pretokens : dict[tuple[bytes, ...], int]
    # ) -> tuple[
    #     dict[tuple[bytes], int], 
    #     dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
    #     ]:
    
    #     pair_count: dict[tuple[bytes], int] = {}
    #     cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]] = defaultdict(set)
    
    #     for k, v in pretokens.items():
    #         for i in range(len(k)-1):
    #             pair = k[i : i+2]

    #             pair_count[pair] = pair_count.get(pair, 0) + v
    
    #             cache[pair].add((k, v))
    
    #     return pair_count, cache
    
    # @staticmethod
    # def build_paircount_and_cache(
    #     pretokens : dict[tuple[bytes, ...], int]
    # ) -> tuple[
    #     dict[tuple[bytes], int], 
    #     dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
    #     ]:
    
    #     pair_count: dict[tuple[bytes], int] = {}
    #     cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]] = defaultdict(set)

    #     pc_get = pair_count.get
    
    #     for k, v in pretokens.items():
    #         for i in range(len(k)-1):
    #             pair = k[i : i+2]

    #             # pair_count[pair] = pair_count.get(pair, 0) + v
    #             pair_count[pair] = pc_get(pair, 0) + v

    #             cache[pair].add((k, v))
    
    #     return pair_count, cache
    
    @staticmethod
    def build_paircount_and_cache(
        pretokens : dict[tuple[bytes, ...], int]
    ) -> tuple[
        dict[tuple[bytes], int], 
        dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
        ]:
    
        pair_count: dict[tuple[bytes], int] = {}
        cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]] = defaultdict(set)

        pc_get = pair_count.get
    
        for k, v in pretokens.items():
            length_k = len(k)-1
            for i in range(length_k):
                pair = k[i : i+2]

                # pair_count[pair] = pair_count.get(pair, 0) + v
                pair_count[pair] = pc_get(pair, 0) + v

                cache[pair].add((k, v))
    
        return pair_count, cache
    
    def build_heap(self, pair_count: dict[tuple[bytes], int]):
        # Prepare for heapify
        self._heap = [(-count, _Desc(pair)) for pair, count in pair_count.items()]

        heapq.heapify(self._heap)
    
    def update_heap(self, changed_paircount: dict[tuple[bytes], int]):
        for pair, count in changed_paircount.items():
            heapq.heappush(self._heap, (-count, _Desc(pair)))
    
    # @staticmethod
    # def _pick_best_mergetoken(pair_count: dict[tuple[bytes], int]) -> tuple[tuple[bytes], int]:
    #     try:
    #         return max(
    #             pair_count.items(),
    #             key = lambda kv: (kv[1], kv[0])
    #         )
    #     except Exception as e:

    #     # Log or print the freqs that caused the failure
    #         print("Error picking best token, pair_count was:", pair_count)
    #         raise
    
    def _pick_best_mergetoken(self, pair_count: dict[tuple[bytes], int]) -> tuple[tuple[bytes], int]:
        while len(self._heap):
            best_heap = self._heap[0]
            pair = best_heap[1].x
            count = -best_heap[0]

            if pair_count.get(pair, 0) == count:
                return (pair, count)
            else:
                heapq.heappop(self._heap)
    
    # @staticmethod
    # def _build_new_pretoken(
    #     old_pretoken: tuple[tuple[bytes, ...], int], 
    #     best_paircount: tuple[bytes, ...]
    #     ) ->  tuple[tuple[bytes, ...], int]:
    
    #     new_pretoken_pair = ()
    #     old_pretoken_pair = old_pretoken[0]
    #     best_pair = best_paircount
    #     i = 0
    
    #     while i < len(old_pretoken_pair)-1:
    #         if old_pretoken_pair[i : i+2] == best_pair:
    #             new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i] + old_pretoken_pair[i+1],)
    
    #             if i == len(old_pretoken_pair)-3:
    #                 new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i+2],)
    
    #             i = i+2
    #         else:
    #             new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i],)
    
    #             if i == len(old_pretoken_pair)-2:
    #                 new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i+1],)
    
    #             i = i+1
        
    #     new_pretoken = (new_pretoken_pair, old_pretoken[1])
    
    #     return new_pretoken
    
    @staticmethod
    def _build_new_pretoken(
        old_pretoken: tuple[tuple[bytes, ...], int], 
        best_paircount: tuple[bytes, ...]
        ) ->  tuple[tuple[bytes, ...], int]:
    
        # new_pretoken_pair = ()
        new_pretoken_pair: list[bytes] = []

        old_pretoken_pair = old_pretoken[0]
        best_pair = best_paircount

        i = 0
        L = len(old_pretoken_pair) - 1
    
        while i < L:
            if old_pretoken_pair[i : i+2] == best_pair:
                # new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i] + old_pretoken_pair[i+1],)
                new_pretoken_pair.append(old_pretoken_pair[i] + old_pretoken_pair[i+1])
    
                # if i == len(old_pretoken_pair)-3:
                #     new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i+2],)
    
                i = i+2
            else:
                # new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i],)
                new_pretoken_pair.append(old_pretoken_pair[i])
    
                # if i == len(old_pretoken_pair)-2:
                #     new_pretoken_pair = new_pretoken_pair + (old_pretoken_pair[i+1],)
    
                i = i+1
            
            if i == L:
                new_pretoken_pair.append(old_pretoken_pair[i])
        
        new_pretoken = (tuple(new_pretoken_pair), old_pretoken[1])
    
        return new_pretoken
    
    # @staticmethod
    # def _delete_old_contribution(
    #     pretoken: tuple[tuple[bytes, ...], int], 
    #     pair_count: dict[tuple[bytes], int], 
    #     reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
    #     ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]]:

    #     pretoken_pair = pretoken[0]
    #     pretoken_count = pretoken[1]
    
    #     for i in range (len(pretoken_pair)-1):
    #         pair = pretoken_pair[i : i+2]
    
    #         pair_count[pair] = pair_count[pair] - pretoken_count
    #         if pair_count[pair] == 0:
    #             del pair_count[pair]
    
    #         reversed_cache[pair].discard(pretoken)
    #         if not reversed_cache[pair]:
    #             del reversed_cache[pair]
        
    #     return pair_count, reversed_cache
    
    # @staticmethod
    # def _delete_old_contribution(
    #     pretoken: tuple[tuple[bytes, ...], int], 
    #     pair_count: dict[tuple[bytes], int], 
    #     reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]],
    #     changed_paircount: dict[tuple[bytes], int]
    #     ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]], dict[tuple[bytes], int]]:

    #     pretoken_pair = pretoken[0]
    #     pretoken_count = pretoken[1]

    #     changed_paircount_get = changed_paircount.get
    
    #     for i in range (len(pretoken_pair)-1):
    #         pair = pretoken_pair[i : i+2]
    
    #         pair_count[pair] = pair_count[pair] - pretoken_count
            
    #         # Record negative change 
    #         changed_paircount[pair] = changed_paircount_get(pair, 0) - pretoken_count

    #         if pair_count[pair] == 0:
    #             del pair_count[pair]
    
    #         reversed_cache[pair].discard(pretoken)
    #         if not reversed_cache[pair]:
    #             del reversed_cache[pair]
        
    #     return pair_count, reversed_cache, changed_paircount
    
    @staticmethod
    def _delete_old_contribution(
        pretoken: tuple[tuple[bytes, ...], int], 
        pair_count: dict[tuple[bytes], int], 
        reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]],
        changed_paircount: dict[tuple[bytes], int]
        ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]], dict[tuple[bytes], int]]:

        pretoken_pair = pretoken[0]
        pretoken_count = pretoken[1]

        changed_paircount_get = changed_paircount.get
        length_pretoken_pair = len(pretoken_pair)-1
    
        for i in range (length_pretoken_pair):
            pair = pretoken_pair[i : i+2]
    
            pair_count[pair] = pair_count[pair] - pretoken_count
            
            # Record negative change 
            changed_paircount[pair] = changed_paircount_get(pair, 0) - pretoken_count

            if pair_count[pair] == 0:
                del pair_count[pair]
    
            reversed_cache[pair].discard(pretoken)
            if not reversed_cache[pair]:
                del reversed_cache[pair]
        
        return pair_count, reversed_cache, changed_paircount
    
    # @staticmethod
    # def _add_new_contribution(
    #     pretoken: tuple[tuple[bytes, ...], int], 
    #     pair_count: dict[tuple[bytes], int], 
    #     reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]
    #     ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]]:

    #     reversed_cache = defaultdict(set, reversed_cache)
    #     pretoken_pair = pretoken[0]
    #     pretoken_count = pretoken[1]
    
    #     for i in range (len(pretoken_pair)-1):
    #         pair = pretoken_pair[i : i+2]
    
    #         pair_count[pair] = pair_count.get(pair, 0) + pretoken_count
    
    #         reversed_cache[pair].add(pretoken)
        
    #     return pair_count, reversed_cache
    
    @staticmethod
    def _add_new_contribution(
        pretoken: tuple[tuple[bytes, ...], int], 
        pair_count: dict[tuple[bytes], int], 
        reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]],
        changed_paircount: dict[tuple[bytes], int]
        ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]], dict[tuple[bytes], int]]:

        reversed_cache = defaultdict(set, reversed_cache)
        pretoken_pair = pretoken[0]
        pretoken_count = pretoken[1]

        pair_count_get = pair_count.get
        changed_paircount_get = changed_paircount.get

        length_pretoken_pair = len(pretoken_pair)-1
    
        for i in range (length_pretoken_pair):
            pair = pretoken_pair[i : i+2]
    
            pair_count[pair] = pair_count_get(pair, 0) + pretoken_count

            # Record positive change
            changed_paircount[pair] = changed_paircount_get(pair, 0) + pretoken_count
    
            reversed_cache[pair].add(pretoken)
        
        return pair_count, reversed_cache, changed_paircount

    # @staticmethod
    # def merge_new(
    #     pair_counts: dict[tuple[bytes], int], 
    #     reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]],
    #     best_pair: tuple[bytes, ...]
    # ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]]]:

    #     affected_pretokens = reversed_cache[best_pair].copy()
    
    #     for old_pretoken in affected_pretokens:
    #         new_pretoken = Tokenizer._build_new_pretoken(old_pretoken, best_pair)
    
    #         # Update, delete old pretoken contribution
    #         pair_counts, reversed_cache = Tokenizer._delete_old_contribution(old_pretoken, pair_counts, reversed_cache)
    #         # update, add new pretoken contrbution
    #         pair_counts, reversed_cache = Tokenizer._add_new_contribution(new_pretoken, pair_counts, reversed_cache)

    #     return pair_counts, reversed_cache
    
    @staticmethod
    def merge_new(
        pair_counts: dict[tuple[bytes], int], 
        reversed_cache: dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]],
        best_pair: tuple[bytes, ...]
    ) -> tuple[dict[tuple[bytes], int], dict[tuple[bytes, ...], set[tuple[tuple[bytes, ...], int]]], dict[tuple[bytes], int]]:

        affected_pretokens = reversed_cache[best_pair].copy()

        delta_changed_paircount: dict[tuple[bytes], int] = {}
    
        for old_pretoken in affected_pretokens:
            new_pretoken = Tokenizer._build_new_pretoken(old_pretoken, best_pair)
    
            # Update, delete old pretoken contribution
            pair_counts, reversed_cache, delta_changed_paircount = Tokenizer._delete_old_contribution(old_pretoken, pair_counts, reversed_cache, delta_changed_paircount)
            # Update, add new pretoken contrbution
            pair_counts, reversed_cache, delta_changed_paircount = Tokenizer._add_new_contribution(new_pretoken, pair_counts, reversed_cache, delta_changed_paircount)
        
        # Build changed pair count dict
        changed_paircount = {}

        for changed_pair, changed_count in delta_changed_paircount.items():
            # If the changed count is zero, which means no changed, should not include here
            if changed_count and changed_pair in pair_counts:
            # if changed_count and pair_counts.get(changed_pair, 0) != 0:
                changed_paircount[changed_pair] = pair_counts[changed_pair]

        return pair_counts, reversed_cache, changed_paircount
    
    
    def update_vocab(self, best_pair: tuple[tuple[bytes], int]):
        # sorted_vocab = sorted(self.vocab.items(), reverse=True)
        # new_index =  sorted_vocab[0][0] + 1
        
        k = best_pair[0]
        k = k[0] + k[1]
    
        self.vocab[self.next_id] = k

        self.next_id += 1
    
    def _save_vocabulary_merges(self):
        """
        Serialize the resulting vocabulary and merges to disk for further inspection

        Args:

        Returns:
        """
        vocab_serialized = {
            str(token_id): vocab_bytes.decode("utf-8", "replace")
            for token_id, vocab_bytes in self.vocab.items()
        }

        merge_serialized = [
            [first_bytes.decode("utf-8", "replace"), second_bytes.decode("utf-8", "replace")]
            for first_bytes, second_bytes in self.merge
        ]

        with open(self.serialization_vocab_path, 'w', encoding="utf-8") as f:
            json.dump(vocab_serialized, f, indent=2, ensure_ascii=False)
        
        with open(self.serialization_merge_path, 'w', encoding="utf-8") as f:
            json.dump(merge_serialized, f, indent=2, ensure_ascii=False)
    
    # def train_bpe(self, input_path: str, vocab_size: int, gpt2_regex: bool, enable_parallel: bool) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    #     # Init vocab
    #     self.init_vocab()

    #     if enable_parallel:
    #         pretokens = self.pretokenize_parallel(input_path, gpt2_regex)
    #     else:
    #         pretokens = self.pretokenize(input_path, gpt2_regex)
      
    #     # Build the first pair count and cache(pair to corresponding pretokens)
    #     pair_counts, reversed_cache = self.build_paircount_and_cache(pretokens)
    
    #     for i in range(vocab_size - 256 - len(self.special_tokens)):
    #         # Pick best adjcent tokens to merge
    #         best_pair = self._pick_best_mergetoken(pair_counts)
    
    #         # Log pair counts, best pair and step
    #         self.dump_pair_count(pair_counts, best_pair, i)
    
    #         # Update pair counts and cache
    #         pair_counts,  reversed_cache = self.merge_new(pair_counts, reversed_cache, best_pair[0])
    
    #         # TODO: optimize point, insert vocab and merges two times
    #         # Update vocabs
    #         self.update_vocab(best_pair)
    #         # Update merges
    #         self.merge.append((best_pair[0][0], best_pair[0][1]))
        
    #     return self.vocab, self.merge

    def train_bpe(self, input_path: str, vocab_size: int, gpt2_regex: bool, enable_parallel: bool) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
        # Init vocab
        self.init_vocab()

        if enable_parallel:
            pretokens = self.pretokenize_parallel(input_path, gpt2_regex)
        else:
            pretokens = self.pretokenize(input_path, gpt2_regex)
      
        # Build the first pair count and cache(pair to corresponding pretokens)
        pair_counts, reversed_cache = self.build_paircount_and_cache(pretokens)

        self.build_heap(pair_counts)

        merge_size = vocab_size - 256 - len(self.special_tokens)
    
        for i in range(merge_size):
            # Pick best adjcent tokens to merge
            best_pair = self._pick_best_mergetoken(pair_counts)
    
            # Log pair counts, best pair and step
            self.dump_pair_count(pair_counts, best_pair, i)
    
            # Update pair counts and cache
            pair_counts, reversed_cache, changed_paircount = self.merge_new(pair_counts, reversed_cache, best_pair[0])

            self.update_heap(changed_paircount)
    
            # TODO: optimize point, insert vocab and merges two times
            # Update vocabs
            self.update_vocab(best_pair)
            # Update merges
            self.merge.append((best_pair[0][0], best_pair[0][1]))
        
        if self.serialization:
            self._save_vocabulary_merges()
        
        return self.vocab, self.merge

Overwriting tokenizer.py


Train bpe

In [27]:
from tokenizer import Tokenizer
import time
import tracemalloc
import os
import psutil
import contextlib
import yaml

@contextlib.contextmanager
def perf_monitor(enabled: bool = True):
    if not enabled:
        yield {}
        return
    
   # Stat time and memory
    tracemalloc.start()
    start_time = time.perf_counter()
    
    try:
        yield {}
    finally:
        # Stat time and memory
        end_time = time.perf_counter()
        _, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        # Memory stat(rss)
        process = psutil.Process(os.getpid())
        rss_mem = process.memory_info().rss / (1024 * 1024)

        # Build report
        report = f"""
        Performence report
        -------------------------------
        Total time                      :{(end_time - start_time):.2f} seconds
        Peak memory managed by python   :{peak / 1024 / 1024:.2f} MB
        Total physical memory used(RSS) :{rss_mem:.2f} MB
        """

        print(report)

def main():
    with perf_monitor(enabled=False):

        with open("/kaggle/working/config_kaggle.yaml", "r") as f:
            config = yaml.safe_load(f)

        # Init tokenizer
        tokenizer = Tokenizer(
            config["special_tokens"], 
            enable_log=config["enable_log"], 
            log_path=config["log_path"],
            serialization=config["serialization"],
            serialization_vocab_path= config["serialization_vocab_path"],
            serialization_merge_path= config["serialization_merge_path"]
        )
    
        # Training
        vocab, merges = tokenizer.train_bpe(
            config["traindata_path"], 
            vocab_size=config["vocab_size"], 
            gpt2_regex=config["gpt2_regex"], 
            enable_parallel=config["parallel"]
        )

    # Build report
    report = f"""
        BPE Tokenizer Training report
        -------------------------------
        Vocabuary size                  :{len(vocab)}
        Number of merges                :{len(merges)}
        First 5 merges                  :{merges[:5]}
    """

    print(report)

# if __name__ == "__main__":
#     main()

# Profile the function inline
%prun -s cumtime main()


        BPE Tokenizer Training report
        -------------------------------
        Vocabuary size                  :10000
        Number of merges                :9743
        First 5 merges                  :[(b' ', b't'), (b'h', b'e'), (b' ', b'a'), (b' ', b's'), (b' ', b'w')]
    
 

         10959364 function calls (10910628 primitive calls) in 830.515 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000  830.515  830.515 {built-in method builtins.exec}
        1    0.064    0.064  830.515  830.515 <string>:1(<module>)
        1    0.040    0.040  830.451  830.451 4062424040.py:42(main)
        1   11.406   11.406  830.406  830.406 tokenizer.py:823(train_bpe)
        1    0.141    0.141  535.781  535.781 tokenizer.py:357(pretokenize_parallel)
       66  535.472    8.113  535.472    8.113 {method 'acquire' of '_thread.lock' objects}
        1    0.000    0.000  535.464  535.464 _base.py:646(__exit__)
        1    0.000    0.000  535.464  535.464 process.py:842(shutdown)
       26    0.000    0.000  535.464   20.595 threading.py:1125(_wait_for_tstate_lock)
        1    0.000    0.000  535.464  535.464 threading.py:1087(join)
     9743   85.004    0.009  279.606    0.029 tokenizer