In [11]:
# -----------------------------
# MNIST DIGIT CLASSIFIER (PyTorch)
# -----------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');
print(f'Using device: {device}');

Using device: cpu


In [13]:
# -----------------------------
# 1. LOAD DATA
# Transforms are preprocessing steps that get applied automatically to every image
# you load from a dataset. 
# Think of transforms as a recipe that says:

# “Every time you give me an image, do X, then Y, then Z to it.”
# “For every MNIST image: convert it to a PyTorch tensor.
# MNIST images come in as PIL images (Python Imaging Library).

# But your neural network expects tensors.
# -----------------------------
transform_train = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [14]:
# Load train dataset
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform_train,
    download=True
)

In [15]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform_test,
    download=True
)

In [7]:
# Make DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

unique_labels = sorted(set(train_dataset.targets.tolist()))
print('Unique labels in training dataset:', unique_labels)

Unique labels in training dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [16]:
# -----------------------------
# 2. DEFINE NEURAL NETWORK
# -----------------------------
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 128)
        self.norm = nn.LayerNorm(128)
        self.fc2 = nn.Linear(128, 10)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.norm(self.fc1(x)))
        x = self.fc2(x)
        return x
    
model = SimpleNN().to(device)
print(model)

SimpleNN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [9]:
# -----------------------------
# 2. DEFINE CNN MODEL (better/alternate method that uses a convolutional neural network)
# The reason that we would want to use a CNN over a NN is that CNNs are just better at image related tasks, so we should expect better performance from this model.
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),)
        self.fc = nn.Sequential(nn.Linear(64 * 14 * 14, 128), nn.ReLU(), nn.Linear(128, 10))
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
model = SimpleCNN().to(device)
print(model)

