# 05 - Data Pipeline: Dataset, DataLoader, and Augmentation

## Learning Objectives

By the end of this notebook, you will:

1. **Understand Dataset internals** - How `torch.utils.data.Dataset` works and when to use map-style vs iterable-style datasets
2. **Master DataLoader mechanics** - Batching, shuffling, collation, and the interaction between components
3. **Implement custom Samplers** - Control data ordering for stratified sampling, curriculum learning, etc.
4. **Optimize data loading** - Multi-worker loading, memory pinning, and prefetching
5. **Design augmentation pipelines** - Using torchvision.transforms and custom augmentations

---

## Setup

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import (
    Dataset, 
    IterableDataset,
    DataLoader, 
    Sampler,
    RandomSampler,
    SequentialSampler,
    BatchSampler,
    WeightedRandomSampler,
    SubsetRandomSampler,
    default_collate
)
import torchvision
import torchvision.transforms as T
from torchvision.datasets import MNIST, FashionMNIST
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import time
import os
from collections import Counter
from typing import Iterator, List, Tuple, Any, Optional

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

---

## 1. Dataset Fundamentals

PyTorch provides two types of datasets:

1. **Map-style datasets** (`Dataset`): Index-based access via `__getitem__` and `__len__`
2. **Iterable-style datasets** (`IterableDataset`): Streaming access via `__iter__`

### 1.1 Map-Style Dataset Internals

A map-style dataset must implement:
- `__getitem__(index)`: Returns the sample at the given index
- `__len__()`: Returns the total number of samples

In [None]:
# Simple custom dataset example
class SimpleDataset(Dataset):
    """A dataset that returns (x, y) pairs where y = 2x + 1"""
    
    def __init__(self, size: int = 1000):
        self.size = size
        # Generate data once during initialization
        self.x = torch.randn(size, 1)
        self.y = 2 * self.x + 1 + 0.1 * torch.randn(size, 1)  # Add noise
    
    def __len__(self) -> int:
        return self.size
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.x[index], self.y[index]


# Create and inspect
dataset = SimpleDataset(100)
print(f"Dataset length: {len(dataset)}")
print(f"First sample: x={dataset[0][0].item():.3f}, y={dataset[0][1].item():.3f}")
print(f"Type of sample: {type(dataset[0])}")

### 1.2 Lazy Loading vs Pre-loading

A critical design decision is **when** to load data:

- **Pre-loading**: Load everything into memory in `__init__` (fast access, high memory)
- **Lazy loading**: Load each sample on-demand in `__getitem__` (slow access, low memory)

In [None]:
class LazyImageDataset(Dataset):
    """Loads images from disk on-demand (lazy loading)"""
    
    def __init__(self, image_paths: List[str], labels: List[int], transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.image_paths)
    
    def __getitem__(self, index: int):
        # Load image from disk each time (lazy)
        image = Image.open(self.image_paths[index]).convert('RGB')
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


class PreloadedImageDataset(Dataset):
    """Loads all images into memory upfront (pre-loading)"""
    
    def __init__(self, image_paths: List[str], labels: List[int], transform=None):
        self.transform = transform
        self.labels = labels
        
        # Load all images into memory
        print(f"Pre-loading {len(image_paths)} images...")
        self.images = [Image.open(p).convert('RGB') for p in image_paths]
        print("Done!")
    
    def __len__(self) -> int:
        return len(self.images)
    
    def __getitem__(self, index: int):
        image = self.images[index]
        label = self.labels[index]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


print("LazyImageDataset: Low memory, disk I/O each access")
print("PreloadedImageDataset: High memory, fast access")
print("\nChoose based on: dataset size vs available RAM")

### 1.3 Iterable-Style Datasets

Use `IterableDataset` when:
- Data comes from a stream (network, database)
- Dataset is too large to index
- Random access is expensive or impossible

In [None]:
class StreamingDataset(IterableDataset):
    """Simulates streaming data that can't be indexed"""
    
    def __init__(self, num_samples: int = 1000):
        self.num_samples = num_samples
    
    def __iter__(self) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
        # Generate samples on-the-fly
        for _ in range(self.num_samples):
            x = torch.randn(1)
            y = 2 * x + 1 + 0.1 * torch.randn(1)
            yield x, y


# IterableDataset doesn't support __len__ or indexing
stream_dataset = StreamingDataset(100)

# Must iterate to access data
for i, (x, y) in enumerate(stream_dataset):
    if i >= 3:
        break
    print(f"Sample {i}: x={x.item():.3f}, y={y.item():.3f}")

# This would fail:
# stream_dataset[0]  # TypeError: 'StreamingDataset' object is not subscriptable

### 1.4 Worker-Safe Iterable Datasets

When using multiple workers with `IterableDataset`, each worker gets a copy. You must handle this to avoid duplicate data!

