## 2.5 Problem: BPE Training on TinyStories (2 points)
`(train_bpe_tinystories)`

**(a)** Train a byte-level BPE tokenizer on the **TinyStories** dataset, using a maximum vocabulary size of **10,000**. Make sure to add the TinyStories <|endoftext|> special token to the vocabulary. Serialize the resulting vocabulary and merges to disk for further inspection. How many hours and memory did training take? What is the longest token in the vocabulary? Does it make sense?

**Resource requirements:** ≤ 30 minutes (no GPUs), ≤ 30GB RAM 

**Hint** You should be able to get under 2 minutes for BPE training using multiprocessing during pretokenization and the following two facts:

- The <|endoftext|> token delimits documents in the data files.

- The <|endoftext|> token is handled as a special case before the BPE merges are applied.

**(b)** Profile your code. What part of the tokenizer training process takes the most time?

### **Answers**

**(a) Time & Memory Usage**

TinyStoriesV2-GPT4-train.txt, vocab_size = 10000.

Without inverted index optimization:
- num_processes = 4: 
    - time: 172.03s
- num_processes = 40:
    - time: 60.81s
    - memory: peak Python heap: 67.1 MB, peak RSS: 144.1 MB

After optimizing merge pair with inverted index + heap with lazy deletion:
- num_processes = 40:
    - time: **16.6s**
    - memory: peak Python heap: 126.2 MB, peak RSS: 391.0 MB

The longest token: `id=7160, len=15 bytes, text=' accomplishment'`

**(b) Code Profiling** 

Notes: Machine info
- OS: Ubuntu 22.04.5 LTS (kernel 6.8 on x86_64)
- CPU: AMD EPYC 9V84 (40 vCPUs visible)
- RAM: 314 GiB total (≈296 GiB available right now)
- Disk: Root filesystem ≈2.0 TB total, ≈2.0 TB free
- GPU: NVIDIA GA103 detected on PCIe (PCI ID 10de:2321), but NVIDIA driver/module is not loaded

In [None]:
### Helper functions

import yaml
import os


def save_tokenizer_yaml(vocab, merges, fname):
    "Save vocab and merges to a YAML file with UTF-8 decoding for readability."
    # Convert bytes → string for readability
    vocab_serializable = {
        k: v.decode("utf-8", errors="replace") if isinstance(v, bytes) else v
        for k, v in vocab.items()
    }
    merges_serializable = [
        (a.decode("utf-8", errors="replace"), b.decode("utf-8", errors="replace"))
        for a, b in merges
    ]

    # Ensure the parent directory exists before writing
    dirpath = os.path.dirname(fname)
    if dirpath:
        os.makedirs(dirpath, exist_ok=True)

    with open(fname, "w", encoding="utf-8") as f:
        yaml.dump(
            {"vocab": vocab_serializable, "merges": merges_serializable},
            f,
            allow_unicode=True,
            sort_keys=False
        )


In [None]:
import time, os
from cs336_basics.bpe import run_train_bpe

file_name = "data/TinyStoriesV2-GPT4-train.txt"
vocab_size = 10000
num_processes = os.cpu_count()

### Training

start = time.perf_counter()
vocab, merges = run_train_bpe(file_name, vocab_size, ["<|endoftext|>"], num_processes)
elapsed_s = time.perf_counter() - start
print(f"time: {elapsed_s:.2f}s")

### Longest token in the vocabulary

save_tokenizer_yaml(vocab, merges, "artifacts/tinystories_bpe.yaml")
longest_id, longest_bytes = max(vocab.items(), key=lambda kv: len(kv[1]))
print(f"longest token: id={longest_id}, len={len(longest_bytes)} bytes, text={longest_bytes.decode('utf-8','replace')!r}")
# longest token: id=7160, len=15 bytes, text=' accomplishment'

In [None]:
### Time & Memory usage; Profile code

import tracemalloc, resource, time
import os
import cProfile, pstats, io
from cs336_basics.bpe import run_train_bpe