SimpleCNN(
  (conv_layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=12544, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [17]:
# -----------------------------
# 3. LOSS FUNCTION + OPTIMIZER
# -----------------------------
criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [18]:
# -----------------------------
# 4. TRAINING LOOP
# -----------------------------

epochs = 35
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
Epoch 1/35:   0%|          | 0/938 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
Epoch 1/35: 100%|██████████| 938/938 [00:07<00:00, 123.31it/s, loss=0.105] 


Epoch 1/35, Avg Loss: 0.2952


Epoch 2/35: 100%|██████████| 938/938 [00:09<00:00, 103.39it/s, loss=0.2]   


Epoch 2/35, Avg Loss: 0.1419


Epoch 3/35: 100%|██████████| 938/938 [00:09<00:00, 103.91it/s, loss=0.0326]


Epoch 3/35, Avg Loss: 0.1080


Epoch 4/35: 100%|██████████| 938/938 [00:06<00:00, 141.94it/s, loss=0.0222]


Epoch 4/35, Avg Loss: 0.0906


Epoch 5/35: 100%|██████████| 938/938 [00:09<00:00, 96.38it/s, loss=0.0653]  


Epoch 5/35, Avg Loss: 0.0798


Epoch 6/35: 100%|██████████| 938/938 [00:12<00:00, 73.85it/s, loss=0.0273] 


Epoch 6/35, Avg Loss: 0.0716


Epoch 7/35: 100%|██████████| 938/938 [00:14<00:00, 66.26it/s, loss=0.172]  


Epoch 7/35, Avg Loss: 0.0641


Epoch 8/35: 100%|██████████| 938/938 [00:13<00:00, 68.57it/s, loss=0.196]  


Epoch 8/35, Avg Loss: 0.0595


Epoch 9/35: 100%|██████████| 938/938 [00:11<00:00, 84.47it/s, loss=0.0721]  


Epoch 9/35, Avg Loss: 0.0551


Epoch 10/35: 100%|██████████| 938/938 [00:07<00:00, 122.18it/s, loss=0.0373] 


Epoch 10/35, Avg Loss: 0.0522


Epoch 11/35: 100%|██████████| 938/938 [00:06<00:00, 150.70it/s, loss=0.0553] 


Epoch 11/35, Avg Loss: 0.0481


Epoch 12/35: 100%|██████████| 938/938 [00:05<00:00, 156.36it/s, loss=0.0736] 


Epoch 12/35, Avg Loss: 0.0457


Epoch 13/35: 100%|██████████| 938/938 [00:05<00:00, 160.26it/s, loss=0.0328] 


Epoch 13/35, Avg Loss: 0.0428


Epoch 14/35: 100%|██████████| 938/938 [00:05<00:00, 159.58it/s, loss=0.0163] 


Epoch 14/35, Avg Loss: 0.0408


Epoch 15/35: 100%|██████████| 938/938 [00:05<00:00, 159.90it/s, loss=0.0244] 


Epoch 15/35, Avg Loss: 0.0376


Epoch 16/35: 100%|██████████| 938/938 [00:05<00:00, 156.69it/s, loss=0.0624] 


Epoch 16/35, Avg Loss: 0.0352


Epoch 17/35: 100%|██████████| 938/938 [00:05<00:00, 160.38it/s, loss=0.0436] 


Epoch 17/35, Avg Loss: 0.0359


Epoch 18/35: 100%|██████████| 938/938 [00:05<00:00, 159.39it/s, loss=0.0708] 


Epoch 18/35, Avg Loss: 0.0345


Epoch 19/35: 100%|██████████| 938/938 [00:06<00:00, 152.76it/s, loss=0.018]   


Epoch 19/35, Avg Loss: 0.0320


Epoch 20/35: 100%|██████████| 938/938 [00:10<00:00, 87.23it/s, loss=0.00163] 


Epoch 20/35, Avg Loss: 0.0309


Epoch 21/35: 100%|██████████| 938/938 [00:06<00:00, 135.63it/s, loss=0.0351]  


Epoch 21/35, Avg Loss: 0.0297


Epoch 22/35: 100%|██████████| 938/938 [00:06<00:00, 139.99it/s, loss=0.0463]  


Epoch 22/35, Avg Loss: 0.0279


Epoch 23/35: 100%|██████████| 938/938 [00:06<00:00, 139.28it/s, loss=0.158]  


Epoch 23/35, Avg Loss: 0.0292


Epoch 24/35: 100%|██████████| 938/938 [00:06<00:00, 137.94it/s, loss=0.00252] 


Epoch 24/35, Avg Loss: 0.0265


Epoch 25/35: 100%|██████████| 938/938 [00:06<00:00, 138.52it/s, loss=0.00423] 


Epoch 25/35, Avg Loss: 0.0282


Epoch 26/35: 100%|██████████| 938/938 [00:06<00:00, 137.19it/s, loss=0.0626]  


Epoch 26/35, Avg Loss: 0.0246


Epoch 27/35: 100%|██████████| 938/938 [00:06<00:00, 144.32it/s, loss=0.00217] 


Epoch 27/35, Avg Loss: 0.0236


Epoch 28/35: 100%|██████████| 938/938 [00:06<00:00, 145.20it/s, loss=0.0019]  


Epoch 28/35, Avg Loss: 0.0216


Epoch 29/35: 100%|██████████| 938/938 [00:09<00:00, 101.76it/s, loss=0.00344] 


Epoch 29/35, Avg Loss: 0.0228


Epoch 30/35: 100%|██████████| 938/938 [00:08<00:00, 108.89it/s, loss=0.0171]  


Epoch 30/35, Avg Loss: 0.0209


Epoch 31/35: 100%|██████████| 938/938 [00:08<00:00, 114.23it/s, loss=0.00886] 


Epoch 31/35, Avg Loss: 0.0203


Epoch 32/35: 100%|██████████| 938/938 [00:09<00:00, 94.30it/s, loss=0.00898]  


Epoch 32/35, Avg Loss: 0.0206


Epoch 33/35: 100%|██████████| 938/938 [00:08<00:00, 114.27it/s, loss=0.0052]  


Epoch 33/35, Avg Loss: 0.0194


Epoch 34/35: 100%|██████████| 938/938 [00:11<00:00, 84.31it/s, loss=0.00403]  


Epoch 34/35, Avg Loss: 0.0186


Epoch 35/35: 100%|██████████| 938/938 [00:16<00:00, 55.73it/s, loss=0.0427]  

Epoch 35/35, Avg Loss: 0.0180





In [19]:
# -----------------------------
# 5. EVALUATION
# -----------------------------
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 98.20%


In [20]:
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# Gradio sketchpad returns a full-color NumPy array (H,W,3).
# MNIST images are grayscale (1x28x28) and normalized.
# This preprocessing converts user drawings into MNIST format.
# -----------------------------

# MNIST normalization values:
MNIST_MEAN = (0.1307,)
MNIST_STD  = (0.3081,)

def preprocess_image(image):
    """Convert Gradio Sketchpad output to a normalized 1x28x28 tensor."""

    # Gradio may pass {'composite': array}
    if isinstance(image, dict) and "composite" in image:
        image = image["composite"]

    # Define preprocessing pipeline
    sketch_transform = transforms.Compose([
        transforms.ToPILImage(),                      # NumPy → PIL
        transforms.Grayscale(num_output_channels=1),  # Convert to 1 channel
        transforms.Resize((28, 28)),                  # Match MNIST input
        transforms.Lambda(lambda img: ImageOps.invert(img)),  
        transforms.ToTensor(),                        # → (1, 28, 28), values in [0,1]
        transforms.Normalize(MNIST_MEAN, MNIST_STD),  # Match MNIST training normalization
    ])

    tensor = sketch_transform(image)                  # Shape: (1, 28, 28)
    tensor = tensor.unsqueeze(0)                      # Shape: (1, 1, 28, 28)
    return tensor.to(device)                          # Move to same device as model


def predict_digit(image):
    """Take raw Sketchpad input → return predicted digit + confidence."""
    
    if image is None:
        return "Draw something!"

    # Convert to model input format
    input_tensor = preprocess_image(image)

    # Ensure model is in eval mode
    model.eval()
    
    with torch.no_grad():
        logits = model(input_tensor)
        
        # turn logits into probabilities
        probs = torch.softmax(logits, dim=1)

        # predicted class index
        predicted_class = torch.argmax(probs, dim=1).item()

        # confidence of that class
        confidence = probs[0, predicted_class].item()

    # return nicely formatted output
    return f"{predicted_class}  ({confidence * 100:.2f}% confidence)"



# -----------------------------
# GRADIO UI
# -----------------------------
interface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(label="Draw Here"),
    outputs="label",
    live=False,
)

interface.queue().launch()


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


