<a href="https://colab.research.google.com/github/yuzhipeng588/llm/blob/main/tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [78]:
from datatrove.pipeline.readers import ParquetReader

# limit determines how many documents will be streamed (remove for all)
# to fetch a specific dump: hf://datasets/HuggingFaceFW/fineweb/data/CC-MAIN-2024-10
# replace "data" with "sample/100BT" to use the 100BT sample
data_reader = ParquetReader("hf://datasets/HuggingFaceFW/fineweb/data", limit=100)


import collections

import json
import os

from typing import Dict, List, Tuple

from transformers import PreTrainedTokenizer
from transformers import AutoTokenizer

text = '\n'.join([doc.text for doc in data_reader()])
token_limit = 512
#text = data_reader().__next__().text
print("Text Length: ", len(text))

def get_next_token(tokens: list[int]) -> list[int]:
  token_count = collections.defaultdict(int)
  for pair in zip(tokens, tokens[1:]):
    token_count[pair] = token_count.get(pair, 0) + 1

  return max(token_count, key=token_count.get)

def merge(tokens: list[int], new_token_pair: tuple, new_token: int) -> list[int]:
  new_tokens = []
  i = 0
  while i < len(tokens):
    if tokens[i:i+2] == list(new_token_pair):
      new_tokens.append(new_token)
      i+=2
    else:
      new_tokens.append(tokens[i])
      i+=1
  return new_tokens

class SimpleUtf8Tokenizer(PreTrainedTokenizer):
  def __init__(self, vocab_file=None,
        merges_file=None,
        vocab=None,
        merges=None, **kwargs):

    # Handle both loading from files and instantiation from in-memory data
    if vocab_file and merges_file:
        with open(vocab_file, "r", encoding="utf-8") as f:
            self.vocab = json.load(f)
        with open(merges_file, "r", encoding="utf-8") as f:
            merges_lines = f.read().splitlines()
            self.merges = [tuple(line.split()) for line in merges_lines]
    elif vocab and merges:
        self.vocab = vocab
        self.merges = merges
    else:
        raise ValueError(
            "You must provide either vocab/merges data or file paths."
        )
    print(self.vocab)
    self.id_to_token_str = {v:k for k,v in self.vocab.items()}
    super().__init__(**kwargs)

  @classmethod
  def train(cls, text: str, vocab_size: int, **kwargs):
    try:
      tokens = text.encode('utf-8')
      tokens = list(map(int, tokens))
    except UnicodeEncodeError:
      raise ValueError("Text must be encoded in UTF-8")

    special_tokens = set()
    if 'unk_token' in kwargs:
      special_tokens.add(kwargs['unk_token'])

    assert vocab_size > 256 + len(special_tokens), "Token limit must be greater than 256 + special tokens"
    bytes_vocab = {i: bytes([i]) for i in range(256)}
    for token in special_tokens:
      bytes_vocab[len(bytes_vocab)] = token.encode('utf-8')
    reverse_bytes_vocab = {v: k for k, v in bytes_vocab.items()}
    merges = []
    while vocab_size > len(bytes_vocab):
      new_token_pair = get_next_token(tokens)
      new_token = len(bytes_vocab)
      new_token_bytes = bytes_vocab[new_token_pair[0]] + bytes_vocab[new_token_pair[1]]

      bytes_vocab[new_token] = new_token_bytes
      reverse_bytes_vocab[new_token_bytes] = new_token
      tokens = merge(tokens, new_token_pair, new_token)
      merges.append((bytes_vocab[new_token_pair[0]], bytes_vocab[new_token_pair[1]]))

    str_vocab = {v.decode('latin-1'): k for k, v in bytes_vocab.items()}
    str_merges = [(k.decode('latin-1'), v.decode('latin-1')) for k, v in merges]
    return cls(vocab=str_vocab, merges=str_merges, **kwargs)

  @property
  def vocab_size(self) -> int:
      # The vocabulary consists of all 256 possible bytes.
      return len(self.vocab)

  def get_vocab(self) -> Dict[str, int]:
      """
      Returns the vocabulary as a dictionary of strings to integers.
      """
      # Create a mapping from the string representation of each byte to its integer value.
      return self.vocab

  def _convert_token_to_id(self, token: str) -> int:
      """
      Converts a token (a single-byte string) into its integer byte value.
      """
      # The token is a single character, and its ord() value is its byte value.
      return self.vocab[token]

  def _convert_id_to_token(self, index: int) -> str:
      """
      Converts an integer byte value into its single-byte string representation.
      """
      # Convert the integer to its character representation using latin-1.
      return self.id_to_token_str[index]

  def save_vocabulary(
      self, save_directory: str, filename_prefix: str | None = None
  ) -> Tuple[str]:
      """Saves the vocabulary and merges to files."""
      if not os.path.isdir(save_directory):
          os.makedirs(save_directory)

      # Save the vocabulary file
      vocab_file = os.path.join(
          save_directory, (filename_prefix or "") + "vocab.json"
      )
      with open(vocab_file, "w", encoding="utf-8") as f:
          json.dump(self.vocab, f, ensure_ascii=False, indent=2)

      # Save the merges file
      merges_file = os.path.join(
          save_directory, (filename_prefix or "") + "merges.txt"
      )
      with open(merges_file, "w", encoding="utf-8") as f:
          for p1, p2 in self.merges:
              f.write(f"{p1} {p2}\n")

      return (vocab_file, merges_file)

  # Encode the text with the longest tokens.
  def encode(self, text):
    try:
      encoded_text = text.encode('utf-8')
    except UnicodeEncodeError:
      raise ValueError("Text must be encoded in UTF-8")

    i = 0
    j = 1
    tokens = []
    while i < len(text) and j < len(text) + 1:
      if text[i:j] in self.vocab:
        if j - i == 1:
          tokens.append(self.vocab[text[i:j]])
        else:
          tokens[-1] = self.vocab[text[i:j]]
        j+=1
      else:
        i = j - 1
    return tokens

  def decode(self, tokens):
    return b''.join([self.id_to_token_str[token].encode('utf8') for token in tokens]).decode('utf-8')

