<a href="https://colab.research.google.com/github/yuzhipeng588/llm/blob/main/tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from datatrove.pipeline.readers import ParquetReader

# limit determines how many documents will be streamed (remove for all)
# to fetch a specific dump: hf://datasets/HuggingFaceFW/fineweb/data/CC-MAIN-2024-10
# replace "data" with "sample/100BT" to use the 100BT sample
data_reader = ParquetReader("hf://datasets/HuggingFaceFW/fineweb/data", limit=100)

In [None]:
import collections

text = '\n'.join([doc.text for doc in data_reader()])
#text = data_reader().__next__().text
print("Text Length: ", len(text))

def get_next_token(tokens: list[int]) -> list[int]:
  token_count = collections.defaultdict(int)
  for pair in zip(tokens, tokens[1:]):
    token_count[pair] = token_count.get(pair, 0) + 1

  return max(token_count, key=token_count.get)

def merge(tokens: list[int], new_token_pair: tuple, new_token: int) -> list[int]:
  new_tokens = []
  i = 0
  while i < len(tokens):
    if tokens[i:i+2] == list(new_token_pair):
      new_tokens.append(new_token)
      i+=2
    else:
      new_tokens.append(tokens[i])
      i+=1
  return new_tokens

class Tokenizer:
  def __init__(self, text, token_limit):
    assert token_limit > 256, "Token limit must be greater than 256"
    self.token_limit = token_limit
    self.vacob = {i: bytes([i]) for i in range(256)}
    self.reverse_vacob = {bytes([i]): i for i in range(256)}
    tokens = text.encode('utf-8')
    tokens = list(map(int, tokens))
    while token_limit > len(self.vacob):
      new_token_pair = get_next_token(tokens)
      new_token = len(self.vacob)
      new_token_bytes = self.vacob[new_token_pair[0]] + self.vacob[new_token_pair[1]]
      self.vacob[new_token] = new_token_bytes
      self.reverse_vacob[new_token_bytes] = new_token
      tokens = merge(tokens, new_token_pair, new_token)

  # Encode the text with the longest tokens.
  def encode(self, text):
    i = 0
    j = 1
    tokens = []
    encoded_text = text.encode('utf-8')
    while i < len(encoded_text) and j < len(encoded_text) + 1:
      if encoded_text[i:j] in self.reverse_vacob:
        if j - i == 1:
          tokens.append(self.reverse_vacob[encoded_text[i:j]])
        else:
          tokens[-1] = self.reverse_vacob[encoded_text[i:j]]
        j+=1
      else:
        i = j - 1
    return tokens

  def decode(self, tokens):
    return b''.join([self.vacob[token] for token in tokens]).decode('utf-8')

tokenizer = Tokenizer(text, token_limit=512)

In [None]:
# Test encoding and decoding
assert tokenizer.decode(tokenizer.encode(text)) == text