## user-end goals
- [ ] use fineweb-edu, sample-10B dataset
- [ ] join documents with <BOS>
- [ ] add max document length (2x context length) - split at reasonable point
## model changes
- [ ] (maybe) add GQA
- [ ] remove bias in linear layers
- [ ] document masking

In [1]:
import requests
import os
import sys
import pyarrow
import pyarrow.parquet as pq
from pathlib import Path
from tqdm import tqdm

def get_project_info() -> tuple[Path, Path]:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

DATA_DIR = ROOT_DIR / "data/fineweb-edu/sample/10BT"
DATA_DIR.mkdir(parents=True, exist_ok=True)

from toy_transformers.tokenization import create_bpe, bulk_encode, Vocabulary, TokenizationMode

In [2]:
BATCH_SIZE_BYTES = 100 * 1024 * 1024
API_URL = "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-edu/parquet/sample-10BT/train"

response = requests.get(API_URL)
response.raise_for_status()
SHARD_URLS = response.json()

def download_shard(url, dst: Path):
	if dst.exists():
		return
	
	tmp = dst.with_suffix(".tmp")
	try:
		with requests.get(url, stream=True, headers={"User-Agent": "python"}) as r:
			r.raise_for_status()

			total = int(r.headers.get("Content-Length", 0))
			with open(tmp, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dst.name) as bar:
				for chunk in r.iter_content(chunk_size=1024*1024):
					if chunk:
						f.write(chunk)
						bar.update(len(chunk))
		tmp.rename(dst)

	except Exception as e:
		if tmp.exists():
			tmp.unlink()
		raise e

def stream_raw_ds(columns=None, shards=None, batch_size_bytes=BATCH_SIZE_BYTES):
	idxs = shards if shards is not None else range(len(SHARD_URLS))

	batch_tables = []
	batch_bytes = 0

	for i in idxs:
		dst = DATA_DIR / f"{i:02d}.parquet"
		download_shard(SHARD_URLS[i], dst)

		pf = pq.ParquetFile(dst)

		for rg in range(pf.metadata.num_row_groups):
			table = pf.read_row_group(rg, columns=columns)
			batch_tables.append(table)
			batch_bytes += table.nbytes

			if batch_bytes >= batch_size_bytes:
				yield pyarrow.concat_tables(batch_tables)
				batch_tables = []
				batch_bytes = 0

	if batch_tables:
		yield pyarrow.concat_tables(batch_tables)

In [7]:
VOCAB_SIZE = 1 << 15
BOS = "<BOS>"
SPECIAL_TOKENS = [BOS]

VOCAB_PATH = EXPERIMENT_DIR / "data/vocab_32k.json"
vocab = Vocabulary.load(VOCAB_PATH)
OUTPUT_DIR = EXPERIMENT_DIR / "data/encoded"

def stream_texts(shards=None):
	for batch in stream_raw_ds(columns=["text"], shards=shards):
		yield (BOS + BOS.join(batch["text"].to_pylist())).encode('utf-8')

In [None]:
vocab = create_bpe(
	data_iter=stream_texts(shards=[0, 1]),
	vocab_size=VOCAB_SIZE,
	mode=TokenizationMode.BYTES,
	special_tokens=SPECIAL_TOKENS
)

vocab.save(EXPERIMENT_DIR / "data/vocab_32k.json")

preprocessing: 0shard [00:00, ?shard/s]

01.parquet: 100%|██████████| 2.15G/2.15G [01:12<00:00, 29.7MB/s]
preprocessing: 65shard [02:43,  2.51s/shard]


starting merging...


BPE Training: 100%|██████████| 32511/32511 [01:08<00:00, 476.79it/s] 


In [8]:
bulk_encode(
	doc_iter=stream_texts(shards=[0]),
	vocab=vocab,
	vocab_path=VOCAB_PATH,
	output_dir=OUTPUT_DIR,
	split_token=BOS
)

encoding: 5chunk [00:12,  1.50s/chunk]

wrote 99,983,250 tokens to shard_0000.bin


encoding: 9chunk [00:14,  1.49chunk/s]

wrote 99,999,435 tokens to shard_0001.bin


encoding: 13chunk [00:20,  1.66s/chunk]

wrote 99,999,023 tokens to shard_0002.bin


encoding: 18chunk [00:23,  1.14chunk/s]

wrote 99,999,170 tokens to shard_0003.bin


encoding: 22chunk [00:26,  1.40chunk/s]

wrote 99,981,093 tokens to shard_0004.bin


encoding: 26chunk [00:30,  1.13chunk/s]

wrote 99,999,188 tokens to shard_0005.bin


encoding: 31chunk [00:34,  1.32chunk/s]

wrote 99,998,612 tokens to shard_0006.bin


encoding: 33chunk [00:34,  1.04s/chunk]

wrote 51,858,617 tokens to shard_0007.bin





In [9]:
from toy_transformers.tokenization import _read_shard

shard = _read_shard(OUTPUT_DIR / "shard_0003.bin")
print(f"Shard shape: {shard.shape}, dtype: {shard.dtype}")
print(f"First 20 token IDs: {shard[:20]}")

