In [1]:
import os
import hashlib
from typing import List, Dict

from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import re

DATA_DIR = "./processed_data"
os.makedirs(DATA_DIR, exist_ok=True)

TOKENIZER_NAME = "gpt2"
BLOCK_SIZE = 512
MIN_WORDS = 50

TARGET_SIZE_GB = 1.5
TARGET_CHARS_TOTAL = int(TARGET_SIZE_GB * 1e9)

SOURCE_FRACTIONS = {
    "wikipedia": 0.4,
    "fineweb": 0.4,
    "ag_news": 0.2,
}

SAMPLE_PT_PATH = os.path.join(DATA_DIR, "sample_dataset.pt")


In [2]:
def load_streaming_corpora():

    print("Loading Wikipedia (wikimedia/wikipedia, streaming)...")
    wiki = load_dataset(
        "wikimedia/wikipedia",
        "20231101.en",
        split="train",
        streaming=True,
    )

    print("Loading FineWeb-Edu (HuggingFaceFW/fineweb-edu, streaming)...")
    fineweb = load_dataset(
        "HuggingFaceFW/fineweb-edu",
        name="sample-10BT",
        split="train",
        streaming=True,
    )

    print("Loading AG News (SetFit/ag_news, regular load)...")
    ag = load_dataset("SetFit/ag_news", split="train")
    ag_iter = ag.to_iterable_dataset()

    return {
        "wikipedia": wiki,
        "fineweb": fineweb,
        "ag_news": ag_iter,
    }

streaming_corpora = load_streaming_corpora()
streaming_corpora



Loading Wikipedia (wikimedia/wikipedia, streaming)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Loading FineWeb-Edu (HuggingFaceFW/fineweb-edu, streaming)...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Loading AG News (SetFit/ag_news, regular load)...


train.jsonl:   0%|          | 0.00/33.8M [00:00<?, ?B/s]

test.jsonl:   0%|          | 0.00/2.13M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

{'wikipedia': IterableDataset({
     features: ['id', 'url', 'title', 'text'],
     num_shards: 41
 }),
 'fineweb': IterableDataset({
     features: ['text', 'id', 'dump', 'url', 'file_path', 'language', 'language_score', 'token_count', 'score', 'int_score'],
     num_shards: 14
 }),
 'ag_news': IterableDataset({
     features: ['text', 'label', 'label_text'],
     num_shards: 1
 })}

In [3]:
whitespace_re = re.compile(r"\s+")
def clean_text(text: str) -> str:
    """
    Basic cleaning:
    - lowercase
    - remove simple HTML-like tags
    - normalize whitespace
    """
    if text is None:
        return ""

    text = text.lower()
    text = re.sub(r"<[^>]+>", " ", text)
    text = whitespace_re.sub(" ", text)
    text = text.strip()
    return text


In [4]:
def collect_docs_to_target_size(streaming_corpora, target_chars_total: int) -> List[str]:
    """
    Stream from multiple datasets and collect cleaned, deduplicated documents
    until we reach approximately `target_chars_total` characters (~1GB).
    Uses SOURCE_FRACTIONS to allocate per-source quotas.
    """
    seen_hashes = set()
    cleaned_docs: List[str] = []

    total_chars_global = 0

    for source_name, fraction in SOURCE_FRACTIONS.items():
        ds = streaming_corpora[source_name]
        target_chars_source = int(target_chars_total * fraction)
        chars_source = 0

        print(f"\nCollecting from {source_name} (target ~{target_chars_source / 1e6:.1f}M chars)...")

        for row in tqdm(ds, desc=f"{source_name} streaming"):
            if source_name == "wikipedia":
                text = row.get("text", "")
            elif source_name == "fineweb":
                text = row.get("text", "")
            else:
                text = row.get("text", "")

            text = clean_text(text)
            if not text:
                continue

            if len(text.split()) < MIN_WORDS:
                continue

            key_str = text[:2000]
            h = hashlib.md5(key_str.encode("utf-8")).hexdigest()
            if h in seen_hashes:
                continue
            seen_hashes.add(h)

            cleaned_docs.append(text)
            chars = len(text)
            chars_source += chars
            total_chars_global += chars

            if chars_source >= target_chars_source:
                print(f"Reached target for {source_name}: {chars_source / 1e6:.1f}M chars.")
                break

        print(f"{source_name}: collected {chars_source / 1e6:.1f}M chars.")

    print(f"\nTOTAL collected chars: {total_chars_global} → ~{total_chars_global / 1e9:.2f} GB")
    return cleaned_docs