In [None]:
class WorkerSafeStreamingDataset(IterableDataset):
    """Properly handles multi-worker data loading"""
    
    def __init__(self, num_samples: int = 1000):
        self.num_samples = num_samples
    
    def __iter__(self) -> Iterator:
        # Get worker info
        worker_info = torch.utils.data.get_worker_info()
        
        if worker_info is None:
            # Single-process loading
            start = 0
            end = self.num_samples
        else:
            # Multi-process loading: split work among workers
            per_worker = self.num_samples // worker_info.num_workers
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = start + per_worker
            if worker_id == worker_info.num_workers - 1:
                end = self.num_samples  # Last worker handles remainder
        
        # Generate only this worker's portion
        for i in range(start, end):
            torch.manual_seed(i)  # Reproducible samples
            x = torch.randn(1)
            y = 2 * x + 1 + 0.1 * torch.randn(1)
            yield x, y


# Test with multiple workers
safe_dataset = WorkerSafeStreamingDataset(20)
loader = DataLoader(safe_dataset, batch_size=4, num_workers=2)

all_samples = []
for batch_x, batch_y in loader:
    all_samples.extend(batch_x.tolist())

print(f"Total samples (with 2 workers): {len(all_samples)}")
print(f"Unique samples: {len(set(tuple(s) for s in all_samples))}")

---

## 2. DataLoader Deep Dive

The `DataLoader` orchestrates:
1. **Sampling**: Which indices to load (via `Sampler`)
2. **Loading**: Getting samples from the dataset
3. **Batching**: Grouping samples together (via `BatchSampler`)
4. **Collation**: Combining samples into batches (via `collate_fn`)

### 2.1 DataLoader Parameters Explained

In [None]:
# Create a dataset for demonstration
demo_dataset = SimpleDataset(100)

# DataLoader with all major parameters
loader = DataLoader(
    demo_dataset,
    
    # Batching
    batch_size=16,          # Samples per batch
    drop_last=False,        # Keep incomplete final batch?
    
    # Sampling
    shuffle=True,           # Randomize order each epoch
    # sampler=None,         # Custom sampler (mutually exclusive with shuffle)
    # batch_sampler=None,   # Custom batch sampler (overrides batch_size, shuffle, etc.)
    
    # Performance
    num_workers=0,          # Parallel data loading processes
    pin_memory=False,       # Pin memory for faster GPU transfer
    prefetch_factor=2,      # Batches to prefetch per worker (when num_workers > 0)
    persistent_workers=False,  # Keep workers alive between epochs
    
    # Collation
    collate_fn=None,        # Custom function to merge samples
)

# Inspect a batch
batch_x, batch_y = next(iter(loader))
print(f"Batch shape: x={batch_x.shape}, y={batch_y.shape}")

### 2.2 Understanding Collation

The `collate_fn` converts a list of samples into a batch. The default behavior:
- Stacks tensors along a new dimension
- Recursively processes tuples, lists, dicts

In [None]:
# Default collation behavior
samples = [demo_dataset[i] for i in range(4)]
print("Individual samples:")
for i, (x, y) in enumerate(samples):
    print(f"  Sample {i}: x.shape={x.shape}, y.shape={y.shape}")

# Default collate function stacks them
batch = default_collate(samples)
print(f"\nAfter collation:")
print(f"  Batch x.shape={batch[0].shape}, y.shape={batch[1].shape}")

In [None]:
# Custom collate function for variable-length sequences
class VariableLengthDataset(Dataset):
    """Dataset with variable-length sequences"""
    
    def __init__(self, size: int = 100):
        self.size = size
    
    def __len__(self) -> int:
        return self.size
    
    def __getitem__(self, index: int):
        # Variable length sequence (5 to 15 elements)
        length = 5 + index % 11
        sequence = torch.randn(length)
        label = index % 3
        return sequence, label


def pad_collate(batch):
    """Pad sequences to the longest in the batch"""
    sequences, labels = zip(*batch)
    
    # Find max length
    lengths = [len(seq) for seq in sequences]
    max_len = max(lengths)
    
    # Pad sequences
    padded = torch.zeros(len(sequences), max_len)
    for i, seq in enumerate(sequences):
        padded[i, :len(seq)] = seq
    
    return padded, torch.tensor(labels), torch.tensor(lengths)


var_dataset = VariableLengthDataset(20)
var_loader = DataLoader(var_dataset, batch_size=4, collate_fn=pad_collate)

for padded_seqs, labels, lengths in var_loader:
    print(f"Padded batch shape: {padded_seqs.shape}")
    print(f"Original lengths: {lengths.tolist()}")
    break

### 2.3 Using torch.nn.utils.rnn for Sequences

PyTorch provides utilities for handling variable-length sequences:

In [None]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

def advanced_pad_collate(batch):
    """Use PyTorch's pad_sequence for efficient padding"""
    sequences, labels = zip(*batch)
    
    # Sort by length (descending) - required for pack_padded_sequence
    sorted_indices = sorted(range(len(sequences)), key=lambda i: len(sequences[i]), reverse=True)
    sequences = [sequences[i] for i in sorted_indices]
    labels = [labels[i] for i in sorted_indices]
    
    lengths = torch.tensor([len(seq) for seq in sequences])
    
    # pad_sequence expects list of tensors, pads to longest
    padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    
    return padded, torch.tensor(labels), lengths