file_name = "data/TinyStoriesV2-GPT4-train.txt"
vocab_size = 10000

# Start fresh
pr = cProfile.Profile()
pr.enable()
tracemalloc.start()
start = time.perf_counter()

# Training
vocab, merges = run_train_bpe(file_name, vocab_size, ["<|endoftext|>"], num_processes=os.cpu_count())

# Stop and report
elapsed_s = time.perf_counter() - start
pr.disable(); 
cur, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

peak_rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
print(f"time: {elapsed_s:.2f}s ({elapsed_s/3600:.2f} h), peak Python heap: {peak/1e6:.1f} MB, peak RSS: {peak_rss_mb:.1f} MB")

s=io.StringIO(); 
pstats.Stats(pr, stream=s).sort_stats('cumtime').print_stats(30)
print(s.getvalue())

""" Profiling Results:

time: 238.98s (0.07 h), peak Python heap: 126.2 MB, peak RSS: 391.0 MB
         19354352 function calls (19353456 primitive calls) in 238.873 seconds

   Ordered by: cumulative time
   List reduced from 408 to 30 due to restriction <30>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      205    0.054    0.000  369.225    1.801 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/asyncio/base_events.py:1970(_run_once)
        5    0.000    0.000  238.977   47.795 /home/haoru/.cache/uv/archive-v0/ZzdUiPCnVvyZrxKy_ds_G/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3663(run_code)
     10/5    0.000    0.000  238.977   47.795 {built-in method builtins.exec}
      205    0.090    0.000  214.202    1.045 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/selectors.py:435(select)
       40    0.056    0.001  154.003    3.850 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/queues.py:96(get)
       40    0.000    0.000  153.372    3.834 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:208(recv_bytes)
      205  141.259    0.689  142.454    0.695 {method 'poll' of 'select.epoll' objects}
     9743   45.666    0.005   75.383    0.008 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:159(merge_pair_and_update_counts)
       40    0.000    0.000   13.920    0.348 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:429(_recv_bytes)
       80    0.001    0.000   13.919    0.174 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:390(_recv)
      280   13.914    0.050   13.914    0.050 {built-in method posix.read}
  7123718   10.989    0.000   10.989    0.000 {method 'get' of 'dict' objects}
        1    0.601    0.601    9.150    9.150 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:322(run_train_bpe)
        1    0.092    0.092    8.549    8.549 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:261(train_bpe_on_pretokens)
  1456506    4.021    0.000    4.021    0.000 {method 'add' of 'set' objects}
     9743    1.156    0.000    3.519    0.000 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:109(find_most_frequent_pair)
  1081253    3.467    0.000    3.467    0.000 {method 'setdefault' of 'dict' objects}
  1079593    3.291    0.000    3.291    0.000 {method 'discard' of 'set' objects}
  1368847    3.033    0.000    3.033    0.000 {method 'append' of 'list' objects}
  2632440    2.583    0.000    2.583    0.000 {built-in method builtins.len}
   749810    0.626    0.000    2.047    0.000 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:65(_desc_key)
   551932    1.951    0.000    1.951    0.000 {built-in method _heapq.heappop}
  1499620    0.766    0.000    1.421    0.000 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:47(_inv)
        1    0.730    0.730    1.189    1.189 /home/haoru/cs336/cs336-hw1/cs336_basics/bpe.py:72(build_pair_counts_and_index)
   747725    1.189    0.000    1.189    0.000 {built-in method _heapq.heappush}
   606448    0.993    0.000    0.993    0.000 {method 'pop' of 'dict' objects}
   277855    0.945    0.000    0.945    0.000 {method 'items' of 'dict' objects}
       41    0.504    0.012    0.702    0.017 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/collections/__init__.py:673(update)
       40    0.572    0.014    0.572    0.014 {built-in method _pickle.loads}
       40    0.002    0.000    0.203    0.005 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/process.py:110(start)
"""

In [None]:
### BPE Training on TinyStories w/O Inverted Index Optimization

