# Convolutional Neural Network (CNN) on MNIST

In this tutorial, we move from a simple **Multi-Layer Perceptron (MLP)** to a **Convolutional Neural Network (CNN)**, which is far more effective for image classification tasks.  
CNNs exploit the spatial structure of images by applying *convolutional filters* to detect edges, shapes, and patterns.

##### What is a CNN?

A **Convolutional Neural Network (CNN)** is designed to handle data with grid-like topology (e.g., images).  
Key building blocks:

- **Convolutional Layers**: Learn filters/kernels that slide over the input and extract local features (edges, corners, textures).
- **Activation Functions (ReLU)**: Introduce non-linearity, enabling the network to learn complex patterns.
- **Pooling Layers (MaxPooling)**: Downsample feature maps to reduce computation and improve generalization.
- **Fully Connected Layers**: Combine extracted features to perform classification.

##### Why CNNs for Images?

Unlike MLPs, which flatten images and ignore spatial locality, CNNs:
- Preserve the **2D structure** of the image.
- Share weights across pixels (parameter-efficient).
- Automatically learn hierarchical representations:
  - Lower layers → edges, blobs
  - Higher layers → shapes, digits, objects

##### Architecture in This Notebook

For the **MNIST handwritten digit dataset (28×28 grayscale)**:
1. **Conv Layer 1**: 1 input channel → 32 filters, kernel size = 3, padding = 1  
   → preserves 28×28 size.
2. **Conv Layer 2**: 32 → 64 filters, kernel size = 3, padding = 1  
   → followed by 2×2 MaxPooling → reduces size to 14×14.
3. **Fully Connected Layer 1**: Flattens into `64 × 14 × 14 = 12544` features → 128 neurons.
4. **Fully Connected Layer 2**: Outputs 10 logits (digits 0–9).

##### Training Objective

- **Loss Function**: Cross-Entropy Loss  
- **Optimizer**: Adam  
- **Evaluation Metric**: Accuracy  

CNNs generally achieve **98%+ accuracy on MNIST** with only a few epochs of training.

##### Steps in This Notebook

1. **Load MNIST dataset** with `torchvision`.
2. **Define CNN model** with 2 convolutional layers and 2 fully connected layers.
3. **Train the model** and monitor accuracy.
4. **Evaluate on test images**.
5. **Visualize predictions** for random samples.

# Code

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

### Load Data

In [2]:
def get_mnist_dataloaders(batch_size=64, root="../../data", shuffle_test=False):
    # simple transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_dataset = datasets.MNIST(root, train=True, download=True, transform=transform)
    test_dataset  = datasets.MNIST(root, train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle_test)
    return train_loader, test_loader

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        # Conv layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) # 28x28 → 28x28
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)                         # 28x28 → 28x28
        self.pool = nn.MaxPool2d(2, 2)  # downsample → 14x14

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # (N, 32, 28, 28)
        x = self.pool(F.relu(self.conv2(x)))  # (N, 64, 14, 14)
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [4]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, 1)
    return torch.mean((preds == labels).float())

In [5]:


def train(epochs=5, batch_size=64, lr=1e-3, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    train_loader, test_loader = get_mnist_dataloaders(batch_size=batch_size)
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs+1):
        model.train()
        train_loss, train_acc = 0.0, 0.0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_acc += accuracy(outputs, target).item()

        train_loss /= len(train_loader)
        train_acc /= len(train_loader)

        # Evaluation
        model.eval()
        test_loss, test_acc = 0.0, 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)
                test_loss += loss.item()
                test_acc += accuracy(outputs, target).item()

        test_loss /= len(test_loader)
        test_acc /= len(test_loader)

        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | "
              f"Test Loss={test_loss:.4f}, Acc={test_acc:.4f}")

    torch.save(model.state_dict(), "simple_cnn.pth")
    print("Model saved to simple_cnn.pth")

    return model


In [6]:
model = train(epochs=5, batch_size=64, lr=1e-3)

# Load trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN()
model.load_state_dict(torch.load("simple_cnn.pth", map_location=device))
model.to(device)
model.eval()

Using device: cpu
Epoch 1: Train Loss=0.1189, Acc=0.9638 | Test Loss=0.0618, Acc=0.9790
Epoch 2: Train Loss=0.0388, Acc=0.9878 | Test Loss=0.0344, Acc=0.9881
Epoch 3: Train Loss=0.0227, Acc=0.9921 | Test Loss=0.0377, Acc=0.9883
Epoch 4: Train Loss=0.0161, Acc=0.9949 | Test Loss=0.0455, Acc=0.9885
Epoch 5: Train Loss=0.0118, Acc=0.9959 | Test Loss=0.0432, Acc=0.9892
Model saved to simple_cnn.pth


SimpleCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=12544, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [17]:
# Load test data
_, test_loader = get_mnist_dataloaders(batch_size=5, shuffle_test=True)

for batch_idx, (data, target) in enumerate(test_loader):
    print("Data shape:", data.shape)
    print("Target shape:", target.shape)
    break

data, target = next(iter(test_loader))
data, target = data.to(device), target.to(device)

# Run prediction
with torch.no_grad():
    outputs = model(data)
    _, preds = torch.max(outputs, 1)

# Plot 5 samples with predicted + actual labels
fig = make_subplots(
    rows=1, cols=5,
    subplot_titles=[f"True: {lbl.item()}<br>Pred: {pred.item()}"
                    for lbl, pred in zip(target[:5], preds[:5])]
)

for i, img in enumerate(data[:5]):
    img = img.squeeze().cpu().numpy()
    heatmap = go.Heatmap(
        z=img,
        colorscale="gray",
        showscale=False
    )
    fig.add_trace(heatmap, row=1, col=i+1)

# Fix orientation (avoid flipped digits)
fig.update_yaxes(autorange="reversed", scaleanchor=None)

fig.update_layout(
    width=1000, height=250,
    margin=dict(l=10, r=10, b=10, t=90),
    title=dict(
        text="MNIST Predictions (First 5 Test Images)",
        y=0.95,
        x=0.5,
        xanchor='center',
        yanchor='top'
    )
)

fig.show()

Data shape: torch.Size([5, 1, 28, 28])
Target shape: torch.Size([5])
