<a href="https://colab.research.google.com/github/yuzhipeng588/llm/blob/main/tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [89]:
import collections

import ast
import json
import os
import regex

from typing import Dict, List, Tuple

from transformers import PreTrainedTokenizer
from transformers import AutoTokenizer

def get_stats(tokens_trunk: list[list[int]]):
  token_count = collections.defaultdict(int)
  for tokens in tokens_trunk:
    for pair in zip(tokens[:-1], tokens[1:]):
      token_count[pair] = token_count.get(pair, 0) + 1

  return token_count

def merge(tokens_trunk: list[list[int]], new_token_pair: tuple, new_token: int) -> list[list[int]]:
  new_tokens_trunk = []
  for tokens in tokens_trunk:
    new_tokens = []
    i = 0
    while i < len(tokens):
      if tokens[i:i+2] == list(new_token_pair):
        new_tokens.append(new_token)
        i+=2
      else:
        new_tokens.append(tokens[i])
        i+=1
    new_tokens_trunk.append(new_tokens)
  return new_tokens_trunk

class SimpleUtf8Tokenizer(PreTrainedTokenizer):
  def __init__(self, vocab_file=None,
        merges_file=None,
        vocab=None,
        merges=None, **kwargs):

    # Handle both loading from files and instantiation from in-memory data
    if vocab_file and merges_file:
        with open(vocab_file, "r", encoding="utf-8") as f:
            vocab_lines = f.read().splitlines()
            self.vocab = {ast.literal_eval(items[0]): int(items[1]) for items in [line.rsplit(" ", 1) for line in vocab_lines]}
            self.reverse_vocab = {v: k for k, v in self.vocab.items()}
        with open(merges_file, "r", encoding="utf-8") as f:
            merges_lines = f.read().splitlines()
            self.merges = {(item[0], item[1]): int(item[2]) for item in [line.split() for line in merges_lines]}
    elif vocab and merges:
        self.reverse_vocab = vocab
        self.vocab = {v: k for k, v in vocab.items()}
        self.merges = merges
    else:
        raise ValueError(
            "You must provide either vocab/merges data or file paths."
        )
    print(self.vocab)
    self.id_to_token_str = {k: v.decode('latin-1') for k, v in self.reverse_vocab.items()}
    super().__init__(**kwargs)

  @classmethod
  def preprocess_to_token(cls, text: str) -> list[list[int]]:
    text_trunks = regex.findall(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""", text)
    try:
      tokens_trunk = [ list(map(int, t.encode('utf-8'))) for t in text_trunks]
    except UnicodeEncodeError:
      raise ValueError("Text must be encoded in UTF-8")
    return tokens_trunk

  def preprocess_to_bytes(cls, text: str) -> list[bytes]:
    text_trunks = regex.findall(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""", text)
    try:
      bytes_trunk = [ t.encode('utf-8') for t in text_trunks]
    except UnicodeEncodeError:
      raise ValueError("Text must be encoded in UTF-8")
    return bytes_trunk

  @classmethod
  def train(cls, text: str, vocab_size: int, **kwargs):
    tokens_trunk = cls.preprocess_to_token(text)

    special_tokens = set()
    if 'unk_token' in kwargs:
      special_tokens.add(kwargs['unk_token'])
    if 'pad_token' in kwargs:
      special_tokens.add(kwargs['pad_token'])
    if 'eos_token' in kwargs:
      special_tokens.add(kwargs['eos_token'])
    if 'sep_token' in kwargs:
      special_tokens.add(kwargs['sep_token'])

    assert vocab_size > 256 + len(special_tokens), "Token limit must be greater than 256 + special tokens"
    bytes_vocab = {i: bytes([i]) for i in range(256)}
    for token in special_tokens:
      bytes_vocab[len(bytes_vocab)] = token.encode('utf-8')
    reverse_bytes_vocab = {v: k for k, v in bytes_vocab.items()}
    merges = {}
    while vocab_size > len(bytes_vocab):
      stats = get_stats(tokens_trunk)
      new_token_pair = max(stats, key=stats.get)
      new_token = len(bytes_vocab)
      new_token_bytes = bytes_vocab[new_token_pair[0]] + bytes_vocab[new_token_pair[1]]

      bytes_vocab[new_token] = new_token_bytes
      reverse_bytes_vocab[new_token_bytes] = new_token
      tokens_trunk = merge(tokens_trunk, new_token_pair, new_token)
      merges[(new_token_pair[0], new_token_pair[1])] = new_token
      print("Vocab Size: ", len(bytes_vocab))

    #str_vocab = {v.decode('latin-1'): k for k, v in bytes_vocab.items()}
    #str_merges = [(k.decode('latin-1'), v.decode('latin-1')) for k, v in merges]
    return cls(vocab=bytes_vocab, merges=merges, **kwargs)

  @property
  def vocab_size(self) -> int:
      # The vocabulary consists of all 256 possible bytes.
      return len(self.vocab)

  def get_vocab(self) -> Dict[str, int]:
      """
      Returns the vocabulary as a dictionary of strings to integers.
      """
      # Create a mapping from the string representation of each byte to its integer value.
      return self.vocab

  def _convert_token_to_id(self, token: str) -> int:
      """
      Converts a token (a single-byte string) into its integer byte value.
      """
      # The token is a single character, and its ord() value is its byte value.
      return self.vocab[token]

  def _convert_id_to_token(self, index: int) -> str:
      """
      Converts an integer byte value into its single-byte string representation.
      """
      # Convert the integer to its character representation using latin-1.
      return self.id_to_token_str[index]

  def save_vocabulary(
      self, save_directory: str, filename_prefix: str | None = None
  ) -> Tuple[str]:
      """Saves the vocabulary and merges to files."""
      if not os.path.isdir(save_directory):
          os.makedirs(save_directory)

      # Save the vocabulary file
      vocab_file = os.path.join(
          save_directory, (filename_prefix or "") + "vocab.txt"
      )
      with open(vocab_file, "w", encoding="utf-8") as f:
          for p1, p2 in self.vocab.items():
              f.write(f"{p1} {p2}\n")

      # Save the merges file
      merges_file = os.path.join(
          save_directory, (filename_prefix or "") + "merges.txt"
      )
      with open(merges_file, "w", encoding="utf-8") as f:
          for k, v in self.merges.items():
              f.write(f"{k[0]} {k[1]} {v}\n")

      return (vocab_file, merges_file)

  # Encode the text with the longest tokens.
  def encode(self, text):
    encoded_text_trunks = self.preprocess_to_token(text)

    while True:
      stats = get_stats(encoded_text_trunks)
      next_token_pair = max(stats, key=lambda p: float('-inf') if p not in self.merges else stats.get(p))
      if next_token_pair in self.merges:
        encoded_text_trunks = merge(encoded_text_trunks, next_token_pair, self.merges[next_token_pair])
      else:
        break
    new_list = []
    for tokens in encoded_text_trunks:
      new_list.extend(tokens)
    return new_list

  def decode(self, tokens):
    print(self.reverse_vocab)
    return b''.join([self.reverse_vocab[token] for token in tokens]).decode('utf-8')