var_loader2 = DataLoader(var_dataset, batch_size=4, collate_fn=advanced_pad_collate)

for padded_seqs, labels, lengths in var_loader2:
    print(f"Padded shape: {padded_seqs.shape}")
    print(f"Lengths (sorted): {lengths.tolist()}")
    
    # Can use with pack_padded_sequence for RNNs
    packed = pack_padded_sequence(padded_seqs.unsqueeze(-1), lengths.cpu(), batch_first=True)
    print(f"Packed data shape: {packed.data.shape}")
    break

---

## 3. Samplers: Controlling Data Order

Samplers determine **which indices** are loaded and in **what order**.

### 3.1 Built-in Samplers

In [None]:
dataset = SimpleDataset(20)

# SequentialSampler: indices 0, 1, 2, ..., n-1
seq_sampler = SequentialSampler(dataset)
print(f"Sequential: {list(seq_sampler)[:10]}...")

# RandomSampler: random permutation
rand_sampler = RandomSampler(dataset)
print(f"Random: {list(rand_sampler)[:10]}...")

# RandomSampler with replacement (for oversampling)
rand_replace_sampler = RandomSampler(dataset, replacement=True, num_samples=30)
print(f"Random w/ replacement (30 samples): {list(rand_replace_sampler)[:15]}...")

# SubsetRandomSampler: random sample from specific indices
indices = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]  # Even indices only
subset_sampler = SubsetRandomSampler(indices)
print(f"Subset (even indices): {list(subset_sampler)}")

### 3.2 WeightedRandomSampler for Class Imbalance

Handles imbalanced datasets by oversampling minority classes:

In [None]:
class ImbalancedDataset(Dataset):
    """Dataset with class imbalance: 90% class 0, 10% class 1"""
    
    def __init__(self, size: int = 1000):
        self.size = size
        # Create imbalanced labels
        self.labels = torch.zeros(size, dtype=torch.long)
        self.labels[:size // 10] = 1  # 10% are class 1
        # Shuffle
        perm = torch.randperm(size)
        self.labels = self.labels[perm]
        self.data = torch.randn(size, 10)
    
    def __len__(self) -> int:
        return self.size
    
    def __getitem__(self, index: int):
        return self.data[index], self.labels[index]


imb_dataset = ImbalancedDataset(1000)
print(f"Original class distribution: {Counter(imb_dataset.labels.tolist())}")

# Calculate sample weights (inverse class frequency)
class_counts = Counter(imb_dataset.labels.tolist())
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label.item()] for label in imb_dataset.labels]

# Create weighted sampler
weighted_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(imb_dataset),
    replacement=True  # Must be True for oversampling
)

# Compare class distributions in batches
regular_loader = DataLoader(imb_dataset, batch_size=100, shuffle=True)
weighted_loader = DataLoader(imb_dataset, batch_size=100, sampler=weighted_sampler)

# Check first epoch
regular_labels = []
weighted_labels = []

for _, labels in regular_loader:
    regular_labels.extend(labels.tolist())

for _, labels in weighted_loader:
    weighted_labels.extend(labels.tolist())

print(f"\nRegular loader distribution: {Counter(regular_labels)}")
print(f"Weighted loader distribution: {Counter(weighted_labels)}")

### 3.3 Custom Sampler Implementation

Create your own sampler by implementing `__iter__` and `__len__`:

