# Simple Neural Network (MLP) on MNIST

In this tutorial, we implement the **simplest feedforward neural network** (also known as a *Multi-Layer Perceptron, MLP*) using **PyTorch**.  
This acts as the **"Hello World"** of deep learning and provides the foundation to understand more complex architectures (CNNs, RNNs, Transformers, etc.).

##### What is an MLP?

A **Multi-Layer Perceptron (MLP)** is a sequence of fully connected layers where:
- Each neuron in one layer is connected to every neuron in the next.
- Non-linear activation functions (like **ReLU**) allow the model to learn complex patterns.
- The final layer produces class scores (logits).

For MNIST (handwritten digit classification):
- **Input layer**: Flattened 28×28 = 784 pixels.
- **Hidden layers**: Two fully connected layers (128 and 64 units).
- **Output layer**: 10 units (digits 0–9).

##### Why MNIST?

The **MNIST dataset** is a classic benchmark of handwritten digits (0–9), with:
- 60,000 training images  
- 10,000 test images  
- Each image is 28×28 grayscale  

It is small, easy to train on a CPU/GPU, and perfect for learning the basics.

##### Training Objective

We want the network to **minimize the Cross-Entropy Loss**, which measures the difference between the predicted probabilities and the true labels.  
We use the **Adam optimizer** for efficient gradient-based learning.

##### Steps in This Notebook

1. **Load MNIST dataset** with PyTorch’s `torchvision`.
2. **Define the MLP model** with two hidden layers.
3. **Train the model** using backpropagation.
4. **Evaluate on test data**.
5. **Visualize predictions** for random samples.

After completing this, you’ll have a working understanding of how a basic neural network works and be ready to move on to deeper architectures like CNNs.

# Code

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

### Load Data

In [2]:
def get_mnist_dataloaders(batch_size=64, root="./data", shuffle_test=False):
    # simple transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_dataset = datasets.MNIST(root, train=True, download=True, transform=transform)
    test_dataset  = datasets.MNIST(root, train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle_test)
    return train_loader, test_loader

In [3]:
train_data, test_data = get_mnist_dataloaders()
print(f"Number of training batches: {len(train_data)}")
print(f"Number of test batches: {len(test_data)}\n{'-'*40}")

for batch_idx, (data, target) in enumerate(train_data):
    print("Data shape:", data.shape)
    print("Target shape:", target.shape)
    break

# Create a 1x5 subplot grid
fig = make_subplots(rows=1, cols=5,
                    subplot_titles=[f"Label: {lbl.item()}" for lbl in target[:5]])

for i, (img, label) in enumerate(zip(data[:5], target[:5])):
    img = img.squeeze().numpy()
    heatmap = go.Heatmap(
        z=img,
        colorscale="gray",
        showscale=False
    )
    fig.add_trace(heatmap, row=1, col=i+1)

# Fix orientation: y-axis reversed, no mirroring
fig.update_yaxes(autorange="reversed", scaleanchor=None)

fig.update_layout(
    width=1000, height=250,
    margin=dict(l=10, r=10, b=10, t=60),  # increase top margin
    title=dict(
        text="First 5 MNIST Images",
        y=0.95,               # position closer to top (0–1 range)
        x=0.5,                # center
        xanchor='center',
        yanchor='top'
    )
)

fig.show()

del train_data, test_data, data, target, batch_idx, img, label, heatmap, fig

Number of training batches: 938
Number of test batches: 157
----------------------------------------
Data shape: torch.Size([64, 1, 28, 28])
Target shape: torch.Size([64])


### Simple Neural Network

In [4]:
class SimpleMLP(nn.Module):
    def __init__(self, input_size=28*28, hidden_sizes=[128, 64], num_classes=10):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.fc3 = nn.Linear(hidden_sizes[1], num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


```mermaid
graph LR
    A["Input Layer (28x28 = 784)"] --> B["FC1: Linear (784 → 128)"]
    B --> C[ReLU]
    C --> D["FC2: Linear(128 → 64)"]
    D --> E[ReLU]
    E --> F["FC3: Linear(64 → 10)"]
    F --> G["Output Layer (10 classes)"]
```

In [5]:
def accuracy(outputs, labels):
    """Compute accuracy, given logits outputs and true labels."""
    _, preds = torch.max(outputs, 1)
    return torch.mean((preds == labels).float())


In [6]:
def train(
    epochs=5,
    batch_size=64,
    learning_rate=1e-3,
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    
    train_loader, test_loader = get_mnist_dataloaders(batch_size=batch_size)
    model = SimpleMLP().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_acc += accuracy(outputs, target).item()
        
        avg_loss = running_loss / len(train_loader)
        avg_acc  = running_acc / len(train_loader)
        print(f"Epoch {epoch} — loss: {avg_loss:.4f}, acc: {avg_acc:.4f}")
        
        # evaluate on test set
        model.eval()
        test_loss = 0.0
        test_acc = 0.0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = criterion(outputs, target)
                test_loss += loss.item()
                test_acc += accuracy(outputs, target).item()
        test_loss /= len(test_loader)
        test_acc  /= len(test_loader)
        print(f" → Test — loss: {test_loss:.4f}, acc: {test_acc:.4f}")
    
    return model


In [7]:
model = train(epochs=3, batch_size=128, learning_rate=1e-3)

torch.save(model.state_dict(), "simple_mlp.pth")
print("Model saved to simple_mlp.pth")

Using device: cpu
Epoch 1 — loss: 0.3180, acc: 0.9075
 → Test — loss: 0.1528, acc: 0.9539
Epoch 2 — loss: 0.1333, acc: 0.9597
 → Test — loss: 0.1285, acc: 0.9612
Epoch 3 — loss: 0.0947, acc: 0.9715
 → Test — loss: 0.0991, acc: 0.9714
Model saved to simple_mlp.pth


In [8]:
# Load trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleMLP()
model.load_state_dict(torch.load("simple_mlp.pth", map_location=device))
model.to(device)
model.eval()

SimpleMLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [17]:
# Load test data
_, test_loader = get_mnist_dataloaders(batch_size=5, shuffle_test=True)

for batch_idx, (data, target) in enumerate(test_loader):
    print("Data shape:", data.shape)
    print("Target shape:", target.shape)
    break

data, target = next(iter(test_loader))
data, target = data.to(device), target.to(device)

# Run prediction
with torch.no_grad():
    outputs = model(data)
    _, preds = torch.max(outputs, 1)

# Plot 5 samples with predicted + actual labels
fig = make_subplots(
    rows=1, cols=5,
    subplot_titles=[f"True: {lbl.item()}<br>Pred: {pred.item()}"
                    for lbl, pred in zip(target[:5], preds[:5])]
)

for i, img in enumerate(data[:5]):
    img = img.squeeze().cpu().numpy()
    heatmap = go.Heatmap(
        z=img,
        colorscale="gray",
        showscale=False
    )
    fig.add_trace(heatmap, row=1, col=i+1)

# Fix orientation (avoid flipped digits)
fig.update_yaxes(autorange="reversed", scaleanchor=None)

fig.update_layout(
    width=1000, height=250,
    margin=dict(l=10, r=10, b=10, t=90),
    title=dict(
        text="MNIST Predictions (First 5 Test Images)",
        y=0.95,
        x=0.5,
        xanchor='center',
        yanchor='top'
    )
)

fig.show()

Data shape: torch.Size([5, 1, 28, 28])
Target shape: torch.Size([5])
