## user-end goals
- [ ] use fineweb-edu, sample-10B dataset
- [ ] join documents with <BOS>
- [ ] add max document length (2x context length) - split at reasonable point
## model changes
- [ ] (maybe) add GQA
- [ ] remove bias in linear layers
- [ ] document masking

In [1]:
import requests
import os
import sys
import pyarrow
import pyarrow.parquet as pq
from pathlib import Path
from tqdm import tqdm

def get_project_info() -> tuple[Path, Path]:
  current = Path.cwd().resolve()
  root = current
  for parent in [current, *current.parents]:
    if (parent / "toy_transformers").exists():
      root = parent
      break
  return root, current

if 'ROOT_DIR' not in globals():
	ROOT_DIR, EXPERIMENT_DIR = get_project_info()
	if str(ROOT_DIR) not in sys.path:
		sys.path.append(str(ROOT_DIR))
	if Path.cwd() != ROOT_DIR:
		os.chdir(ROOT_DIR)

DATA_DIR = ROOT_DIR / "data/fineweb-edu/sample/10BT"
DATA_DIR.mkdir(parents=True, exist_ok=True)

from toy_transformers.tokenization import create_bpe, bulk_encode, Vocabulary, TokenizationMode

In [2]:
BATCH_SIZE_BYTES = 100 * 1024 * 1024
API_URL = "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb-edu/parquet/sample-10BT/train"

response = requests.get(API_URL)
response.raise_for_status()
SHARD_URLS = response.json()