AutoTokenizer.register("SimpleUtf8Tokenizer", slow_tokenizer_class=SimpleUtf8Tokenizer)

In [91]:
!pip install datatrove
from datatrove.pipeline.readers import ParquetReader

# limit determines how many documents will be streamed (remove for all)
# to fetch a specific dump: hf://datasets/HuggingFaceFW/fineweb/data/CC-MAIN-2024-10
# replace "data" with "sample/100BT" to use the 100BT sample
data_reader = ParquetReader("hf://datasets/HuggingFaceFW/fineweb/data", limit=100)

text = '\n'.join([doc.text for doc in data_reader()])
token_limit = 2560
#text = data_reader().__next__().text
print("Text Length: ", len(text))

tokenizer = SimpleUtf8Tokenizer.train(text, token_limit, unk_token="<unk>")


save_directory = "./simple_utf8_tokenizer"
print(f"Saving custom tokenizer to '{save_directory}'...")
tokenizer.save_pretrained(save_directory)
print("Tokenizer saved.")



[32m2025-09-14 05:11:18.969[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file CC-MAIN-2013-20/000_00000.parquet, 1/27468[0m


Text Length:  252642
Vocab Size:  258
Vocab Size:  259
Vocab Size:  260
Vocab Size:  261
Vocab Size:  262
Vocab Size:  263
Vocab Size:  264
Vocab Size:  265
Vocab Size:  266
Vocab Size:  267
Vocab Size:  268
Vocab Size:  269
Vocab Size:  270
Vocab Size:  271
Vocab Size:  272
Vocab Size:  273
Vocab Size:  274
Vocab Size:  275
Vocab Size:  276
Vocab Size:  277
Vocab Size:  278
Vocab Size:  279
Vocab Size:  280
Vocab Size:  281
Vocab Size:  282
Vocab Size:  283
Vocab Size:  284
Vocab Size:  285
Vocab Size:  286
Vocab Size:  287
Vocab Size:  288
Vocab Size:  289
Vocab Size:  290
Vocab Size:  291
Vocab Size:  292
Vocab Size:  293
Vocab Size:  294
Vocab Size:  295
Vocab Size:  296
Vocab Size:  297
Vocab Size:  298
Vocab Size:  299
Vocab Size:  300
Vocab Size:  301
Vocab Size:  302
Vocab Size:  303
Vocab Size:  304
Vocab Size:  305
Vocab Size:  306
Vocab Size:  307
Vocab Size:  308
Vocab Size:  309
Vocab Size:  310
Vocab Size:  311
Vocab Size:  312
Vocab Size:  313
Vocab Size:  314
Vocab Size

In [74]:
# Verify it works
text = "hello custom"
encoded = tokenizer.encode(text)
print(f"\nOriginal text: '{text}'")
print(f"Encoded with loaded tokenizer: {encoded}")
decoded = tokenizer.decode(encoded)
print(f"Decoded text: '{decoded}'")


Original text: 'hello custom'
Encoded with loaded tokenizer: [259, 108, 108, 111, 32, 99, 117, 115, 116, 111, 109]
{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80:

In [9]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [92]:
from huggingface_hub import HfApi

api = HfApi(token=os.getenv("HF_TOKEN"))
api.upload_folder(
    folder_path=save_directory,
    repo_id="thaitea2021/experimental",
    repo_type="model",
)

CommitInfo(commit_url='https://huggingface.co/thaitea2021/experimental/commit/d556d1a512c3229b97edfbcc4a74c0949e0904f4', commit_message='Upload folder using huggingface_hub', commit_description='', oid='d556d1a512c3229b97edfbcc4a74c0949e0904f4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/thaitea2021/experimental', endpoint='https://huggingface.co', repo_type='model', repo_id='thaitea2021/experimental'), pr_revision=None, pr_num=None)

In [93]:
from transformers import AutoTokenizer
print("\nLoading tokenizer back...")
loaded_tokenizer = AutoTokenizer.from_pretrained(
    save_directory, trust_remote_code=True, vocab_file=save_directory + "/vocab.txt", merges_file=save_directory + "/merges.txt"
)

# Verify it works
text = "hello custom"
encoded = loaded_tokenizer.encode(text)
print(f"\nOriginal text: '{text}'")
print(f"Encoded with loaded tokenizer: {encoded}")
decoded = loaded_tokenizer.decode(encoded)
print(f"Decoded text: '{decoded}'")


Loading tokenizer back...
{b'\x00': 0, b'\x01': 1, b'\x02': 2, b'\x03': 3, b'\x04': 4, b'\x05': 5, b'\x06': 6, b'\x07': 7, b'\x08': 8, b'\t': 9, b'\n': 10, b'\x0b': 11, b'\x0c': 12, b'\r': 13, b'\x0e': 14, b'\x0f': 15, b'\x10': 16, b'\x11': 17, b'\x12': 18, b'\x13': 19, b'\x14': 20, b'\x15': 21, b'\x16': 22, b'\x17': 23, b'\x18': 24, b'\x19': 25, b'\x1a': 26, b'\x1b': 27, b'\x1c': 28, b'\x1d': 29, b'\x1e': 30, b'\x1f': 31, b' ': 32, b'!': 33, b'"': 34, b'#': 35, b'$': 36, b'%': 37, b'&': 38, b"'": 39, b'(': 40, b')': 41, b'*': 42, b'+': 43, b',': 44, b'-': 45, b'.': 46, b'/': 47, b'0': 48, b'1': 49, b'2': 50, b'3': 51, b'4': 52, b'5': 53, b'6': 54, b'7': 55, b'8': 56, b'9': 57, b':': 58, b';': 59, b'<': 60, b'=': 61, b'>': 62, b'?': 63, b'@': 64, b'A': 65, b'B': 66, b'C': 67, b'D': 68, b'E': 69, b'F': 70, b'G': 71, b'H': 72, b'I': 73, b'J': 74, b'K': 75, b'L': 76, b'M': 77, b'N': 78, b'O': 79, b'P': 80, b'Q': 81, b'R': 82, b'S': 83, b'T': 84, b'U': 85, b'V': 86, b'W': 87, b'X': 88, b'

In [94]:
# Test encoding and decoding
assert tokenizer.decode(tokenizer.encode(text)) == text

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',