In [14]:
import re
import os
import regex as re
from collections import Counter, defaultdict
import multiprocessing

In [12]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

# bpe tokenizer that returns vocab and merges
def tokenizer(input_path, vocab_size, special_tokens):
  # read input file
  with open(input_path, "r") as f:
    text = f.read()

  # 0-255 initial vocab
  vocab = { i: chr(i) for i in range(256) }
  last_index = 256
  for i, token in enumerate(special_tokens):
    vocab[last_index] = token
    last_index += 1

  escaped_special_tokens = [re.escape(token) for token in special_tokens]
  special_token_split = "|".join(escaped_special_tokens)
  words = []

  for doc in re.split(special_token_split, text):
    doc_words = re.findall(PAT, doc)
    doc_words_ints = [word.encode("utf-8") for word in doc_words]
    words.extend(doc_words_ints)

  # while last_index < vocab_size:
  pair_counts = {}
  for word in words:
    for a, b in zip(word, word[1:]):
      pair = (a, b)
      if pair not in pair_counts:
        pair_counts[pair] = 0
      pair_counts[pair] += 1

  best_pair = max(pair_counts, key=pair_counts.get)
  a, b = best_pair
  new_token = a + b
  vocab[last_index] = new_token
  last_index += 1

  # merge
  for word in reversed(words):
    new_word = []
    i = 0
    while i < len(word):
      if word[i:i+2] == list(best_pair):
        new_word.append(new_token)
        i += 2







tokenizer("../tests/fixtures/tinystories_sample.txt", 100, ["<|endoftext|>"])

(32, 116)


In [15]:
PAT = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

def get_stats(ids_chunks):
    # count adjacent pairs
    pair_counts = Counter()
    for ids in ids_chunks:
        for pair in zip(ids, ids[1:]):
            pair_counts[pair] += 1
    return pair_counts

def merge_chunk(ids_chunk, pair, idx):
    # merge pairs
    new_ids_chunk = []
    for ids in ids_chunk:
        new_ids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and (ids[i], ids[i+1]) == pair:
                new_ids.append(idx)
                i += 2
            else:
                new_ids.append(ids[i])
                i += 1
        new_ids_chunk.append(new_ids)
    return new_ids_chunk

def tokenizer(input_path, vocab_size, special_tokens):
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()

    escaped_special_tokens = [re.escape(token) for token in special_tokens]
    special_token_split = "|".join(escaped_special_tokens)
    split_texts = re.split(special_token_split, text)
    text_words = []
    for split_text in split_texts:
      text_words.extend(PAT.findall(split_text))

    words = [list(word.encode("utf-8")) for word in text_words if word] # Ensure not empty strings

    # Initial vocab: map integers 0-255 to their byte representation
    vocab = {i: bytes([i]) for i in range(256)}
    merges = {}
    next_id = 256
    num_merges = vocab_size - 256 - len(special_tokens) # Reserve space for base bytes and special tokens

    if num_merges < 0:
        raise ValueError("vocab_size is too small for base bytes and special tokens.")

    num_processes = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=num_processes)

    for i in range(num_merges):
        # Parallelized pair counting
        chunk_size = (len(words) + num_processes - 1) // num_processes
        chunks = [words[j*chunk_size:(j+1)*chunk_size] for j in range(num_processes)]

        results = pool.map(get_stats, chunks)

        # Combine counts
        pair_counts = Counter()
        for result in results:
            pair_counts.update(result)

        if not pair_counts:
            print(f"No more pairs to merge. Stopping at iteration {i}.")
            break # No more pairs found

        # Find the best pair
        best_pair = max(pair_counts, key=pair_counts.get)
        new_id = next_id

        # Parallelized merging
        merge_args = [(chunks[j], best_pair, new_id) for j in range(num_processes)]
        merged_chunks = pool.starmap(merge_chunk, merge_args)

        # Combine merged chunks
        words = [item for sublist in merged_chunks for item in sublist] # Flatten list of lists

        # Update merges and vocab
        merges[best_pair] = new_id
        try:
             vocab[new_id] = vocab[best_pair[0]] + vocab[best_pair[1]]
        except KeyError:
            print(f"Error: Key not found in vocab. best_pair: {best_pair}")
            print(f"Current vocab keys: {list(vocab.keys())}")
            print(f"Current max ID in words: {max(max(w) for w in words if w)}") # Check if IDs got out of sync
            raise # Re-raise the error to understand it better if it occurs.

        next_id += 1
        print(f"Merge {i+1}/{num_merges}: {best_pair} -> {new_id} ({vocab[new_id]}) count={pair_counts[best_pair]}")


    # Add special tokens to the end of the vocab
    special_token_start_id = next_id
    for i, token in enumerate(special_tokens):
        current_special_id = special_token_start_id + i
        vocab[current_special_id] = token.encode('utf-8') # Store as bytes

    pool.close()
    pool.join()

    print(f"Final vocab size: {len(vocab)}")
    return vocab, merges

SyntaxError: '[' was never closed (3341508237.py, line 33)