def download_shard(url, dst: Path):
	if dst.exists():
		return
	
	tmp = dst.with_suffix(".tmp")
	try:
		with requests.get(url, stream=True, headers={"User-Agent": "python"}) as r:
			r.raise_for_status()

			total = int(r.headers.get("Content-Length", 0))
			with open(tmp, "wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=dst.name) as bar:
				for chunk in r.iter_content(chunk_size=1024*1024):
					if chunk:
						f.write(chunk)
						bar.update(len(chunk))
		tmp.rename(dst)

	except Exception as e:
		if tmp.exists():
			tmp.unlink()
		raise e

def stream_raw_ds(columns=None, shards=None, batch_size_bytes=BATCH_SIZE_BYTES):
	idxs = shards if shards is not None else range(len(SHARD_URLS))

	batch_tables = []
	batch_bytes = 0

	for i in idxs:
		dst = DATA_DIR / f"{i:02d}.parquet"
		download_shard(SHARD_URLS[i], dst)

		pf = pq.ParquetFile(dst)

		for rg in range(pf.metadata.num_row_groups):
			table = pf.read_row_group(rg, columns=columns)
			batch_tables.append(table)
			batch_bytes += table.nbytes

			if batch_bytes >= batch_size_bytes:
				yield pyarrow.concat_tables(batch_tables)
				batch_tables = []
				batch_bytes = 0

	if batch_tables:
		yield pyarrow.concat_tables(batch_tables)

In [3]:
VOCAB_SIZE = 1 << 15
BOS = "<BOS>"
SPECIAL_TOKENS = [BOS]

VOCAB_PATH = EXPERIMENT_DIR / "data/vocab_32k.json"
vocab = Vocabulary.load(VOCAB_PATH)
OUTPUT_DIR = EXPERIMENT_DIR / "data/encoded"

def stream_texts(shards=None):
	for batch in stream_raw_ds(columns=["text"], shards=shards):
		yield (BOS + BOS.join(batch["text"].to_pylist())).encode('utf-8')

In [None]:
vocab = create_bpe(
	data_iter=stream_texts(shards=[0, 1]),
	vocab_size=VOCAB_SIZE,
	mode=TokenizationMode.BYTES,
	special_tokens=SPECIAL_TOKENS
)

vocab.save(EXPERIMENT_DIR / "data/vocab_32k.json")

preprocessing: 0shard [00:00, ?shard/s]

01.parquet: 100%|██████████| 2.15G/2.15G [01:12<00:00, 29.7MB/s]
preprocessing: 65shard [02:43,  2.51s/shard]


starting merging...


BPE Training: 100%|██████████| 32511/32511 [01:08<00:00, 476.79it/s] 


In [4]:
bulk_encode(
	doc_iter=stream_texts(shards=[0]),
	vocab=vocab,
	vocab_path=VOCAB_PATH,
	output_dir=OUTPUT_DIR,
	split_token=BOS
)

encoding: 5chunk [00:19,  2.69s/chunk]

wrote 99,988,208 tokens to shard_0000.bin


encoding: 9chunk [00:33,  2.83s/chunk]

wrote 99,997,731 tokens to shard_0001.bin


encoding: 13chunk [00:49,  4.81s/chunk]

wrote 99,996,422 tokens to shard_0002.bin


encoding: 18chunk [00:55,  1.87s/chunk]

wrote 99,990,484 tokens to shard_0003.bin


encoding: 22chunk [01:09,  2.33s/chunk]

wrote 99,997,550 tokens to shard_0004.bin


encoding: 26chunk [01:23,  3.42s/chunk]

wrote 99,995,174 tokens to shard_0005.bin


encoding: 33chunk [01:35,  2.89s/chunk]

wrote 99,987,983 tokens to shard_0006.bin
wrote 51,902,347 tokens to shard_0007.bin





In [9]:
from toy_transformers.tokenization import _read_shard

shard = _read_shard(OUTPUT_DIR / "shard_0003.bin")
print(f"Shard shape: {shard.shape}, dtype: {shard.dtype}")
print(f"First 20 token IDs: {shard[:20]}")

# decode first 200 tokens
decoded = vocab.decode(shard[:200].tolist())
text = b"".join(decoded).decode("utf-8", errors="replace")
print(f"\n=== First 200 tokens decoded ===\n{text}")

# check BOS positions
bos_id = vocab.token_to_idx[b"<BOS>"]
bos_positions = (shard == bos_id).nonzero()[0]
print(f"\nBOS id: {bos_id}")
print(f"Number of documents in shard: {len(bos_positions)}")
print(f"First 5 BOS positions: {bos_positions[:5]}")

Shard shape: (99999170,), dtype: uint16
First 20 token IDs: [    0    83   323  9422  4949   288   477  1187   387  4536    11    46
 22609  2385   288  1205  2917   281  1699  8552]

=== First 200 tokens decoded ===
<BOS>Rotavirus vaccine and intussusception
- Inform parents and carers of young infants receiving rotavirus vaccine of the rare risk of intussusception following the vaccine and how to be alert to the signs and symptoms of the condition.
- Do not give rotavirus vaccine outside the recommended age limits.
- Do not give rotavirus vaccine to a baby with a history of intussusception.
- Report any cases of intussusception following rotavirus vaccination through the usual reporting arrangements for adverse events following immunisation in your State and Territory.
Risk of intussusception
- There is new evidence from Australian and overseas studies suggesting a small increased risk of intussusception in infants following rotavirus vaccination.
- The increased risk appears to occu

In [5]:
from toy_transformers.tokenization import shuffle_shards

shuffle_shards(OUTPUT_DIR, EXPERIMENT_DIR / "data/shuffled")

shuffling: 162729doc [00:01, 122486.18doc/s]


In [9]:
from toy_transformers.tokenization import _read_shard

shard = _read_shard(EXPERIMENT_DIR / "data/shuffled/shard_0001.bin")
print(f"Shard shape: {shard.shape}, dtype: {shard.dtype}")
print(f"First 20 token IDs: {shard[:20]}")

# decode first 200 tokens
decoded = vocab.decode(shard[:300].tolist())
text = b"".join(decoded).decode("utf-8", errors="replace")
print(f"\n=== First 200 tokens decoded ===\n{text}")

# check BOS positions
bos_id = vocab.token_to_idx[b"<BOS>"]
bos_positions = (shard == bos_id).nonzero()[0]
print(f"\nBOS id: {bos_id}")
print(f"Number of documents in shard: {len(bos_positions)}")
print(f"First 5 BOS positions: {bos_positions[:5]}")

Shard shape: (94997855,), dtype: uint16
First 20 token IDs: [    0   457 17175 11597    11  2030   499   262  2655    45 27705   288
 24267   284 11597 31772   383   116  2980    45]

=== First 200 tokens decoded ===
<BOS>The Independent Jane
For all the love, romance and scandal in Jane Austen’s books, what they are really about is freedom and independence. Independence of thought and the freedom to choose.
Elizabeth’s refusal of Mr. Collins offer of marriage showed an independence seldom seen in heroines of the day. Her refusal of Mr. Darcy while triggered by anger showed a level of independence that left him shocked and stunned.
The freedom she exhibited in finally accepting him in direct defiance of Lady Catherine and knowing her father would disapprove was unusual even for Austen. In her last book Anne Elliot is persuaded to refuse Captain Wentworth at Lady Russel’s insistence.
Although Jane played by the rules of the day, all of her writing is infused with how she wanted life to 