In [1]:
"""
Enhanced Streaming Language Modeling Data Pipeline with Hugging Face Datasets
----------------------------------------------------------------------------
Goal:
 Demonstrate an enhanced streaming LM pipeline that:
 - Uses a different dataset (WikiText-103) for larger, more diverse text content
 - Employs DistilGPT-2 tokenizer for faster processing
 - Processes data without loading the entire dataset into RAM
 - Tokenizes on the fly with custom preprocessing
 - Concatenates text and chunks into larger fixed-length blocks (256 tokens)
 - Produces batches ready for training in PyTorch
 - Includes data statistics and memory usage tracking

Key Enhancements:
 1. Different dataset: WikiText-103 (larger) instead of WikiText-2
 2. Different model: DistilGPT-2 instead of GPT-2
 3. Larger block size: 256 tokens for better context
 4. Enhanced monitoring: Statistics and memory tracking
 5. Additional analysis: Token distribution visualization
"""
print(__doc__)



Enhanced Streaming Language Modeling Data Pipeline with Hugging Face Datasets
----------------------------------------------------------------------------
Goal:
 Demonstrate an enhanced streaming LM pipeline that:
 - Uses a different dataset (WikiText-103) for larger, more diverse text content
 - Employs DistilGPT-2 tokenizer for faster processing
 - Processes data without loading the entire dataset into RAM
 - Tokenizes on the fly with custom preprocessing
 - Concatenates text and chunks into larger fixed-length blocks (256 tokens)
 - Produces batches ready for training in PyTorch
 - Includes data statistics and memory usage tracking

Key Enhancements:
 1. Different dataset: WikiText-103 (larger) instead of WikiText-2
 2. Different model: DistilGPT-2 instead of GPT-2
 3. Larger block size: 256 tokens for better context
 4. Enhanced monitoring: Statistics and memory tracking
 5. Additional analysis: Token distribution visualization



In [2]:
# !pip install datasets transformers torch


In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import IterableDataset, DataLoader
import torch
import psutil
import os
from collections import Counter


In [4]:
# ============================================================
# 1. Load the dataset in STREAMING mode
# ============================================================
# Using WikiText-103 (larger version) instead of WikiText-2
# This provides more diverse content while maintaining compatibility
# Streaming mode returns an IterableDataset — you can iterate over it
# without having all the data in memory at once.
stream_dataset = load_dataset(
    "wikitext", 
    "wikitext-103-raw-v1", 
    split="train", 
    streaming=True
)


In [5]:
# ============================================================
# 1.5. Sample and analyze dataset characteristics
# ============================================================
print("Analyzing dataset characteristics...")
sample_count = 0
total_chars = 0
sample_texts = []

for example in stream_dataset:
    if sample_count < 5:  # Sample first 5 examples
        # C4 dataset has 'text' field
        text = example.get("text", "")
        if text:  # Only process non-empty texts
            sample_texts.append(text[:200])  # First 200 chars
            total_chars += len(text)
            sample_count += 1
    if sample_count >= 5:
        break

print(f"\nSample texts (first 200 chars each):")
for i, text in enumerate(sample_texts):
    print(f"\nExample {i+1}: {text}...")

if sample_count > 0:
    print(f"\nAverage text length (first {sample_count} samples): {total_chars / sample_count:.0f} characters")
else:
    print("\nNo samples found")


Analyzing dataset characteristics...

Sample texts (first 200 chars each):

Example 1:  = Valkyria Chronicles III = 
...

Example 2:  Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ p...

Example 3:  The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adju...

Example 4:  It met with positive sales in Japan , and was praised by both Japanese and western critics . After release , it received downloadable content , along with an expanded edition in November of that year...

Example 5:  = = Gameplay = = 
...

Average text length (first 5 samples): 371 characters


In [6]:
# ============================================================
# 2.5. Memory usage tracking (after tokenizer initialization)
# ============================================================
def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