# Profiling result below shows:
#    run_train_bpe: 125.404s total runtime (baseline to compare others).
#    train_bpe_on_pretokens: 94.138s (~75% of total) on BPE training i.e. merge loop.
# 
# Within the merge loop,
#    the dominant cost is repeatedly find_most_frequent_pair (65.180s; max over a large dict pair_counts).
#         builtins.max inside it: tottime 44.085 s
#         the key lambda used by max: 369,218,707 calls, 21.031 s
# 
# ~30.62s (~24%) is spent receiving pretoken counts from worker processes via multiprocessing Queue.
#
# How to read cProfile table:
#    Each row is a function.
#
#    ncalls: how many times it was called
#    tottime: time spent in that function body only (exclusive)
#    cumtime: time spent in that function plus all functions it called (inclusive)
#    percall: divides by ncalls
#    Sorted by cumtime (largest “overall cost” at top)

""" Profiling Results:

# Results
   Ordered by: cumulative time
   List reduced from 126 to 30 due to restriction <30>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.041    0.041  125.404  125.404 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:244(run_train_bpe)
        1    4.192    4.192   94.138   94.138 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:184(train_bpe_on_pretokens)
     9743    0.038    0.000   65.180    0.007 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:71(find_most_frequent_pair)
     9743   44.085    0.005   65.116    0.007 {built-in method builtins.max}
       40    0.062    0.002   30.797    0.770 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/queues.py:96(get)
       40    0.000    0.000   30.621    0.766 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:208(recv_bytes)
       40    0.000    0.000   30.621    0.766 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:429(_recv_bytes)
       80    0.000    0.000   30.621    0.383 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/connection.py:390(_recv)
      280   30.618    0.109   30.618    0.109 {built-in method posix.read}
     9743    2.374    0.000   24.489    0.003 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:106(merge_pair_and_update_counts)
369218707   21.031    0.000   21.031    0.000 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:78(<lambda>)
     9784    0.170    0.000   17.112    0.002 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/collections/__init__.py:673(update)
     9744    0.012    0.000   16.841    0.002 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/collections/__init__.py:734(copy)
     9745    0.025    0.000   16.829    0.002 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/collections/__init__.py:599(__init__)
     9745   16.767    0.002   16.767    0.002 {function Counter.update at 0x7ce170c35f80}
     9783    2.792    0.000    2.792    0.000 {method 'copy' of 'dict' objects}
   277780    0.590    0.000    0.900    0.000 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:83(new_pretoken_after_merge_pair)
  4068852    0.772    0.000    0.772    0.000 {method 'get' of 'dict' objects}
  1457032    0.306    0.000    0.306    0.000 {method 'add' of 'set' objects}
  1079666    0.276    0.000    0.276    0.000 {method 'discard' of 'set' objects}
        1    0.145    0.145    0.249    0.249 /home/haoru/cs336/cs336-hw1/cs336_basics/train_bpe.py:41(build_pair_counts_and_index)
  3641926    0.235    0.000    0.235    0.000 {built-in method builtins.len}
  1081780    0.185    0.000    0.185    0.000 {method 'setdefault' of 'dict' objects}
  1362718    0.152    0.000    0.152    0.000 {method 'keys' of 'dict' objects}
  1369343    0.133    0.000    0.133    0.000 {method 'append' of 'list' objects}
       40    0.002    0.000    0.115    0.003 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/process.py:110(start)
       40    0.114    0.003    0.114    0.003 {built-in method _pickle.loads}
       40    0.001    0.000    0.109    0.003 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/context.py:222(_Popen)
       40    0.002    0.000    0.108    0.003 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/context.py:279(_Popen)
       40    0.001    0.000    0.106    0.003 /home/haoru/.local/share/uv/python/cpython-3.13.7-linux-x86_64-gnu/lib/python3.13/multiprocessing/popen_fork.py:16(__init__)
"""


## 2.5 Problem: BPE Training on OpenWebText (2 points)

```(train_bpe_expts_owt)```

