# Data Loading, Preprocessing, and Augmentation in PyTorch

This notebook provides a comprehensive guide to efficiently loading, preprocessing, and augmenting data in PyTorch. Effective data handling is critical for any machine learning pipeline.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import os
import pandas as pd
from pathlib import Path
import glob
import time

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

# Create output directory
output_dir = "05_data_loading_preprocessing_outputs"
os.makedirs(output_dir, exist_ok=True)

## 1. Introduction to Data Handling

Data loading and preprocessing are critical steps in any machine learning pipeline:

- **Loading:** Reading data from various sources (files, databases)
- **Preprocessing:** Cleaning, transforming, and structuring data
- **Augmentation:** Artificially expanding the dataset for better generalization

In [None]:
# Demonstrate built-in datasets
print("Using Built-in Datasets:")
print("-" * 30)

# Load MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transforms.ToTensor()
)

print(f"Dataset size: {len(mnist_dataset)}")
sample, label = mnist_dataset[0]
print(f"Sample shape: {sample.shape}")
print(f"Sample dtype: {sample.dtype}")
print(f"Label: {label}")

# Visualize a sample
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(sample.squeeze(), cmap='gray')
plt.title(f'MNIST Sample (Label: {label})')
plt.axis('off')

# Show multiple samples
plt.subplot(1, 2, 2)
fig, axes = plt.subplots(2, 3, figsize=(6, 4))
for i, ax in enumerate(axes.flat):
    if i < 6:
        img, lbl = mnist_dataset[i]
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(f'Label: {lbl}')
        ax.axis('off')
plt.tight_layout()
plt.show()