# PyTorch Lightning Tutorial

This notebook demonstrates how to use PyTorch Lightning to simplify deep learning workflows.

In [None]:
# Install PyTorch Lightning if not already installed
# !pip install pytorch-lightning

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np
import os
from typing import Optional

# Set random seed
pl.seed_everything(42)

# Check PyTorch Lightning version
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Introduction to PyTorch Lightning

PyTorch Lightning is a lightweight wrapper that:
- Eliminates boilerplate code
- Provides automatic optimization
- Enables easy distributed training
- Integrates logging and checkpointing
- Ensures reproducibility

## 2. Lightning Module

The LightningModule organizes PyTorch code into a standard structure.

In [None]:
class LitMNISTClassifier(pl.LightningModule):
    """A simple CNN for MNIST classification using PyTorch Lightning."""
    
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        # Save hyperparameters
        self.save_hyperparameters()
        
        # Define model architecture
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.25)
        
        # Metrics
        self.train_accuracy = pl.metrics.Accuracy()
        self.val_accuracy = pl.metrics.Accuracy()
        
    def forward(self, x):
        """Forward pass."""
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        """Training step - called for each batch."""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step."""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, y)
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        """Test step."""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('test_loss', loss)
        self.log('test_acc', acc)
    
    def configure_optimizers(self):
        """Configure optimizers and schedulers."""
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

In [None]:
# Create and inspect the model
model = LitMNISTClassifier(learning_rate=1e-3)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Hyperparameters: {model.hparams}")

## 3. Data Module

DataModules encapsulate all data loading logic in a reusable class.

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    """PyTorch Lightning DataModule for MNIST."""
    
    def __init__(self, data_dir='./data', batch_size=64, num_workers=2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Define transforms
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    
    def prepare_data(self):
        """Download data if needed. Called only on 1 GPU/process."""
        torchvision.datasets.MNIST(self.data_dir, train=True, download=True)
        torchvision.datasets.MNIST(self.data_dir, train=False, download=True)
    
    def setup(self, stage: Optional[str] = None):
        """Setup train/val/test datasets. Called on every GPU."""
        if stage == 'fit' or stage is None:
            mnist_full = torchvision.datasets.MNIST(
                self.data_dir, train=True, transform=self.transform
            )
            # Split into train and validation
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000]
            )
        
        if stage == 'test' or stage is None:
            self.mnist_test = torchvision.datasets.MNIST(
                self.data_dir, train=False, transform=self.transform
            )
    
    def train_dataloader(self):
        return DataLoader(
            self.mnist_train, 
            batch_size=self.batch_size, 
            shuffle=True,
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.mnist_val, 
            batch_size=self.batch_size, 
            shuffle=False,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.mnist_test, 
            batch_size=self.batch_size, 
            shuffle=False,
            num_workers=self.num_workers
        )

In [None]:
# Create data module
data_module = MNISTDataModule(batch_size=64, num_workers=0)
data_module.prepare_data()
data_module.setup('fit')

print(f"Train samples: {len(data_module.mnist_train)}")
print(f"Val samples: {len(data_module.mnist_val)}")

# Visualize some samples
train_loader = data_module.train_dataloader()
batch = next(iter(train_loader))
images, labels = batch

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i].squeeze(), cmap='gray')
    ax.set_title(f'Label: {labels[i]}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## 4. Basic Training

Training with PyTorch Lightning is as simple as creating a Trainer and calling fit().

In [None]:
# Create model and data
model = LitMNISTClassifier(learning_rate=1e-3)
data_module = MNISTDataModule(batch_size=64, num_workers=0)

# Create trainer
trainer = pl.Trainer(
    max_epochs=5,
    gpus=1 if torch.cuda.is_available() else 0,
    progress_bar_refresh_rate=20
)

# Train model
trainer.fit(model, data_module)

In [None]:
# Test the model
test_results = trainer.test(model, data_module)
print(f"Test accuracy: {test_results[0]['test_acc']:.4f}")

## 5. Callbacks

Callbacks allow you to add functionality at various points during training.

In [None]:
# Define callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints/',
    filename='mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min'
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

lr_monitor = LearningRateMonitor(logging_interval='step')

# Create trainer with callbacks
trainer = pl.Trainer(
    max_epochs=10,
    gpus=1 if torch.cuda.is_available() else 0,
    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
    progress_bar_refresh_rate=20
)

# Create new model
model = LitMNISTClassifier(learning_rate=1e-3)

# Train with callbacks
trainer.fit(model, data_module)

In [None]:
# Load best checkpoint
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")

# Load and test best model
best_model = LitMNISTClassifier.load_from_checkpoint(best_model_path)
test_results = trainer.test(best_model, data_module)
print(f"Best model test accuracy: {test_results[0]['test_acc']:.4f}")

## 6. Advanced Lightning Module

Here's a more advanced example with additional features.

In [None]:
class AdvancedLitModel(pl.LightningModule):
    """Advanced Lightning module with more features."""
    
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        # Enhanced architecture
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            # Second block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * 7 * 7, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        output = self.classifier(features)
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log_dict({
            'train_loss': loss,
            'train_acc': acc
        }, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log_dict({
            'val_loss': loss,
            'val_acc': acc
        }, prog_bar=True)
        
        return {'val_loss': loss, 'val_acc': acc}
    
    def validation_epoch_end(self, outputs):
        # Calculate average metrics
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        
        print(f"\nValidation - Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
    
    def configure_optimizers(self):
        # Optimizer
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=1e-4
        )
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            verbose=True
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss',
                'interval': 'epoch',
                'frequency': 1
            }
        }

In [None]:
# Create advanced model
advanced_model = AdvancedLitModel(learning_rate=1e-3)
print(f"Advanced model parameters: {sum(p.numel() for p in advanced_model.parameters()):,}")

# Train with advanced features
trainer = pl.Trainer(
    max_epochs=10,
    gpus=1 if torch.cuda.is_available() else 0,
    precision=16,  # Mixed precision training
    gradient_clip_val=1.0,  # Gradient clipping
    accumulate_grad_batches=2,  # Gradient accumulation
    progress_bar_refresh_rate=20
)

trainer.fit(advanced_model, data_module)

## 7. Logging with TensorBoard

In [None]:
# Create TensorBoard logger
tb_logger = TensorBoardLogger(
    save_dir='lightning_logs',
    name='mnist_experiment'
)

# Create model with custom logging
class LoggingLitModel(LitMNISTClassifier):
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        
        # Log images occasionally
        if batch_idx % 100 == 0:
            # Log sample images
            grid = torchvision.utils.make_grid(x[:8])
            self.logger.experiment.add_image('train_images', grid, self.global_step)
            
            # Log predictions
            self.logger.experiment.add_text(
                'predictions',
                f'True: {y[:8].tolist()}, Pred: {preds[:8].tolist()}',
                self.global_step
            )
        
        return loss

# Train with logging
logging_model = LoggingLitModel()
trainer = pl.Trainer(
    max_epochs=3,
    gpus=1 if torch.cuda.is_available() else 0,
    logger=tb_logger,
    progress_bar_refresh_rate=20
)

trainer.fit(logging_model, data_module)

print(f"\nTensorBoard logs saved to: {tb_logger.log_dir}")
print("Run 'tensorboard --logdir=lightning_logs' to view logs")

## 8. Model Export and Inference

In [None]:
# Export to TorchScript
model.eval()
example_input = torch.randn(1, 1, 28, 28)
traced_model = model.to_torchscript(method='trace', example_inputs=example_input)

# Save traced model
torch.jit.save(traced_model, 'mnist_lightning_model.pt')
print("Model exported to TorchScript")

# Load and use for inference
loaded_model = torch.jit.load('mnist_lightning_model.pt')
loaded_model.eval()

# Test inference
with torch.no_grad():
    test_input = torch.randn(1, 1, 28, 28)
    output = loaded_model(test_input)
    prediction = torch.argmax(output, dim=1)
    print(f"Inference test - Output shape: {output.shape}, Prediction: {prediction.item()}")

## 9. Best Practices Summary

In [None]:
# Create a summary of best practices
best_practices = """
# PyTorch Lightning Best Practices