***(a)*** Train a byte-level BPE tokenizer on the OpenWebText dataset, using a maximum vocabulary size of 32,000. Serialize the resulting vocabulary and merges to disk for further inspection. What is the longest token in the vocabulary? Does it make sense?

***Resource requirements***: ≤ 12 hours (no GPUs), ≤ 100GB RAM 

**Answers**:
- Training time: **4109.53s = 1.14 hour** (with inverted index optimization & heap)
    
    `"data/owt_train.txt"; vocab_size = 32000, num_processes = 40`
- Longest token:

    `id=25822, len=64 bytes, text='ÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂÃÂ'`

***(b)*** Compare and contrast the tokenizer that you get training on TinyStories versus OpenWebText.


## 2.7 Problem: Experiments with tokenizers (4 points)

`(tokenizer_experiments)`

**(a)** Sample 10 documents from TinyStories and OpenWebText. Using your previously-trained TinyStories and OpenWebText tokenizers (10K and 32K vocabulary size, respectively), encode these sampled documents into integer IDs. What is each tokenizer’s compression ratio (bytes/token)?

**(b)** What happens if you tokenize your OpenWebText sample with the TinyStories tokenizer? Compare the compression ratio and/or qualitatively describe what happens.


**(c)** Estimate the throughput of your tokenizer (e.g., in bytes/second). How long would it take to tokenize the Pile dataset (825GB of text)?


**(d)** Using your TinyStories and OpenWebText tokenizers, encode the respective training and development datasets into a sequence of integer token IDs. We’ll use this later to train our language model. We recommend serializing the token IDs as a NumPy array of datatype uint16. Why is uint16 an appropriate choice?


### **Answers**

**(a) & (b)**

Used 10/10 documents for calculation...

Compression ratio:
- TinyStories tokenizer (10K vocab) on TinyStories docs: 4.191 bytes/token
- OpenWebText tokenizer (32K vocab) on OpenWebText docs: 4.311 bytes/token

Cross-domain comparison **(less efficient)**:
- TinyStories tokenizer on OpenWebText docs: 3.270 bytes/token
- OpenWebText tokenizer on TinyStories docs: 4.085 bytes/token



In [10]:
### Problem (tokenizer_experiments): Experiments with tokenizers

import random
from cs336_basics.tokenizer import Tokenizer

def load_tokenizer_yaml(fname):
    """
    Load vocab and merges from a YAML file, handling Python tuple tags safely.

    Guarantees the 0..255 byte symbols exist (so no KeyError on non‑ASCII).
    Uses any recoverable ASCII higher‑id tokens and merges.
    Skips unrecoverable tokens/merges that contain �, which were destroyed on save.
    """
    import yaml

    # Custom constructor for Python tuples
    yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',
        lambda l, n: tuple(l.construct_sequence(n)))
    
    d = yaml.safe_load(open(fname, "r", encoding="utf-8"))

    vocab = {i: bytes([i]) for i in range(256)}
    for k, v in d["vocab"].items():
        i = int(k)
        if i < 256:  # skip base bytes already added
            continue

        if isinstance(v, (list, tuple)):
            b = bytes(v)
        elif isinstance(v, str):
            b = v.encode("utf-8", "ignore")
            # using .encode("utf-8", "ignore") is equivalent to:
            # if isinstance(v, str) and v.isascii():
            #     b = v.encode("utf-8")
            # else:
            #     b = b""  # ignore non-ascii tokens
            if b:  # skip empty bytes
                vocab[i] = b
    
    merges = []
    for a, b in d["merges"]:
        if isinstance(a, (list, tuple)): 
            merges.append((bytes(a), bytes(b)))
        else:
            merge_a, merge_b = a.encode("utf-8", "ignore"), b.encode("utf-8", "ignore")
            if merge_a and merge_b:
                merges.append((merge_a, merge_b))
    return vocab, merges