In [None]:
class CurriculumSampler(Sampler):
    """
    Curriculum learning sampler: starts with easier samples,
    gradually introduces harder ones.
    
    Assumes dataset has a 'difficulty' attribute or method.
    """
    
    def __init__(self, dataset, difficulties: List[float], epoch: int = 0, max_epochs: int = 10):
        self.dataset = dataset
        self.difficulties = difficulties
        self.epoch = epoch
        self.max_epochs = max_epochs
        
        # Sort indices by difficulty
        self.sorted_indices = sorted(range(len(difficulties)), key=lambda i: difficulties[i])
    
    def set_epoch(self, epoch: int):
        """Update epoch for curriculum progression"""
        self.epoch = epoch
    
    def __iter__(self) -> Iterator[int]:
        # Calculate how much of the dataset to use
        progress = min(1.0, (self.epoch + 1) / self.max_epochs)
        num_samples = int(progress * len(self.sorted_indices))
        num_samples = max(num_samples, len(self.sorted_indices) // 10)  # At least 10%
        
        # Use easiest samples up to current progress
        available_indices = self.sorted_indices[:num_samples]
        
        # Shuffle within available samples
        perm = torch.randperm(len(available_indices)).tolist()
        return iter([available_indices[i] for i in perm])
    
    def __len__(self) -> int:
        progress = min(1.0, (self.epoch + 1) / self.max_epochs)
        return max(int(progress * len(self.sorted_indices)), len(self.sorted_indices) // 10)


# Example usage
dataset = SimpleDataset(100)
# Simulate difficulties (e.g., based on noise level)
difficulties = torch.rand(100).tolist()

curriculum_sampler = CurriculumSampler(dataset, difficulties, epoch=0, max_epochs=10)

print("Curriculum learning progression:")
for epoch in [0, 2, 5, 9]:
    curriculum_sampler.set_epoch(epoch)
    indices = list(curriculum_sampler)
    avg_difficulty = sum(difficulties[i] for i in indices) / len(indices)
    print(f"  Epoch {epoch}: {len(indices)} samples, avg difficulty: {avg_difficulty:.3f}")

### 3.4 BatchSampler: Custom Batching Strategies

In [None]:
class SimilarLengthBatchSampler(Sampler):
    """
    Groups sequences of similar length into batches.
    Reduces padding waste in NLP tasks.
    """
    
    def __init__(self, lengths: List[int], batch_size: int, drop_last: bool = False):
        self.batch_size = batch_size
        self.drop_last = drop_last
        
        # Sort indices by length
        self.sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])
        self.lengths = lengths
    
    def __iter__(self):
        # Create batches of similar-length sequences
        batches = []
        for i in range(0, len(self.sorted_indices), self.batch_size):
            batch = self.sorted_indices[i:i + self.batch_size]
            if len(batch) == self.batch_size or not self.drop_last:
                batches.append(batch)
        
        # Shuffle batch order (not within batches)
        perm = torch.randperm(len(batches)).tolist()
        for i in perm:
            yield from batches[i]
    
    def __len__(self):
        if self.drop_last:
            return (len(self.sorted_indices) // self.batch_size) * self.batch_size
        return len(self.sorted_indices)


# Example with variable-length dataset
var_dataset = VariableLengthDataset(100)
lengths = [5 + i % 11 for i in range(100)]  # Match dataset's length pattern

similar_sampler = SimilarLengthBatchSampler(lengths, batch_size=8)
similar_loader = DataLoader(var_dataset, batch_size=1, sampler=similar_sampler, collate_fn=pad_collate)

# The loader will yield individual samples in the sampler's order
# We need to manually batch them or use BatchSampler
print("Similar-length batching reduces padding waste!")

---

## 4. Performance Optimization

### 4.1 Multi-Worker Loading

In [None]:
class SlowDataset(Dataset):
    """Simulates a dataset with slow I/O (e.g., loading from disk)"""
    
    def __init__(self, size: int = 100, delay: float = 0.01):
        self.size = size
        self.delay = delay
        self.data = torch.randn(size, 64, 64)
        self.labels = torch.randint(0, 10, (size,))
    
    def __len__(self) -> int:
        return self.size
    
    def __getitem__(self, index: int):
        time.sleep(self.delay)  # Simulate I/O delay
        return self.data[index], self.labels[index]


def benchmark_loader(loader, num_batches: int = 10) -> float:
    """Measure time to load batches"""
    start = time.time()
    for i, (data, labels) in enumerate(loader):
        if i >= num_batches:
            break
        # Simulate some processing
        _ = data.mean()
    return time.time() - start


slow_dataset = SlowDataset(200, delay=0.01)

# Compare different num_workers
print("Benchmarking multi-worker loading:")
print("(Note: First run may be slower due to worker startup)")

for num_workers in [0, 2, 4]:
    loader = DataLoader(
        slow_dataset, 
        batch_size=16, 
        num_workers=num_workers,
        persistent_workers=(num_workers > 0)
    )
    
    # Warm up
    _ = benchmark_loader(loader, 2)
    
    # Actual benchmark
    elapsed = benchmark_loader(loader, 10)
    print(f"  num_workers={num_workers}: {elapsed:.3f}s")

### 4.2 Memory Pinning

Pinned (page-locked) memory enables faster CPU-to-GPU transfers:

In [None]:
if torch.cuda.is_available():
    # Create larger dataset for meaningful benchmark
    large_dataset = SimpleDataset(10000)
    
    def benchmark_transfer(loader, num_batches: int = 50):
        """Measure time to transfer batches to GPU"""
        torch.cuda.synchronize()
        start = time.time()
        
        for i, (x, y) in enumerate(loader):
            if i >= num_batches:
                break
            # Transfer to GPU with non_blocking when using pinned memory
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
        
        torch.cuda.synchronize()
        return time.time() - start
    
    # Without memory pinning
    loader_no_pin = DataLoader(large_dataset, batch_size=256, pin_memory=False)
    time_no_pin = benchmark_transfer(loader_no_pin)
    
    # With memory pinning
    loader_pinned = DataLoader(large_dataset, batch_size=256, pin_memory=True)
    time_pinned = benchmark_transfer(loader_pinned)
    
    print(f"Without pin_memory: {time_no_pin:.4f}s")
    print(f"With pin_memory: {time_pinned:.4f}s")
    print(f"Speedup: {time_no_pin / time_pinned:.2f}x")
else:
    print("GPU not available - pin_memory has no effect on CPU-only training")

### 4.3 Prefetching and Persistent Workers

In [None]:
# Best practices for DataLoader configuration
def create_optimized_loader(
    dataset,
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 4,
    pin_memory: bool = True
):
    """
    Create an optimized DataLoader with best practices.
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        
        # Performance settings
        num_workers=num_workers,
        pin_memory=pin_memory and torch.cuda.is_available(),
        
        # Prefetch batches while GPU is processing
        prefetch_factor=2 if num_workers > 0 else None,
        
        # Keep workers alive between epochs (saves startup time)
        persistent_workers=num_workers > 0,
        
        # Drop incomplete batches for consistent batch sizes
        drop_last=True,
    )


print("Optimal DataLoader settings:")
print("  - num_workers: 2-4x CPU cores (start with 4, tune based on I/O)")
print("  - pin_memory: True for GPU training")
print("  - prefetch_factor: 2 (default) is usually good")
print("  - persistent_workers: True to avoid worker restart overhead")

---

## 5. Data Augmentation

Data augmentation increases training data diversity without collecting more data.

### 5.1 torchvision.transforms

In [None]:
# Download MNIST for augmentation examples
data_dir = '../data'
os.makedirs(data_dir, exist_ok=True)

# Basic transforms
basic_transform = T.Compose([
    T.ToTensor(),  # PIL Image -> Tensor, scales to [0, 1]
    T.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Training transforms with augmentation
train_transform = T.Compose([
    T.RandomRotation(15),              # Rotate up to 15 degrees
    T.RandomAffine(                    # Random affine transformation
        degrees=0,
        translate=(0.1, 0.1),          # Shift up to 10%
        scale=(0.9, 1.1),              # Scale between 90% and 110%
    ),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
    T.RandomErasing(p=0.2),            # Randomly erase patches
])

# Load datasets
train_dataset = MNIST(data_dir, train=True, download=True, transform=train_transform)
test_dataset = MNIST(data_dir, train=False, download=True, transform=basic_transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Visualize augmentations
def show_augmentations(dataset, index: int = 0, num_versions: int = 8):
    """Show multiple augmented versions of the same image"""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    
    for i, ax in enumerate(axes):
        img, label = dataset[index]  # Each call applies random augmentation
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(f'Version {i+1} (Label: {label})')
        ax.axis('off')
    
    plt.suptitle('Same image with different random augmentations')
    plt.tight_layout()
    plt.show()


show_augmentations(train_dataset, index=0)

### 5.2 Advanced Transforms for Images

In [None]:
# Common augmentation pipeline for color images
imagenet_train_transform = T.Compose([
    # Spatial transformations
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Random crop and resize
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(15),
    
    # Color transformations
    T.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),
    T.RandomGrayscale(p=0.1),
    
    # Convert to tensor
    T.ToTensor(),
    
    # Normalize (ImageNet stats)
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    
    # Regularization
    T.RandomErasing(p=0.2),
])

# Validation/test transform (no augmentation)
imagenet_val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

print("ImageNet-style augmentation pipeline created!")

### 5.3 Custom Transforms

In [None]:
class AddGaussianNoise:
    """Add Gaussian noise to tensor images"""
    
    def __init__(self, mean: float = 0.0, std: float = 0.1):
        self.mean = mean
        self.std = std
    
    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
        noise = torch.randn_like(tensor) * self.std + self.mean
        return tensor + noise
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"


class Cutout:
    """Randomly mask out rectangular regions (cutout regularization)"""
    
    def __init__(self, n_holes: int = 1, length: int = 8):
        self.n_holes = n_holes
        self.length = length
    
    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        h, w = img.shape[-2:]
        mask = torch.ones_like(img)
        
        for _ in range(self.n_holes):
            y = torch.randint(0, h, (1,)).item()
            x = torch.randint(0, w, (1,)).item()
            
            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)
            
            mask[..., y1:y2, x1:x2] = 0
        
        return img * mask


# Use in a transform pipeline
custom_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
    AddGaussianNoise(std=0.1),
    Cutout(n_holes=2, length=6),
])

# Test custom transforms
custom_dataset = MNIST(data_dir, train=True, download=True, transform=custom_transform)

fig, axes = plt.subplots(1, 4, figsize=(12, 3))
for i, ax in enumerate(axes):
    img, label = custom_dataset[i]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'Custom augmented (Label: {label})')
    ax.axis('off')
plt.tight_layout()
plt.show()

### 5.4 torchvision.transforms.v2 (Modern API)

In [None]:
from torchvision.transforms import v2

# v2 transforms work on both images AND bounding boxes, masks, etc.
modern_transform = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.8, 1.0), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ColorJitter(brightness=0.2, contrast=0.2),
    v2.ToImage(),  # Convert to TVTensor (enables efficient GPU augmentations)
    v2.ToDtype(torch.float32, scale=True),  # Normalize to [0, 1]
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("v2 transforms offer:")
print("  - Consistent transforms across images, boxes, masks")
print("  - GPU-accelerated augmentations")
print("  - Better performance with TVTensor format")
print("  - Compound transforms (MixUp, CutMix)")

### 5.5 MixUp and CutMix

In [None]:
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
    """
    MixUp: Creates convex combinations of training examples.
    
    x_mixed = lambda * x_i + (1 - lambda) * x_j
    y_mixed = lambda * y_i + (1 - lambda) * y_j
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """MixUp loss: weighted combination of losses"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


# Demonstrate MixUp
batch_x, batch_y = next(iter(DataLoader(train_dataset, batch_size=4)))
mixed_x, y_a, y_b, lam = mixup_data(batch_x, batch_y, alpha=0.4)

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
    axes[0, i].imshow(batch_x[i].squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original (y={batch_y[i].item()})')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(mixed_x[i].squeeze(), cmap='gray')
    axes[1, i].set_title(f'Mixed ({lam:.2f}*{y_a[i].item()} + {1-lam:.2f}*{y_b[i].item()})')
    axes[1, i].axis('off')

plt.suptitle(f'MixUp with lambda={lam:.2f}')
plt.tight_layout()
plt.show()

In [None]:
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
    """
    CutMix: Cuts a patch from one image and pastes it onto another.
    Labels are mixed proportionally to the area.
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size)
    
    # Get image dimensions
    _, _, h, w = x.shape
    
    # Calculate cut size
    cut_ratio = np.sqrt(1 - lam)
    cut_h = int(h * cut_ratio)
    cut_w = int(w * cut_ratio)
    
    # Random position for cut
    cy = np.random.randint(h)
    cx = np.random.randint(w)
    
    # Bound the cut region
    y1 = np.clip(cy - cut_h // 2, 0, h)
    y2 = np.clip(cy + cut_h // 2, 0, h)
    x1 = np.clip(cx - cut_w // 2, 0, w)
    x2 = np.clip(cx + cut_w // 2, 0, w)
    
    # Apply CutMix
    mixed_x = x.clone()
    mixed_x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    
    # Adjust lambda based on actual cut area
    lam = 1 - ((y2 - y1) * (x2 - x1) / (h * w))
    
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam


# Demonstrate CutMix
batch_x, batch_y = next(iter(DataLoader(train_dataset, batch_size=4)))
cutmix_x, y_a, y_b, lam = cutmix_data(batch_x, batch_y, alpha=1.0)

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
    axes[0, i].imshow(batch_x[i].squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original (y={batch_y[i].item()})')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(cutmix_x[i].squeeze(), cmap='gray')
    axes[1, i].set_title(f'CutMix ({lam:.2f}*{y_a[i].item()} + {1-lam:.2f}*{y_b[i].item()})')
    axes[1, i].axis('off')

plt.suptitle(f'CutMix with effective lambda={lam:.2f}')
plt.tight_layout()
plt.show()

---

## 6. Putting It All Together

Complete example with a real training pipeline:

In [None]:
class CompleteTrainingPipeline:
    """
    A complete data pipeline demonstrating all concepts.
    """
    
    def __init__(
        self,
        data_dir: str = '../data',
        batch_size: int = 64,
        num_workers: int = 2,
        use_augmentation: bool = True,
        use_mixup: bool = False,
    ):
        self.batch_size = batch_size
        self.use_mixup = use_mixup
        
        # Define transforms
        if use_augmentation:
            self.train_transform = T.Compose([
                T.RandomRotation(10),
                T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                T.ToTensor(),
                T.Normalize((0.1307,), (0.3081,)),
            ])
        else:
            self.train_transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.1307,), (0.3081,)),
            ])
        
        self.val_transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.1307,), (0.3081,)),
        ])
        
        # Load datasets
        full_train = MNIST(data_dir, train=True, download=True, transform=self.train_transform)
        self.test_dataset = MNIST(data_dir, train=False, download=True, transform=self.val_transform)
        
        # Split training into train/val
        train_size = int(0.9 * len(full_train))
        val_size = len(full_train) - train_size
        
        # Note: Using random_split with different transforms requires more care
        # For simplicity, we use the same transform here
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            full_train, [train_size, val_size]
        )
        
        # Create DataLoaders
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=num_workers > 0,
            drop_last=True,
        )
        
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=num_workers > 0,
        )
        
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=torch.cuda.is_available(),
        )
    
    def get_batch(self, split: str = 'train'):
        """Get a single batch for debugging/visualization"""
        loader = getattr(self, f'{split}_loader')
        return next(iter(loader))
    
    def __repr__(self):
        return (
            f"CompleteTrainingPipeline(\n"
            f"  train_samples={len(self.train_dataset)},\n"
            f"  val_samples={len(self.val_dataset)},\n"
            f"  test_samples={len(self.test_dataset)},\n"
            f"  batch_size={self.batch_size},\n"
            f"  use_mixup={self.use_mixup}\n"
            f")"
        )