# decode first 200 tokens
decoded = vocab.decode(shard[:200].tolist())
text = b"".join(decoded).decode("utf-8", errors="replace")
print(f"\n=== First 200 tokens decoded ===\n{text}")

# check BOS positions
bos_id = vocab.token_to_idx[b"<BOS>"]
bos_positions = (shard == bos_id).nonzero()[0]
print(f"\nBOS id: {bos_id}")
print(f"Number of documents in shard: {len(bos_positions)}")
print(f"First 5 BOS positions: {bos_positions[:5]}")

Shard shape: (99999170,), dtype: uint16
First 20 token IDs: [    0    83   323  9422  4949   288   477  1187   387  4536    11    46
 22609  2385   288  1205  2917   281  1699  8552]

=== First 200 tokens decoded ===
<BOS>Rotavirus vaccine and intussusception
- Inform parents and carers of young infants receiving rotavirus vaccine of the rare risk of intussusception following the vaccine and how to be alert to the signs and symptoms of the condition.
- Do not give rotavirus vaccine outside the recommended age limits.
- Do not give rotavirus vaccine to a baby with a history of intussusception.
- Report any cases of intussusception following rotavirus vaccination through the usual reporting arrangements for adverse events following immunisation in your State and Territory.
Risk of intussusception
- There is new evidence from Australian and overseas studies suggesting a small increased risk of intussusception in infants following rotavirus vaccination.
- The increased risk appears to occu

In [10]:
# Cell - Verify document integrity across shards
import numpy as np
from toy_transformers.tokenization import _read_shard

shard_dir = OUTPUT_DIR
shard_files = sorted(shard_dir.glob("shard_*.bin"))
bos_id = vocab.token_to_idx[b"<BOS>"]

print(f"BOS id: {bos_id}")
print(f"Found {len(shard_files)} shards\n")

# 1) Check that every shard starts with BOS (except possibly shard 0 
#    if the first batch didn't start with BOS)
for path in shard_files:
    shard = _read_shard(path)
    starts_with_bos = shard[0] == bos_id if len(shard) > 0 else False
    bos_count = np.sum(shard == bos_id)
    print(f"{path.name}: {len(shard):>12,} tokens, {bos_count:>6,} docs, starts_with_bos={starts_with_bos}")

# 2) Check cross-shard boundaries: load consecutive pairs and verify
#    shard N+1 starts with BOS (meaning shard N ended at a document boundary)
print("\n=== Cross-shard boundary check ===")
issues = 0
for i in range(len(shard_files) - 1):
    cur = _read_shard(shard_files[i])
    nxt = _read_shard(shard_files[i + 1])
    
    next_starts_bos = nxt[0] == bos_id if len(nxt) > 0 else True
    if not next_starts_bos:
        # decode tokens around the boundary to see what happened
        tail = vocab.decode(cur[-10:].tolist())
        head = vocab.decode(nxt[:10].tolist())
        tail_str = b"".join(tail).decode("utf-8", errors="replace")
        head_str = b"".join(head).decode("utf-8", errors="replace")
        print(f"SPLIT at shard {i}/{i+1}:")
        print(f"  tail: ...{tail_str!r}")
        print(f"  head: {head_str!r}...")
        issues += 1

if issues == 0:
    print("All shard boundaries align with document boundaries.")
else:
    print(f"\n{issues} document(s) split across shards!")

# 3) Check stream_texts: verify first batch starts with BOS
print("\n=== stream_texts BOS check ===")
for i, chunk in enumerate(stream_texts(shards=[0])):
    starts_bos = chunk.startswith(b"<BOS>")
    # check if it ends mid-document (no trailing BOS) — not necessarily 
    # an issue since the next chunk should start with BOS
    print(f"  chunk {i}: {len(chunk):,} bytes, starts_with_BOS={starts_bos}")
    if i >= 4:
        print(f"  ... ({i} more chunks)")
        break

BOS id: 0
Found 8 shards

shard_0000.bin:   99,983,250 tokens, 21,437 docs, starts_with_bos=True
shard_0001.bin:   99,999,435 tokens, 21,952 docs, starts_with_bos=True
shard_0002.bin:   99,999,023 tokens, 21,885 docs, starts_with_bos=True
shard_0003.bin:   99,999,170 tokens, 21,447 docs, starts_with_bos=True
shard_0004.bin:   99,981,093 tokens, 21,560 docs, starts_with_bos=True
shard_0005.bin:   99,999,188 tokens, 21,780 docs, starts_with_bos=True
shard_0006.bin:   99,998,612 tokens, 21,738 docs, starts_with_bos=True
shard_0007.bin:   51,858,617 tokens, 10,930 docs, starts_with_bos=True

=== Cross-shard boundary check ===
All shard boundaries align with document boundaries.

=== stream_texts BOS check ===
  chunk 0: 107,504,256 bytes, starts_with_BOS=True
  chunk 1: 105,618,284 bytes, starts_with_BOS=True
  chunk 2: 107,026,664 bytes, starts_with_BOS=True
  chunk 3: 105,112,500 bytes, starts_with_BOS=True
  chunk 4: 107,668,853 bytes, starts_with_BOS=True
  ... (4 more chunks)