def sample_documents(file_path, num_samples=10, seed=42):
    """Sample documents from a dataset file separated by <|endoftext|>."""
    random.seed(seed)
    
    # For large files, read in chunks to avoid memory issues
    documents = []
    with open(file_path, 'r', encoding='utf-8') as f:
        current_doc = ""
        for line in f:
            if '<|endoftext|>' in line:
                # Split on <|endoftext|> and handle multiple occurrences per line
                parts = line.split('<|endoftext|>')
                current_doc += parts[0]
                if current_doc.strip():
                    documents.append(current_doc.strip())
                
                # Handle remaining parts
                for part in parts[1:-1]:
                    if part.strip():
                        documents.append(part.strip())
                
                current_doc = parts[-1]
                
                # Stop if we have enough documents for sampling
                if len(documents) >= num_samples * 10:  # Get more than needed for good sampling
                    break
            else:
                current_doc += line
    
    # Add final document if exists
    if current_doc.strip():
        documents.append(current_doc.strip())
    
    # Sample random documents
    sampled_docs = random.sample(documents, min(num_samples, len(documents)))
    return sampled_docs

def can_tokenize(text, tokenizer):
    """Check if a text can be tokenized without errors."""
    try:
        tokenizer.encode(text)
        return True
    except KeyError as e:
        print(f"Doc failed; missing byte:", e)
        return False

def calculate_compression_ratio(documents, tokenizer):
    """Calculate compression ratio (bytes/token) for a list of documents."""
    total_bytes = 0
    total_tokens = 0
    valid_docs = 0
    
    for doc in documents:
        # Skip documents that can't be tokenized
        if not can_tokenize(doc, tokenizer):
            continue
            
        # Count bytes (UTF-8 encoding)
        doc_bytes = len(doc.encode('utf-8'))
        total_bytes += doc_bytes
        
        # Count tokens
        tokens = tokenizer.encode(doc)
        total_tokens += len(tokens)
        valid_docs += 1
    
    print(f"    Used {valid_docs}/{len(documents)} documents for calculation")
    return total_bytes / total_tokens if total_tokens > 0 else 0

# Load the trained tokenizers
tinystories_vocab, tinystories_merges = load_tokenizer_yaml("artifacts/tinystories_bpe.yaml")
owt_vocab, owt_merges = load_tokenizer_yaml("artifacts/owt_bpe.yaml")

# Create tokenizer objects
tinystories_tokenizer = Tokenizer(tinystories_vocab, tinystories_merges, special_tokens=["<|endoftext|>"])
owt_tokenizer = Tokenizer(owt_vocab, owt_merges, special_tokens=["<|endoftext|>"])

# Sample documents from both datasets
print("Sampling documents...")
tinystories_docs = sample_documents("data/TinyStoriesV2-GPT4-train.txt", num_samples=10)
owt_docs = sample_documents("data/owt_train.txt", num_samples=10)

print(f"Sampled {len(tinystories_docs)} TinyStories documents")
print(f"Sampled {len(owt_docs)} OpenWebText documents")

# Calculate compression ratios
print("\nCalculating compression ratios...")

# TinyStories tokenizer on TinyStories documents
ts_on_ts_ratio = calculate_compression_ratio(tinystories_docs, tinystories_tokenizer)
print(f"TinyStories tokenizer (10K vocab) on TinyStories docs: {ts_on_ts_ratio:.3f} bytes/token")

# OpenWebText tokenizer on OpenWebText documents  
owt_on_owt_ratio = calculate_compression_ratio(owt_docs, owt_tokenizer)
print(f"OpenWebText tokenizer (32K vocab) on OpenWebText docs: {owt_on_owt_ratio:.3f} bytes/token")

# Cross-domain evaluation for comparison
ts_on_owt_ratio = calculate_compression_ratio(owt_docs, tinystories_tokenizer)
owt_on_ts_ratio = calculate_compression_ratio(tinystories_docs, owt_tokenizer)

print(f"\nCross-domain comparison:")
print(f"TinyStories tokenizer on OpenWebText docs: {ts_on_owt_ratio:.3f} bytes/token")
print(f"OpenWebText tokenizer on TinyStories docs: {owt_on_ts_ratio:.3f} bytes/token")