# Create pipeline
pipeline = CompleteTrainingPipeline(
    data_dir='../data',
    batch_size=64,
    num_workers=2,
    use_augmentation=True,
    use_mixup=False,
)

print(pipeline)

# Get a batch
batch_x, batch_y = pipeline.get_batch('train')
print(f"\nBatch shape: {batch_x.shape}")
print(f"Labels shape: {batch_y.shape}")

---

## Exercises

### Exercise 1: Stratified Train/Val Split

Create a dataset wrapper that ensures train and validation splits have the same class distribution.

In [None]:
# Exercise 1: Implement stratified splitting

def stratified_split(dataset, labels: List[int], val_ratio: float = 0.1):
    """
    Split a dataset into train/val sets with the same class distribution.
    
    Args:
        dataset: The dataset to split
        labels: List of labels for each sample
        val_ratio: Fraction of data for validation
    
    Returns:
        train_indices, val_indices: Lists of indices for each split
    """
    # YOUR CODE HERE
    # Hint: Group indices by class, then sample from each class
    pass


# Test your implementation
# labels = [dataset[i][1] for i in range(len(dataset))]
# train_idx, val_idx = stratified_split(dataset, labels, val_ratio=0.2)

### Exercise 2: Implement a Dynamic Batch Sampler

Create a sampler that adjusts batch size based on sequence length to maintain roughly constant memory usage.

