In [13]:
"""
Streaming Language Modeling Data Pipeline with Hugging Face Datasets
--------------------------------------------------------------------
Goal:
    Demonstrate how to build a *true streaming* LM pipeline that:
    - Processes data without loading the entire dataset into RAM.
    - Tokenizes on the fly.
    - Concatenates text and chunks into fixed-length blocks for LM training.
    - Produces batches ready for training in PyTorch.

Key Teaching Points:
    1. Streaming allows us to work with web-scale corpora.
    2. We still can do grouping/chunking in a rolling fashion.
    3. This approach mimics real-world pipelines for large-scale LM training.
"""
print(__doc__)


Streaming Language Modeling Data Pipeline with Hugging Face Datasets
--------------------------------------------------------------------
Goal:
    Demonstrate how to build a *true streaming* LM pipeline that:
    - Processes data without loading the entire dataset into RAM.
    - Tokenizes on the fly.
    - Concatenates text and chunks into fixed-length blocks for LM training.
    - Produces batches ready for training in PyTorch.

Key Teaching Points:
    1. Streaming allows us to work with web-scale corpora.
    2. We still can do grouping/chunking in a rolling fashion.
    3. This approach mimics real-world pipelines for large-scale LM training.



In [5]:
# !pip install transformers, AutoTokenizer, torch

In [21]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import IterableDataset, DataLoader
import torch

In [14]:
# ============================================================
# 1. Load the dataset in STREAMING mode
# ============================================================
# Streaming mode returns an IterableDataset — you can iterate over it
# without having all the data in memory at once.
stream_dataset = load_dataset(
    "wikitext", 
    "wikitext-2-raw-v1", 
    split="train", 
    streaming=True
)

In [15]:
# ============================================================
# 2. Initialize the tokenizer
# ============================================================
# For GPT-2, there is no pad token by default, so we set pad_token = eos_token.
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [16]:
# ============================================================
# 3. Tokenization step
# ============================================================
# We do NOT pad/truncate here — we want raw token sequences.
# This keeps flexibility to later concatenate across documents.
def tokenize_function(examples):
    return tokenizer(examples["text"])

# Map tokenization lazily over the streaming dataset
tokenized_stream = stream_dataset.map(tokenize_function, batched=True)

In [20]:
# ============================================================
# 4. Rolling buffer for grouping into fixed-length blocks
# ============================================================
# Because streaming datasets are iterators, we can't look ahead arbitrarily.
# We'll keep a buffer that stores leftover tokens from the previous batch,
# so we can concatenate and chunk consistently.
block_size = 128

def group_texts_streaming(dataset_iter, block_size):
    buffer = []
    for example in dataset_iter:
        buffer.extend(example["input_ids"])
        while len(buffer) >= block_size:
            chunk = buffer[:block_size]
            buffer = buffer[block_size:]
            yield {
                "input_ids": chunk,
                "attention_mask": [1] * block_size
            }


In [23]:
# ============================================================
# 5. Wrap generator in an IterableDataset
# ============================================================
class StreamingLMIterableDataset(IterableDataset):
    def __init__(self, hf_iterable_dataset, block_size):
        self.dataset = hf_iterable_dataset
        self.block_size = block_size

    def __iter__(self):
        return group_texts_streaming(self.dataset, self.block_size)

grouped_iterable_dataset = StreamingLMIterableDataset(tokenized_stream, block_size)

In [24]:
# ============================================================
# 6. Collate function for batches
# ============================================================
def collate_fn(batch):
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": input_ids.clone()
    }

In [25]:
# ============================================================
# 7. DataLoader for streaming data
# ============================================================
train_loader = DataLoader(grouped_iterable_dataset, batch_size=8, collate_fn=collate_fn)

In [None]:
# ============================================================
# 8. Iterate over a few batches
# ============================================================
print("Sample streaming batches:")
for i, batch in enumerate(train_loader):
    print(f"Batch {i} -> input_ids shape: {batch['input_ids'].shape}")
    if i == 2:
        break

Sample streaming batches:
Batch 0 -> input_ids shape: torch.Size([8, 128])
Batch 1 -> input_ids shape: torch.Size([8, 128])
Batch 2 -> input_ids shape: torch.Size([8, 128])
