## user-end goals
- [ ] use fineweb-edu, sample-10B dataset
- [ ] join documents with <BOS>
- [ ] add max document length (2x context length) - split at reasonable point
## model changes
- [ ] (maybe) add GQA
- [ ] remove bias in linear layers
- [ ] document masking

In [4]:
import requests
import os
import sys
import pyarrow
import pyarrow.parquet as pq
from pathlib import Path
from tqdm import tqdm

def get_project_info() -> tuple[Path, Path]:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

DATA_DIR = ROOT_DIR / "data/fineweb-edu/sample/10BT"
DATA_DIR.mkdir(parents=True, exist_ok=True)

from toy_transformers.tokenization import create_bpe, TokenizationMode, Vocabulary

In [3]:
BATCH_SIZE_BYTES = 100 * 1024 * 1024
API_URL = "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-edu/parquet/sample-10BT/train"

response = requests.get(API_URL)
response.raise_for_status()
SHARD_URLS = response.json()

def download_shard(url, dst: Path):
	if dst.exists():
		return
	
	tmp = dst.with_suffix(".tmp")
	try:
		with requests.get(url, stream=True, headers={"User-Agent": "python"}) as r:
			r.raise_for_status()

			total = int(r.headers.get("Content-Length", 0))
			with open(tmp, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dst.name) as bar:
				for chunk in r.iter_content(chunk_size=1024*1024):
					if chunk:
						f.write(chunk)
						bar.update(len(chunk))
		tmp.rename(dst)

	except Exception as e:
		if tmp.exists():
			tmp.unlink()
		raise e

def stream_raw_ds(columns=None, shards=None, batch_size_bytes=BATCH_SIZE_BYTES):
	idxs = shards if shards is not None else range(len(SHARD_URLS))

	batch_tables = []
	batch_bytes = 0

	for i in idxs:
		dst = DATA_DIR / f"{i:02d}.parquet"
		download_shard(SHARD_URLS[i], dst)

		pf = pq.ParquetFile(dst)

		for rg in range(pf.metadata.num_row_groups):
			table = pf.read_row_group(rg, columns=columns)
			batch_tables.append(table)
			batch_bytes += table.nbytes

			if batch_bytes >= batch_size_bytes:
				yield pyarrow.concat_tables(batch_tables)
				batch_tables = []
				batch_bytes = 0

	if batch_tables:
		yield pyarrow.concat_tables(batch_tables)

In [4]:
VOCAB_SIZE = 1 << 15
BOS = "<BOS>"
SPECIAL_TOKENS = [BOS]

def stream_texts(shards=None):
	for batch in stream_raw_ds(columns=["text"], shards=shards):
		yield (BOS.join(batch["text"].to_pylist())).encode('utf-8')

vocab = create_bpe(
	data_iter=stream_texts(shards=[0, 1, 2, 3, 4]),
	vocab_size=VOCAB_SIZE,
	mode=TokenizationMode.BYTES,
	special_tokens=SPECIAL_TOKENS
)
vocab.save(EXPERIMENT_DIR / "data/vocab_32k.json")

preprocessing: 0shard [00:00, ?shard/s]

01.parquet: 100%|██████████| 2.15G/2.15G [00:44<00:00, 48.5MB/s]
02.parquet: 100%|██████████| 2.15G/2.15G [00:46<00:00, 46.1MB/s]
03.parquet: 100%|██████████| 2.15G/2.15G [00:43<00:00, 49.3MB/s]
04.parquet: 100%|██████████| 2.15G/2.15G [00:48<00:00, 44.1MB/s]
preprocessing: 162shard [07:40,  2.84s/shard]


starting merging...


BPE Training: 100%|██████████| 32511/32511 [02:27<00:00, 220.66it/s] 


In [5]:
vocab = Vocabulary.load(EXPERIMENT_DIR / "data/vocab_32k.json")

In [39]:
import random
for batch in stream_raw_ds(columns=["text"], shards=[0]):
    docs = batch["text"].to_pylist()
    doc = random.choice(docs)
    break

text = "<BOS>" + doc[:500]  # first 500 chars to keep output readable
encoded = vocab.encode(text.encode("utf-8"))
decoded_tokens = vocab.decode(encoded)
reconstructed = b"".join(decoded_tokens).decode("utf-8")

print(f"\n=== Original (first 200 chars) ===\n{text[:200]}")
print(f"\n=== Token IDs (first 30) ===\n{encoded[:30]}")
print(f"\n=== Decoded tokens (first 30) ===")
for t in decoded_tokens[:30]:
    print(f"  {t.decode('utf-8', errors='replace')!r}")
print(f"\n=== Roundtrip matches: {reconstructed == text}")


=== Original (first 200 chars) ===
<BOS>Water: What You Can Do
What You Can Do
Provides information on how you can get involved including ways to protect human health and the environment by raising awareness about potential threats to 

=== Token IDs (first 30) ===
[0, 11148, 59, 1759, 1227, 1606, 2474, 11, 1604, 1227, 1606, 2474, 11, 16295, 1718, 1019, 334, 656, 355, 411, 858, 2701, 1221, 1912, 289, 1972, 1157, 818, 288, 262]

=== Decoded tokens (first 30) ===
  '<BOS>'
  'Water'
  ':'
  ' What'
  ' You'
  ' Can'
  ' Do'
  '\n'
  'What'
  ' You'
  ' Can'
  ' Do'
  '\n'
  'Prov'
  'ides'
  ' information'
  ' on'
  ' how'
  ' you'
  ' can'
  ' get'
  ' involved'
  ' including'
  ' ways'
  ' to'
  ' protect'
  ' human'
  ' health'
  ' and'
  ' the'

=== Roundtrip matches: True