In [None]:
# Exercise 2: Implement dynamic batching

class DynamicBatchSampler(Sampler):
    """
    Sampler that creates batches with approximately equal total tokens.
    Longer sequences get smaller batches, shorter sequences get larger batches.
    
    Args:
        lengths: List of sequence lengths
        max_tokens: Maximum total tokens per batch
        shuffle: Whether to shuffle the data
    """
    
    def __init__(self, lengths: List[int], max_tokens: int = 1000, shuffle: bool = True):
        # YOUR CODE HERE
        pass
    
    def __iter__(self):
        # YOUR CODE HERE
        # Hint: Sort by length, then create batches that don't exceed max_tokens
        pass
    
    def __len__(self):
        # YOUR CODE HERE
        pass

### Exercise 3: Create a Multi-Dataset Loader

Implement a loader that samples from multiple datasets with specified mixing ratios.

In [None]:
# Exercise 3: Multi-dataset sampling

class MultiDatasetLoader:
    """
    Samples from multiple datasets according to specified ratios.
    
    Args:
        datasets: List of datasets
        ratios: Sampling ratios for each dataset (will be normalized)
        batch_size: Batch size
    """
    
    def __init__(self, datasets: List[Dataset], ratios: List[float], batch_size: int = 32):
        # YOUR CODE HERE
        pass
    
    def __iter__(self):
        # YOUR CODE HERE
        # Hint: Create samplers for each dataset and interleave according to ratios
        pass
    
    def __len__(self):
        # YOUR CODE HERE
        pass