cleaned_docs = collect_docs_to_target_size(streaming_corpora, TARGET_CHARS_TOTAL)
len(cleaned_docs)



Collecting from wikipedia (target ~600.0M chars)...


wikipedia streaming: 133404it [00:56, 2374.14it/s]


Reached target for wikipedia: 600.0M chars.
wikipedia: collected 600.0M chars.

Collecting from fineweb (target ~600.0M chars)...


fineweb streaming: 125865it [01:10, 1783.97it/s]


Reached target for fineweb: 600.0M chars.
fineweb: collected 600.0M chars.

Collecting from ag_news (target ~300.0M chars)...


ag_news streaming: 120000it [00:14, 8185.77it/s]

ag_news: collected 3.5M chars.

TOTAL collected chars: 1203536778 → ~1.20 GB





262042

In [5]:
def tokenize_and_chunk(docs: List[str], tokenizer, block_size: int) -> List[List[int]]:
    """
    Convert a list of cleaned docs into a big list of token blocks,
    each of length <= block_size.
    """
    token_blocks: List[List[int]] = []

    for doc in tqdm(docs, desc="Tokenizing & chunking"):
        ids = tokenizer.encode(doc, add_special_tokens=False)

        for i in range(0, len(ids), block_size):
            block = ids[i : i + block_size]
            if len(block) == 0:
                continue
            token_blocks.append(block)

    print(f"Total token blocks: {len(token_blocks)}")
    return token_blocks

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

token_blocks = tokenize_and_chunk(cleaned_docs, tokenizer, BLOCK_SIZE)
len(token_blocks)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Tokenizing & chunking:   0%|          | 0/262042 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (8524 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing & chunking: 100%|██████████| 262042/262042 [17:27<00:00, 250.07it/s] 

Total token blocks: 662834





662834

In [6]:
class TokenBlockDataset(Dataset):
    def __init__(self, token_blocks: List[List[int]]):
        self.token_blocks = token_blocks

    def __len__(self):
        return len(self.token_blocks)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        ids = self.token_blocks[idx]
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long)
        }


In [7]:
def collate_fn_pad(batch, pad_token_id: int):
    """
    Pads a batch of variable-length input_ids to the same length,
    and creates attention_mask.
    """
    input_ids = [item["input_ids"] for item in batch]
    lengths = [len(x) for x in input_ids]
    max_len = max(lengths)

    padded_ids = []
    attention_masks = []

    for ids in input_ids:
        pad_len = max_len - len(ids)
        if pad_len > 0:
            padded = torch.cat(
                [ids, torch.full((pad_len,), pad_token_id, dtype=torch.long)]
            )
        else:
            padded = ids
        padded_ids.append(padded)
        attention_masks.append((padded != pad_token_id).long())

    return {
        "input_ids": torch.stack(padded_ids, dim=0),
        "attention_mask": torch.stack(attention_masks, dim=0),
    }


In [8]:
token_dataset = TokenBlockDataset(token_blocks)

dataloader = DataLoader(
    token_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=lambda b: collate_fn_pad(b, tokenizer.pad_token_id),
)

sample_batch = next(iter(dataloader))

print("input_ids shape:", sample_batch["input_ids"].shape)
print("attention_mask shape:", sample_batch["attention_mask"].shape)

torch.save(sample_batch, SAMPLE_PT_PATH)
print(f"Saved sample batch to {SAMPLE_PT_PATH}")


input_ids shape: torch.Size([8, 512])
attention_mask shape: torch.Size([8, 512])
Saved sample batch to ./processed_data/sample_dataset.pt


In [9]:
example_ids = sample_batch["input_ids"][0]
decoded = tokenizer.decode(example_ids.tolist(), skip_special_tokens=True)
decoded[:1000]


'david gibbins (born 1962) is an underwater archaeologist and a bestselling novelist. early life gibbins was born in 1962 in saskatoon, saskatchewan, canada, to british parents who were academic scientists. he is related to the victorian historian henry de beltgens gibbins and to brigadier henry john gordon gale, dso and bar. after growing up in canada, new zealand and england he attended the university of bristol, where he was awarded a first class honours degree in ancient mediterranean studies. he spent part of 1984 in turkey funded by a travel scholarship from the british institute of archaeology in ankara. in 1984 he was awarded a research scholarship by corpus christi college, university of cambridge, where he completed a phd in archaeology in 1991. he qualified as a scuba diver in canada at the age of 15, and since then has dived extensively around the world. career academic career from 1991 to 1993 he held a postdoctoral fellowship at the university of cambridge from the canadi