# Track memory after tokenizer is loaded
initial_memory = get_memory_usage()
print(f"Memory usage after tokenizer initialization: {initial_memory:.2f} MB")


Memory usage after tokenizer initialization: 459.79 MB


In [7]:
# ============================================================
# 2. Initialize the tokenizer
# ============================================================
# Using DistilGPT-2 - a distilled version of GPT-2
# Faster and smaller while maintaining good performance
# For DistilGPT-2, there is no pad token by default, so we set pad_token = eos_token.
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token


In [8]:
# ============================================================
# 3. Tokenization step
# ============================================================
# We do NOT pad/truncate here — we want raw token sequences.
# This keeps flexibility to later concatenate across documents.
def tokenize_function(examples):
    return tokenizer(examples["text"])

# Map tokenization lazily over the streaming dataset
tokenized_stream = stream_dataset.map(tokenize_function, batched=True)


In [9]:
# ============================================================
# 4. Rolling buffer for grouping into fixed-length blocks
# ============================================================
# Using larger block size (256) for better context understanding
# This allows models to see longer sequences
# Because streaming datasets are iterators, we can't look ahead arbitrarily.
# We'll keep a buffer that stores leftover tokens from the previous batch,
# so we can concatenate and chunk consistently.
block_size = 256  # Changed from 128 to 256 for better context

def group_texts_streaming(dataset_iter, block_size):
    buffer = []
    for example in dataset_iter:
        buffer.extend(example["input_ids"])
        while len(buffer) >= block_size:
            chunk = buffer[:block_size]
            buffer = buffer[block_size:]
            yield {
                "input_ids": chunk,
                "attention_mask": [1] * block_size
            }


In [10]:
# ============================================================
# 5. Wrap generator in an IterableDataset
# ============================================================
class StreamingLMIterableDataset(IterableDataset):
    def __init__(self, hf_iterable_dataset, block_size):
        self.dataset = hf_iterable_dataset
        self.block_size = block_size

    def __iter__(self):
        return group_texts_streaming(self.dataset, self.block_size)

grouped_iterable_dataset = StreamingLMIterableDataset(tokenized_stream, block_size)


In [11]:
# ============================================================
# 6. Collate function for batches
# ============================================================
def collate_fn(batch):
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": input_ids.clone()
    }


In [12]:
# ============================================================
# 7. DataLoader for streaming data
# ============================================================
# Using larger batch size for more efficient training
train_loader = DataLoader(grouped_iterable_dataset, batch_size=16, collate_fn=collate_fn)  # Changed from 8 to 16


In [13]:
# ============================================================
# 8. Iterate over batches with enhanced monitoring
# ============================================================
print("Sample streaming batches:")
print("Testing with 5 batches (batch_size=16, block_size=256)...")

batch_shapes = []
token_counts = []

for i, batch in enumerate(train_loader):
    shape = batch['input_ids'].shape
    batch_shapes.append(shape)
    token_counts.append(batch['input_ids'].numel())
    
    print(f"Batch {i} -> input_ids shape: {shape}")
    print(f"         -> Total tokens: {batch['input_ids'].numel()}")
    print(f"         -> attention_mask shape: {batch['attention_mask'].shape}")
    print(f"         -> labels shape: {batch['labels'].shape}")
    
    # Validation checks
    assert batch['input_ids'].shape == batch['labels'].shape, "Input IDs and labels must have same shape"
    assert batch['input_ids'].shape[1] == block_size, f"Sequence length must be {block_size}"
    
    if i == 4:
        break

print("\n" + "="*60)
print("Batch Statistics:")
print(f"  Average batch size: {sum(s[0] for s in batch_shapes) / len(batch_shapes)}")
print(f"  Total tokens processed: {sum(token_counts):,}")
print(f"  Memory usage after processing: {get_memory_usage():.2f} MB")
print(f"  Memory increase: {get_memory_usage() - initial_memory:.2f} MB")
print("="*60)
print("\n✅ Test completed successfully! All batches have correct shapes.")