---

## Solutions

In [None]:
# Solution 1: Stratified Split

def stratified_split(dataset, labels: List[int], val_ratio: float = 0.1):
    """
    Split a dataset into train/val sets with the same class distribution.
    """
    from collections import defaultdict
    
    # Group indices by class
    class_indices = defaultdict(list)
    for idx, label in enumerate(labels):
        class_indices[label].append(idx)
    
    train_indices = []
    val_indices = []
    
    # Sample from each class
    for cls, indices in class_indices.items():
        # Shuffle indices for this class
        perm = torch.randperm(len(indices)).tolist()
        shuffled = [indices[i] for i in perm]
        
        # Split
        val_count = int(len(shuffled) * val_ratio)
        val_indices.extend(shuffled[:val_count])
        train_indices.extend(shuffled[val_count:])
    
    return train_indices, val_indices


# Test
test_dataset = MNIST('../data', train=True, download=True)
labels = [test_dataset[i][1] for i in range(len(test_dataset))]
train_idx, val_idx = stratified_split(test_dataset, labels, val_ratio=0.2)

print(f"Train size: {len(train_idx)}, Val size: {len(val_idx)}")
print(f"Train class dist: {Counter([labels[i] for i in train_idx])}")
print(f"Val class dist: {Counter([labels[i] for i in val_idx])}")

In [None]:
# Solution 2: Dynamic Batch Sampler

