# PyTorch Training Framework Tutorial

This notebook demonstrates how to use the PyTorch training framework with Lightning and Hydra for end-to-end machine learning workflows.

## Overview

The framework provides:
- **Hydra Configuration Management**: Flexible configuration system for experiments
- **Lightning Integration**: Scalable training with PyTorch Lightning
- **Model Zoo**: Pre-built models and data modules
- **Experiment Tracking**: Built-in logging and monitoring

We'll walk through a complete CIFAR-10 classification example.


## 1. Setup and Imports

First, let's import the necessary libraries and set up our environment.


In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import lightning as L
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add the source directory to Python path
project_root = Path.cwd().parent
src_path = project_root / "src"
sys.path.append(str(src_path))

# Import our framework components
from tfh_train.model_zoo.cifar_clf.model import CifarClassifier
from tfh_train.model_zoo.cifar_clf.data_module import CifarClassifierLightningDataModule
from tfh_train.model_zoo.cifar_clf.model_module import CifarClassifierTraining

print(f"PyTorch version: {torch.__version__}")
print(f"Lightning version: {L.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")


## 2. Understanding the Model Architecture

Let's examine the CIFAR-10 classifier model that's included in the framework.


In [None]:
# Create an instance of the model
model = CifarClassifier()
print("Model Architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test with a sample input
sample_input = torch.randn(1, 3, 32, 32)  # CIFAR-10 image size
with torch.no_grad():
    output = model(sample_input)
print(f"\nInput shape: {sample_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output (logits): {output}")


## 3. Data Loading and Exploration

Let's set up the data module and explore the CIFAR-10 dataset.


In [None]:
# CIFAR-10 class names
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']

# Create data module
data_module = CifarClassifierLightningDataModule(
    batch_size=32,
    num_workers=2,
    pin_memory=True
)

# Setup the data module
data_module.setup(stage="fit")

print(f"Training dataset size: {len(data_module.train_dataset)}")
print(f"Validation dataset size: {len(data_module.validation_dataset)}")
print(f"Test dataset size: {len(data_module.test_dataset)}")


In [None]:
# Visualize some sample images
def imshow(img, title=None):
    """Display image with denormalization."""
    img = img * 0.5 + 0.5  # Denormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    if title:
        plt.title(title)
    plt.axis('off')

# Get a batch of training data
train_loader = data_module.train_dataloader()
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(8):
    ax = axes[i//4, i%4]
    ax.imshow(np.transpose((images[i] * 0.5 + 0.5).numpy(), (1, 2, 0)))
    ax.set_title(f'Class: {cifar10_classes[labels[i]]}')
    ax.axis('off')

plt.tight_layout()
plt.show()

print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")


## 4. Setting Up the Lightning Module

Now let's create the Lightning module that wraps our model with training logic.


In [None]:
import functools
import torch.optim as optim

# Create the model
model = CifarClassifier()

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer (using functools.partial as expected by the framework)
optimizer_partial = functools.partial(optim.Adam, lr=0.001)

# Create the Lightning module
lightning_module = CifarClassifierTraining(
    model=model,
    criterion=criterion,
    optimizer=optimizer_partial
)

print("Lightning module created successfully!")
print(f"Model: {type(lightning_module.neural_net).__name__}")
print(f"Criterion: {type(lightning_module.criterion).__name__}")
print(f"Optimizer: {optimizer_partial.func.__name__}")

# Note: The Lightning module now includes a forward() method that delegates to the underlying model
# This allows direct inference: lightning_module(input_tensor)


## 5. Training the Model

Let's train our model using PyTorch Lightning trainer.


In [None]:
# Set up trainer
trainer = L.Trainer(
    max_epochs=3,  # Keep it short for demo
    accelerator="auto",  # Use GPU if available
    devices=1,
    logger=True,  # Enable default logger
    enable_checkpointing=True,
    enable_progress_bar=True,
    log_every_n_steps=50
)

print(f"Trainer configured:")
print(f"  Max epochs: {trainer.max_epochs}")
print(f"  Accelerator: {trainer.accelerator}")
print(f"  Devices: {trainer.num_devices}")


In [None]:
# Start training
print("Starting training...")
trainer.fit(lightning_module, data_module)
print("Training completed!")


## 6. Model Evaluation

Let's evaluate our trained model on the test set.


In [None]:
# Test the model
test_results = trainer.test(lightning_module, data_module)
print("Test Results:")
for key, value in test_results[0].items():
    print(f"  {key}: {value:.4f}")


## 7. Making Predictions

Let's use our trained model to make predictions on some test images.


In [None]:
# Set model to evaluation mode
lightning_module.eval()

# Get a batch of test data
test_loader = data_module.test_dataloader()
test_iter = iter(test_loader)
test_images, test_labels = next(test_iter)

# Make predictions
with torch.no_grad():
    outputs = lightning_module(test_images)
    _, predicted = torch.max(outputs, 1)
    probabilities = F.softmax(outputs, dim=1)

# Visualize predictions
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
for i in range(8):
    ax = axes[i//4, i%4]
    
    # Display image
    img = test_images[i] * 0.5 + 0.5  # Denormalize
    ax.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    
    # Add prediction info
    true_class = cifar10_classes[test_labels[i]]
    pred_class = cifar10_classes[predicted[i]]
    confidence = probabilities[i][predicted[i]].item()
    
    color = 'green' if predicted[i] == test_labels[i] else 'red'
    ax.set_title(f'True: {true_class}\nPred: {pred_class}\nConf: {confidence:.2f}', 
                color=color, fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

# Calculate accuracy for this batch
correct = (predicted == test_labels).sum().item()
total = test_labels.size(0)
batch_accuracy = 100 * correct / total
print(f"Batch accuracy: {batch_accuracy:.2f}% ({correct}/{total})")


## 8. Summary and Next Steps

Congratulations! You've successfully completed an end-to-end machine learning workflow using the PyTorch training framework.

### What we accomplished:

1. ✅ **Explored the framework structure** - Understanding the modular design
2. ✅ **Loaded and visualized data** - Working with CIFAR-10 dataset  
3. ✅ **Trained a model** - Using PyTorch Lightning for scalable training
4. ✅ **Evaluated performance** - Analyzing results with metrics and visualizations
5. ✅ **Made predictions** - Using the trained model for inference

### Next Steps:

1. **Experiment with hyperparameters**: Modify learning rate, batch size, architecture
2. **Try different optimizers**: SGD, AdamW, etc.
3. **Add data augmentation**: Improve model generalization
4. **Implement callbacks**: Early stopping, learning rate scheduling
5. **Scale to multiple GPUs**: Use Lightning's distributed training features
6. **Create custom models**: Add your own architectures to the model zoo
7. **Experiment tracking**: Integrate with MLflow, Weights & Biases

### Framework Commands:

You can also use the framework from the command line:

```bash
# Train with default config
tfh-train

# Train with specific experiment
tfh-train experiment=cifar_clf/cifar_clf_training

# Override specific parameters
tfh-train trainer.max_epochs=20 data_module.batch_size=64

# Evaluate a trained model
tfh-evaluate ckpt_path=/path/to/checkpoint.ckpt
```


In [None]:
# Final summary statistics
print("🎉 Tutorial Complete!")
print(f"📊 Final Test Accuracy: {test_results[0].get('test_Accuracy', 'N/A'):.4f}")
print(f"🏗️  Model Parameters: {trainable_params:,}")
print(f"📚 Dataset Size: {len(data_module.train_dataset):,} training samples")
print(f"⚡ Framework: PyTorch Lightning + Hydra")
print(f"🚀 Ready for production scaling!")
