# MNIST Classification with PyTorch

This notebook demonstrates how to implement a neural network for MNIST digit classification using PyTorch. It follows a standard machine learning workflow:

1. Data loading and preprocessing
2. Model definition
3. Training and validation
4. Evaluation and visualization
5. Model saving and loading

## 1. Setup and Imports

First, we import the necessary libraries and set up our environment.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import time

# Check if CUDA is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## 2. Data Loading and Preprocessing

We'll load the MNIST dataset using torchvision, apply transformations, and split it into training, validation, and test sets.

In [None]:
# Define transformations for the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with mean and std of MNIST
])

# Load MNIST dataset directly from torchvision
train_full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_full_dataset))
val_size = len(train_full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_full_dataset, [train_size, val_size])

# Create data loaders for batch processing
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

## 3. Data Visualization

Let's create a function to visualize some examples from our dataset.

In [None]:
# Function to display sample images
def plot_example(images, labels, num_samples=5):
    """Plot a selection of images and their labels"""
    plt.figure(figsize=(12, 3))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i + 1)
        
        # If input is a tensor, convert to numpy and reshape
        if isinstance(images, torch.Tensor):
            img = images[i].cpu().numpy()
            if img.shape[0] == 1:  # If it's in format [1, 28, 28]
                img = img.reshape(28, 28)
        else:
            img = images[i].reshape(28, 28)
            
        plt.imshow(img, cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.title(f"Label: {labels[i]}")
    plt.tight_layout()
    plt.show()

In [None]:
# Display a few training examples
dataiter = iter(train_loader)
images, labels = next(dataiter)
plot_example(images, labels)

## 4. Model Definition

Now we'll define our neural network model for MNIST classification.

In [None]:
# Define the neural network model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        # Input layer: 28x28 = 784 input features
        # First hidden layer: 256 neurons with ReLU activation and dropout
        self.layer1 = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        # Second hidden layer: 128 neurons with ReLU activation and dropout
        self.layer2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        # Output layer: 10 neurons (one for each digit 0-9)
        self.output = nn.Linear(128, 10)
    
    def forward(self, x):
        # Reshape input from [batch_size, 1, 28, 28] to [batch_size, 784]
        x = x.view(-1, 28*28)
        # Pass through first hidden layer
        x = self.layer1(x)
        # Pass through second hidden layer
        x = self.layer2(x)
        # Pass through output layer
        logits = self.output(x)
        return logits

#### Simpler alternative using only  `nn.Sequential`

In [None]:
##  Define the neural network model using nn.Sequential
# model = nn.Sequential(
#     nn.Flatten(),  # Reshape input from [batch_size, 1, 28, 28] to [batch_size, 784]
#     nn.Linear(28*28, 256),
#     nn.ReLU(),
#     nn.Dropout(0.4),
#     nn.Linear(256, 128),
#     nn.ReLU(),
#     nn.Dropout(0.3),
#     nn.Linear(128, 10)
# )

In [None]:
# Instantiate the model and move it to the device
model = MNISTClassifier().to(device)
print(model)

# Calculate the total number of trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of trainable model parameters: {num_params}")

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

#### Quick explanation of Momentum in SGD

Momentum in SGD adds a "velocity" component to parameter updates, similar to how 
a ball rolling down a hill accumulates momentum:

v_t = μ*v_{t-1} + η*∇J(θ)  # velocity update
θ_t = θ_{t-1} - v_t        # parameter update

Where:
- μ is the momentum coefficient (typically 0.9)
- η is the learning rate
- ∇J(θ) is the gradient of the loss function

Example: Imagine a ball rolling down a valley. Without momentum, the ball moves 
directly downhill at each point (standard SGD). With momentum, the ball retains 
some of its velocity from previous movements, allowing it to:
- Roll through small bumps (local minima)
- Move faster in consistent directions
- Dampen oscillations in narrow valleys

This makes training faster and more stable, particularly for complex loss landscapes.


## 5. Evaluation Metrics

Let's define a function to compute accuracy and loss metrics.

In [None]:
# Function to compute accuracy and loss
def compute_metrics(model, dataloader, criterion=None, device='cpu', calculate_loss=True, calculate_accuracy=True):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    running_loss = 0.0
    
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            
            # Calculate loss only if requested and criterion is provided
            if calculate_loss and criterion is not None:
                loss = criterion(outputs, labels)
                running_loss += loss.item()
            
            # Calculate accuracy only if requested
            if calculate_accuracy:
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
    # Calculate average loss and accuracy based on what was requested
    avg_loss = running_loss / len(dataloader) if calculate_loss and criterion is not None else None
    accuracy = correct / total if calculate_accuracy else None
    
    return accuracy, avg_loss

## 6. Training Function

Now we'll define the training function that will train our model over multiple epochs.

In [None]:
# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    # Lists to store metrics for plotting
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    # Start timer
    start_time = time.time()
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()  # Set model to training mode
        running_loss = 0.0
        
        for batch_idx, (data, labels) in enumerate(train_loader):
            # Move data to device
            data, labels = data.to(device), labels.to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            running_loss += loss.item()
            
            # Print progress every 100 batches
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], "
                      f"Batch [{batch_idx+1}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}")
        
        # Calculate average loss for the epoch
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        
        # Calculate training accuracy only (skip loss calculation since we already have it)
        train_accuracy, _ = compute_metrics(model, train_loader, device=device, calculate_loss=False)
        train_accuracies.append(train_accuracy)
        
        # Validation phase
        model.eval()  # Set model to evaluation mode
        val_accuracy, val_loss = compute_metrics(model, val_loader, criterion, device, calculate_loss=True, calculate_accuracy=True)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        # Print epoch statistics
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_loss:.4f}, "
              f"Train Accuracy: {train_accuracy:.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Accuracy: {val_accuracy:.4f}")
    
    # Calculate total training time
    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")
    
    return train_losses, train_accuracies, val_losses, val_accuracies