AutoTokenizer.register("SimpleUtf8Tokenizer", slow_tokenizer_class=SimpleUtf8Tokenizer)

[32m2025-09-08 01:12:50.916[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file CC-MAIN-2013-20/000_00000.parquet, 1/27468[0m


Text Length:  252642


In [79]:
tokenizer = SimpleUtf8Tokenizer.train(text, token_limit, unk_token="<unk>")


save_directory = "./simple_utf8_tokenizer"
print(f"Saving custom tokenizer to '{save_directory}'...")
tokenizer.save_pretrained(save_directory)
print("Tokenizer saved.")

{'\x00': 0, '\x01': 1, '\x02': 2, '\x03': 3, '\x04': 4, '\x05': 5, '\x06': 6, '\x07': 7, '\x08': 8, '\t': 9, '\n': 10, '\x0b': 11, '\x0c': 12, '\r': 13, '\x0e': 14, '\x0f': 15, '\x10': 16, '\x11': 17, '\x12': 18, '\x13': 19, '\x14': 20, '\x15': 21, '\x16': 22, '\x17': 23, '\x18': 24, '\x19': 25, '\x1a': 26, '\x1b': 27, '\x1c': 28, '\x1d': 29, '\x1e': 30, '\x1f': 31, ' ': 32, '!': 33, '"': 34, '#': 35, '$': 36, '%': 37, '&': 38, "'": 39, '(': 40, ')': 41, '*': 42, '+': 43, ',': 44, '-': 45, '.': 46, '/': 47, '0': 48, '1': 49, '2': 50, '3': 51, '4': 52, '5': 53, '6': 54, '7': 55, '8': 56, '9': 57, ':': 58, ';': 59, '<': 60, '=': 61, '>': 62, '?': 63, '@': 64, 'A': 65, 'B': 66, 'C': 67, 'D': 68, 'E': 69, 'F': 70, 'G': 71, 'H': 72, 'I': 73, 'J': 74, 'K': 75, 'L': 76, 'M': 77, 'N': 78, 'O': 79, 'P': 80, 'Q': 81, 'R': 82, 'S': 83, 'T': 84, 'U': 85, 'V': 86, 'W': 87, 'X': 88, 'Y': 89, 'Z': 90, '[': 91, '\\': 92, ']': 93, '^': 94, '_': 95, '`': 96, 'a': 97, 'b': 98, 'c': 99, 'd': 100, 'e': 101

In [80]:
from transformers import AutoTokenizer
print("\nLoading tokenizer back...")
loaded_tokenizer = AutoTokenizer.from_pretrained(
    save_directory, trust_remote_code=True, vocab_file=save_directory + "/vocab.json", merges_file=save_directory + "/merges.txt"
)

# Verify it works
text = "hello custom"
encoded = loaded_tokenizer.encode(text)
print(f"\nOriginal text: '{text}'")
print(f"Encoded with loaded tokenizer: {encoded}")
decoded = loaded_tokenizer.decode(encoded)
print(f"Decoded text: '{decoded}'")


Loading tokenizer back...
{'\x00': 0, '\x01': 1, '\x02': 2, '\x03': 3, '\x04': 4, '\x05': 5, '\x06': 6, '\x07': 7, '\x08': 8, '\t': 9, '\n': 10, '\x0b': 11, '\x0c': 12, '\r': 13, '\x0e': 14, '\x0f': 15, '\x10': 16, '\x11': 17, '\x12': 18, '\x13': 19, '\x14': 20, '\x15': 21, '\x16': 22, '\x17': 23, '\x18': 24, '\x19': 25, '\x1a': 26, '\x1b': 27, '\x1c': 28, '\x1d': 29, '\x1e': 30, '\x1f': 31, ' ': 32, '!': 33, '"': 34, '#': 35, '$': 36, '%': 37, '&': 38, "'": 39, '(': 40, ')': 41, '*': 42, '+': 43, ',': 44, '-': 45, '.': 46, '/': 47, '0': 48, '1': 49, '2': 50, '3': 51, '4': 52, '5': 53, '6': 54, '7': 55, '8': 56, '9': 57, ':': 58, ';': 59, '<': 60, '=': 61, '>': 62, '?': 63, '@': 64, 'A': 65, 'B': 66, 'C': 67, 'D': 68, 'E': 69, 'F': 70, 'G': 71, 'H': 72, 'I': 73, 'J': 74, 'K': 75, 'L': 76, 'M': 77, 'N': 78, 'O': 79, 'P': 80, 'Q': 81, 'R': 82, 'S': 83, 'T': 84, 'U': 85, 'V': 86, 'W': 87, 'X': 88, 'Y': 89, 'Z': 90, '[': 91, '\\': 92, ']': 93, '^': 94, '_': 95, '`': 96, 'a': 97, 'b': 98, 

In [81]:
# Test encoding and decoding
assert tokenizer.decode(tokenizer.encode(text)) == text

In [82]:
type(b'\x80'.decode('latin-1'))

str

In [48]:
type(bytes([1,3]))

bytes