Sample streaming batches:
Testing with 5 batches (batch_size=16, block_size=256)...
Batch 0 -> input_ids shape: torch.Size([16, 256])
         -> Total tokens: 4096
         -> attention_mask shape: torch.Size([16, 256])
         -> labels shape: torch.Size([16, 256])
Batch 1 -> input_ids shape: torch.Size([16, 256])
         -> Total tokens: 4096
         -> attention_mask shape: torch.Size([16, 256])
         -> labels shape: torch.Size([16, 256])
Batch 2 -> input_ids shape: torch.Size([16, 256])
         -> Total tokens: 4096
         -> attention_mask shape: torch.Size([16, 256])
         -> labels shape: torch.Size([16, 256])
Batch 3 -> input_ids shape: torch.Size([16, 256])
         -> Total tokens: 4096
         -> attention_mask shape: torch.Size([16, 256])
         -> labels shape: torch.Size([16, 256])
Batch 4 -> input_ids shape: torch.Size([16, 256])
         -> Total tokens: 4096
         -> attention_mask shape: torch.Size([16, 256])
         -> labels shape: torch.Size([1

In [14]:
# ============================================================
# 9. Enhanced validation and token analysis
# ============================================================
print("\n" + "="*60)
print("Enhanced Validation Tests")
print("="*60)

print(f"\nTokenizer Information:")
print(f"  Model: DistilGPT-2")
print(f"  Vocab size: {len(tokenizer):,}")
print(f"  Block size: {block_size}")
print(f"  Batch size: 16")

# Get a sample batch for analysis
sample_batch = next(iter(train_loader))

print(f"\nSample Batch Data Types:")
print(f"  input_ids dtype: {sample_batch['input_ids'].dtype}")
print(f"  attention_mask dtype: {sample_batch['attention_mask'].dtype}")
print(f"  labels dtype: {sample_batch['labels'].dtype}")

print(f"\nSample Batch Value Ranges:")
min_val = sample_batch['input_ids'].min().item()
max_val = sample_batch['input_ids'].max().item()
print(f"  input_ids min: {min_val}, max: {max_val}")
print(f"  All values are valid token IDs: {min_val >= 0}")
print(f"  Values within vocab range: {max_val < len(tokenizer)}")

# Token frequency analysis
all_tokens = sample_batch['input_ids'].flatten().tolist()
token_freq = Counter(all_tokens)
most_common = token_freq.most_common(10)

print(f"\nTop 10 Most Frequent Token IDs in Sample Batch:")
for token_id, count in most_common:
    try:
        token = tokenizer.decode([token_id])
        print(f"  Token ID {token_id:5d} ({token:20s}): {count:4d} occurrences")
    except:
        print(f"  Token ID {token_id:5d}: {count:4d} occurrences")

print("\n" + "="*60)
print("✅ All validation tests passed!")
print("="*60)



Enhanced Validation Tests

Tokenizer Information:
  Model: DistilGPT-2
  Vocab size: 50,257
  Block size: 256
  Batch size: 16

Sample Batch Data Types:
  input_ids dtype: torch.int64
  attention_mask dtype: torch.int64
  labels dtype: torch.int64

Sample Batch Value Ranges:
  input_ids min: 11, max: 49907
  All values are valid token IDs: True
  Values within vocab range: True

Top 10 Most Frequent Token IDs in Sample Batch:
  Token ID   262 ( the                ):  207 occurrences
  Token ID   837 ( ,                  ):  178 occurrences
  Token ID   764 ( .                  ):  128 occurrences
  Token ID   284 ( to                 ):   81 occurrences
  Token ID   286 ( of                 ):   75 occurrences
  Token ID   290 ( and                ):   73 occurrences
  Token ID   257 ( a                  ):   60 occurrences
  Token ID   569 ( V                  ):   53 occurrences
  Token ID 18354 (alky                ):   53 occurrences
  Token ID  7496 (ria                 ):   49 o