Sampling documents...
Sampled 10 TinyStories documents
Sampled 10 OpenWebText documents

Calculating compression ratios...
    Used 10/10 documents for calculation
TinyStories tokenizer (10K vocab) on TinyStories docs: 4.191 bytes/token
    Used 10/10 documents for calculation
OpenWebText tokenizer (32K vocab) on OpenWebText docs: 4.311 bytes/token
    Used 10/10 documents for calculation
    Used 10/10 documents for calculation

Cross-domain comparison:
TinyStories tokenizer on OpenWebText docs: 3.270 bytes/token
OpenWebText tokenizer on TinyStories docs: 4.085 bytes/token


In [None]:
import numpy as np
from multiprocessing import Process, Queue

from cs336_basics.tokenizer import Tokenizer
from cs336_basics.pretokenization import find_chunk_boundaries

def encode_worker(start: int, end: int, input_path: str, tokenizer, q: Queue):
    """
    Worker function to encode a chunk of the file
    """
    with open(input_path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        
        # Encode chunk and put tokens in queue
        tokens = []
        for token_id in tokenizer.encode_iterable([chunk]):
            tokens.append(token_id)
        q.put(tokens)

def encode_dataset_parallel(input_path, output_path, tokenizer, num_chunks):
    """
    Encode dataset using chunking approach from pretokenization.py and
    parallel processing following run_train_bpe() pattern.
    """
    print(f"Starting parallel encoding with {num_chunks} processes")

    # Create processes and queue (same pattern as run_train_bpe)
    processes = []
    q = Queue()
    
    with open(input_path, "rb") as f:
        # Use the same chunking logic as pretokenization.py
        boundaries = find_chunk_boundaries(f, num_chunks, b"<|endoftext|>")
        
        # Process each chunk separately
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            p = Process(target=encode_worker, args=(start, end, input_path, tokenizer, q))
            p.start()
            processes.append(p)
    
        # Collect and merge tokens from workers
        all_tokens = []
        for _ in range(len(processes)):
            all_tokens.extend(q.get())
    
        # Wait for all processes to complete
        for p in processes:
            p.join()
        
        # Convert to uint16 numpy array
        tokens_array = np.array(all_tokens, dtype=np.uint16)
        np.save(output_path, tokens_array)
        print(f"Saved {len(tokens_array)} tokens to {output_path}")
        return tokens_array

# Load tokenizers
# tinystories_vocab, tinystories_merges = load_tokenizer_yaml("artifacts/tinystories_bpe.yaml")
owt_vocab, owt_merges = load_tokenizer_yaml("artifacts/owt_bpe.yaml")

# ts_tokenizer = Tokenizer(tinystories_vocab, tinystories_merges, special_tokens=["<|endoftext|>"])
owt_tokenizer = Tokenizer(owt_vocab, owt_merges, special_tokens=["<|endoftext|>"])

# Encode datasets
# print("Encoding TinyStories dataset...")
# encode_dataset_parallel("data/TinyStoriesV2-GPT4-train.txt", "data/tinystories_train_tokens.npy", ts_tokenizer)

print("Encoding OpenWebText dataset...")
encode_dataset_parallel("data/owt_train.txt", "data/owt_train_tokens.npy", owt_tokenizer, num_chunks=100)

"""
23m 25.6s
Encoding TinyStories dataset...
Processing chunk 1/4, size: 556710563 chars
Processing chunk 2/4, size: 556712694 chars
Processing chunk 3/4, size: 556712530 chars
Processing chunk 4/4, size: 556709481 chars
Saved 542447487 tokens to data/tinystories_train_tokens.npy
array([ 10, 430, 439, ..., 317,  89, 111],
      shape=(542447487,), dtype=uint16)
"""

Encoding OpenWebText dataset...
Processing chunk 1/4, size: 2954033787 chars
Processing chunk 2/4, size: 2953974483 chars
Processing chunk 3/4, size: 2953958612 chars