## 7. Model Training

Let's train our model and visualize the training progress.

In [None]:
# Train the model
print("Starting training...")
train_losses, train_accuracies, val_losses, val_accuracies = train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs=10
)

In [None]:
# Plot training metrics
plt.figure(figsize=(12, 8))

# Plot loss
plt.subplot(2, 1, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot accuracy
plt.subplot(2, 1, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

## 8. Model Evaluation

Now let's evaluate our model on the test set.

In [None]:
# Function to evaluate model on test set
def evaluate_on_test_set(model, test_loader, criterion, device):
    print("\n=== Final Evaluation on Test Set ===")
    print("Note: Test set has not been used during training or model selection")
    
    # Ensure model is in evaluation mode
    model.eval()
    
    # Compute metrics on test set
    test_accuracy, test_loss = compute_metrics(model, test_loader, criterion, device, calculate_loss=True, calculate_accuracy=True)
    
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print("==================================\n")
    
    return test_accuracy, test_loss

# Evaluate the model on the test set (only at the end, after all training is complete)
test_accuracy, test_loss = evaluate_on_test_set(model, test_loader, criterion, device)

## 9. Visualizing Model Predictions

Let's visualize some of the model's predictions on the test set.

In [None]:
# Function to visualize model predictions
def visualize_predictions(model, dataloader, device, num_samples=5):
    model.eval()
    dataiter = iter(dataloader)
    images, labels = next(dataiter)
    
    with torch.no_grad():
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    # Move tensors back to CPU for visualization
    images = images.cpu()
    labels = labels.cpu()
    predicted = predicted.cpu()
    
    # Plot images with true and predicted labels
    plt.figure(figsize=(12, 3))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i + 1)
        img = images[i][0].numpy()
        plt.imshow(img, cmap='gray')
        plt.xticks([])
        plt.yticks([])
        
        # Green if correct, red if wrong
        color = 'green' if predicted[i] == labels[i] else 'red'
        plt.title(f"True: {labels[i]}\nPred: {predicted[i]}", color=color)
    
    plt.tight_layout()
    plt.show()

# Visualize some predictions
visualize_predictions(model, test_loader, device)

## 10. Analyzing Misclassifications

Let's find and visualize some examples that the model misclassified.

In [None]:
# Find and visualize misclassified examples
def visualize_errors(model, dataloader, device, num_errors=5):
    model.eval()
    errors_images = []
    errors_labels = []
    errors_preds = []
    
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            
            # Find indices where predictions are wrong
            error_indices = (predicted != labels).nonzero(as_tuple=True)[0]
            
            for idx in error_indices:
                errors_images.append(data[idx].cpu())
                errors_labels.append(labels[idx].item())
                errors_preds.append(predicted[idx].item())
                
                if len(errors_images) >= num_errors:
                    break
            
            if len(errors_images) >= num_errors:
                break
    
    # Plot the errors
    plt.figure(figsize=(12, 3))
    for i in range(min(num_errors, len(errors_images))):
        plt.subplot(1, num_errors, i + 1)
        img = errors_images[i][0].numpy()
        plt.imshow(img, cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.title(f"True: {errors_labels[i]}\nPred: {errors_preds[i]}", color='red')
    
    plt.tight_layout()
    plt.show()

# Visualize some misclassified examples
visualize_errors(model, test_loader, device)

## 11. Saving and Loading the Model

Finally, let's save our trained model and demonstrate how to load it back.

In [None]:
# Save the trained model
torch.save(model.state_dict(), "mnist_model.pth")
print("Model saved to mnist_model.pth")

In [None]:
# Example of how to load the model
def load_model():
    loaded_model = MNISTClassifier().to(device)
    loaded_model.load_state_dict(torch.load("mnist_model.pth"))
    loaded_model.eval()
    return loaded_model

# Load the saved model and verify it works
loaded_model = load_model()
loaded_accuracy, loaded_loss = compute_metrics(loaded_model, test_loader, criterion, device,
 calculate_loss=True, calculate_accuracy=True)
 
print(f"Loaded model test loss: {loaded_loss:.4f}")
print(f"Loaded model test accuracy: {loaded_accuracy:.4f}")

# Verify this matches the original model's accuracy
print(f"Original model test loss: {test_loss:.4f}")
print(f"Original model test accuracy: {test_accuracy:.4f}")
print(f"Accuracy difference: {abs(loaded_accuracy - test_accuracy):.8f}")

### Exercise

Write a simple MNIST classifier using PyTorch
1. Load only images of digits 0, and 8 from the MNIST dataset. Hint: uses `torch.utils.data.Subset` to filter the dataset.
2. Create a simple neural network with one hidden layer (input -> 128 neurons -> output)
3. Train for 5 epochs and print accuracy on test set
4. Create a confusion matrix (can use scikit learn and seaborn or matplotlib ) to visualize the model's performance

Your code should:
- Filter the MNIST dataset to keep only digits 0, and 8
- Use nn.Linear layers to build your model
- Include proper training and evaluation loops
- Display the confusion matrix as a heatmap