## 1. Code Organization
- Keep model logic in LightningModule
- Use DataModules for data handling
- Separate training, validation, and test logic

## 2. Reproducibility
- Use pl.seed_everything()
- Save hyperparameters with save_hyperparameters()
- Version control your experiments

## 3. Monitoring
- Use self.log() for metrics
- Integrate with TensorBoard, W&B, etc.
- Monitor hardware utilization

## 4. Performance
- Use mixed precision training (precision=16)
- Enable gradient accumulation for large batches
- Profile your code to find bottlenecks

## 5. Distributed Training
- Start with DDP for multi-GPU
- Test on single GPU first
- Use appropriate batch sizes

## 6. Production
- Export to TorchScript for deployment
- Use ModelCheckpoint for saving best models
- Implement proper error handling
"""

print(best_practices)

# Save summary
with open('lightning_best_practices.md', 'w') as f:
    f.write(best_practices)
print("\nBest practices saved to 'lightning_best_practices.md'")

## Summary

In this notebook, we've covered:

1. **LightningModule**: Organizing PyTorch code into a standard structure
2. **DataModule**: Encapsulating data loading logic
3. **Trainer**: Automating the training loop with minimal code
4. **Callbacks**: Adding functionality like checkpointing and early stopping
5. **Logging**: Integrating with TensorBoard for experiment tracking
6. **Advanced Features**: Mixed precision, gradient accumulation, and more
7. **Export**: Converting models for production deployment

PyTorch Lightning simplifies deep learning workflows while maintaining flexibility and performance!