class DynamicBatchSampler(Sampler):
    """
    Sampler that creates batches with approximately equal total tokens.
    """
    
    def __init__(self, lengths: List[int], max_tokens: int = 1000, shuffle: bool = True):
        self.lengths = lengths
        self.max_tokens = max_tokens
        self.shuffle = shuffle
        
        # Pre-compute batches
        self._create_batches()
    
    def _create_batches(self):
        # Sort indices by length
        sorted_indices = sorted(range(len(self.lengths)), key=lambda i: self.lengths[i])
        
        self.batches = []
        current_batch = []
        current_max_len = 0
        
        for idx in sorted_indices:
            seq_len = self.lengths[idx]
            new_max_len = max(current_max_len, seq_len)
            new_batch_tokens = new_max_len * (len(current_batch) + 1)
            
            if new_batch_tokens > self.max_tokens and current_batch:
                # Start new batch
                self.batches.append(current_batch)
                current_batch = [idx]
                current_max_len = seq_len
            else:
                current_batch.append(idx)
                current_max_len = new_max_len
        
        if current_batch:
            self.batches.append(current_batch)
    
    def __iter__(self):
        if self.shuffle:
            perm = torch.randperm(len(self.batches)).tolist()
            for i in perm:
                yield from self.batches[i]
        else:
            for batch in self.batches:
                yield from batch
    
    def __len__(self):
        return sum(len(b) for b in self.batches)


# Test
lengths = [10, 50, 15, 100, 25, 30, 80, 12, 45, 60]
dyn_sampler = DynamicBatchSampler(lengths, max_tokens=150, shuffle=False)

print(f"Batches created: {len(dyn_sampler.batches)}")
for i, batch in enumerate(dyn_sampler.batches):
    batch_lens = [lengths[j] for j in batch]
    total_tokens = max(batch_lens) * len(batch)
    print(f"  Batch {i}: size={len(batch)}, lengths={batch_lens}, total_tokens={total_tokens}")

In [None]:
# Solution 3: Multi-Dataset Loader

class MultiDatasetLoader:
    """
    Samples from multiple datasets according to specified ratios.
    """
    
    def __init__(self, datasets: List[Dataset], ratios: List[float], batch_size: int = 32):
        self.datasets = datasets
        self.batch_size = batch_size
        
        # Normalize ratios
        total = sum(ratios)
        self.ratios = [r / total for r in ratios]
        
        # Calculate samples per dataset for one epoch
        # Use the largest dataset as reference
        max_size = max(len(d) for d in datasets)
        self.samples_per_dataset = [
            int(max_size * ratio) for ratio in self.ratios
        ]
        
        # Create weighted sampler for dataset selection
        self.total_samples = sum(self.samples_per_dataset)
    
    def __iter__(self):
        # Create iterator for each dataset
        samplers = [
            iter(RandomSampler(d, replacement=True, num_samples=n))
            for d, n in zip(self.datasets, self.samples_per_dataset)
        ]
        
        # Create dataset selection weights
        remaining = list(self.samples_per_dataset)
        
        batch = []
        while sum(remaining) > 0:
            # Select dataset proportionally to remaining samples
            weights = [r / sum(remaining) if sum(remaining) > 0 else 0 for r in remaining]
            dataset_idx = np.random.choice(len(self.datasets), p=weights)
            
            # Get sample
            try:
                sample_idx = next(samplers[dataset_idx])
                batch.append(self.datasets[dataset_idx][sample_idx])
                remaining[dataset_idx] -= 1
                
                if len(batch) == self.batch_size:
                    yield default_collate(batch)
                    batch = []
            except StopIteration:
                remaining[dataset_idx] = 0
        
        # Yield remaining
        if batch:
            yield default_collate(batch)
    
    def __len__(self):
        return (self.total_samples + self.batch_size - 1) // self.batch_size


# Test with MNIST and FashionMNIST
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
mnist = MNIST('../data', train=True, download=True, transform=transform)
fashion = FashionMNIST('../data', train=True, download=True, transform=transform)

multi_loader = MultiDatasetLoader(
    datasets=[mnist, fashion],
    ratios=[0.7, 0.3],  # 70% MNIST, 30% FashionMNIST
    batch_size=32
)

print(f"Total batches: {len(multi_loader)}")
print(f"Samples per dataset: {multi_loader.samples_per_dataset}")

---

## Summary

### Key Takeaways

1. **Dataset Types**:
   - `Dataset` (map-style): Index-based, use for most cases
   - `IterableDataset`: For streaming data, requires worker-aware splitting

2. **DataLoader Components**:
   - Sampler determines **which** indices to load
   - Collate function determines **how** to batch samples
   - Multi-worker loading parallelizes I/O

3. **Performance Optimization**:
   - Use `num_workers > 0` for disk-bound loading
   - Enable `pin_memory=True` for GPU training
   - Set `persistent_workers=True` to avoid startup overhead

4. **Class Imbalance**:
   - `WeightedRandomSampler` for oversampling minority classes
   - Custom samplers for curriculum learning, etc.

5. **Data Augmentation**:
   - Use `torchvision.transforms` for standard augmentations
   - MixUp/CutMix improve generalization
   - Augment training data only, not validation/test

### Common Pitfalls

- Forgetting to handle multi-worker splitting in `IterableDataset`
- Using `shuffle=True` with custom sampler (mutually exclusive)
- Not accounting for variable-length sequences in collation
- Applying augmentation to validation data