In [69]:
# -----------------------------
# MNIST DIGIT CLASSIFIER (PyTorch)
# -----------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [51]:
# -----------------------------
# 1. LOAD DATA
# Transforms are preprocessing steps that get applied automatically to every image
# you load from a dataset. 
# Think of transforms as a recipe that says:

# “Every time you give me an image, do X, then Y, then Z to it.”
# “For every MNIST image: convert it to a PyTorch tensor.
# MNIST images come in as PIL images (Python Imaging Library).

# But your neural network expects tensors.
# -----------------------------
transform_train = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [52]:
# Load train dataset
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform_train,
    download=True
)

In [53]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform_test,
    download=True
)

In [54]:
# Make DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

unique_labels = sorted(set(train_dataset.targets.tolist()))
print('Unique labels in training dataset:', unique_labels)

Unique labels in training dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [60]:
# -----------------------------
# 2. DEFINE NEURAL NETWORK
# -----------------------------
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 128)
        self.norm = nn.LayerNorm(128)
        self.fc2 = nn.Linear(128, 10)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.norm(self.fc1(x)))
        x = self.fc2(x)
        return x


In [71]:
# -----------------------------
# 2. DEFINE CNN MODEL
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),)
        self.fc = nn.Sequential(nn.Linear(64 * 14 * 14, 128), nn.ReLU(), nn.Linear(128, 10))
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [73]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');
#model = SimpleNN().to(device);
model = SimpleCNN().to(device);
print(f'Using device: {device}');

Using device: cpu


In [74]:
# -----------------------------
# 3. LOSS FUNCTION + OPTIMIZER
# -----------------------------
criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [77]:
# -----------------------------
# 4. TRAINING LOOP
# -----------------------------

epochs = 50
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
Epoch 1/50: 100%|██████████| 938/938 [00:41<00:00, 22.69it/s, loss=0.0582]  
Epoch 1/50: 100%|██████████| 938/938 [00:41<00:00, 22.69it/s, loss=0.0582]


Epoch 1/50, Avg Loss: 0.0784


Epoch 2/50: 100%|██████████| 938/938 [02:04<00:00,  7.53it/s, loss=0.405]  
Epoch 2/50: 100%|██████████| 938/938 [02:04<00:00,  7.53it/s, loss=0.405] 


Epoch 2/50, Avg Loss: 0.0483


Epoch 3/50: 100%|██████████| 938/938 [00:41<00:00, 22.77it/s, loss=0.0484]  
Epoch 3/50: 100%|██████████| 938/938 [00:41<00:00, 22.77it/s, loss=0.0484]


Epoch 3/50, Avg Loss: 0.0383


Epoch 4/50: 100%|██████████| 938/938 [01:07<00:00, 13.82it/s, loss=0.00118] 
Epoch 4/50: 100%|██████████| 938/938 [01:07<00:00, 13.82it/s, loss=0.00118]


Epoch 4/50, Avg Loss: 0.0319


Epoch 5/50: 100%|██████████| 938/938 [00:41<00:00, 22.87it/s, loss=0.0101]  
Epoch 5/50: 100%|██████████| 938/938 [00:41<00:00, 22.87it/s, loss=0.0101]


Epoch 5/50, Avg Loss: 0.0267


Epoch 6/50: 100%|██████████| 938/938 [02:24<00:00,  6.50it/s, loss=0.000541] 
Epoch 6/50: 100%|██████████| 938/938 [02:24<00:00,  6.50it/s, loss=0.000541]


Epoch 6/50, Avg Loss: 0.0223


Epoch 7/50: 100%|██████████| 938/938 [00:43<00:00, 21.74it/s, loss=0.000143]
Epoch 7/50: 100%|██████████| 938/938 [00:43<00:00, 21.74it/s, loss=0.000143]


Epoch 7/50, Avg Loss: 0.0193


Epoch 8/50: 100%|██████████| 938/938 [00:43<00:00, 21.63it/s, loss=0.135]   
Epoch 8/50: 100%|██████████| 938/938 [00:43<00:00, 21.63it/s, loss=0.135] 


Epoch 8/50, Avg Loss: 0.0188


Epoch 9/50: 100%|██████████| 938/938 [00:43<00:00, 21.40it/s, loss=0.00029] 
Epoch 9/50: 100%|██████████| 938/938 [00:43<00:00, 21.40it/s, loss=0.00029]


Epoch 9/50, Avg Loss: 0.0153


Epoch 10/50: 100%|██████████| 938/938 [00:43<00:00, 21.33it/s, loss=0.00271] 
Epoch 10/50: 100%|██████████| 938/938 [00:43<00:00, 21.33it/s, loss=0.00271]


Epoch 10/50, Avg Loss: 0.0134


Epoch 11/50: 100%|██████████| 938/938 [00:44<00:00, 21.04it/s, loss=0.00236] 
Epoch 11/50: 100%|██████████| 938/938 [00:44<00:00, 21.04it/s, loss=0.00236]


Epoch 11/50, Avg Loss: 0.0129


Epoch 12/50: 100%|██████████| 938/938 [00:40<00:00, 22.91it/s, loss=0.0042]  
Epoch 12/50: 100%|██████████| 938/938 [00:40<00:00, 22.91it/s, loss=0.0042] 


Epoch 12/50, Avg Loss: 0.0129


Epoch 13/50: 100%|██████████| 938/938 [00:40<00:00, 22.99it/s, loss=0.000106]
Epoch 13/50: 100%|██████████| 938/938 [00:40<00:00, 22.99it/s, loss=0.000106]


Epoch 13/50, Avg Loss: 0.0123


Epoch 14/50: 100%|██████████| 938/938 [01:46<00:00,  8.78it/s, loss=0.0216]   
Epoch 14/50: 100%|██████████| 938/938 [01:46<00:00,  8.78it/s, loss=0.0216] 


Epoch 14/50, Avg Loss: 0.0098


Epoch 15/50: 100%|██████████| 938/938 [01:30<00:00, 10.42it/s, loss=9.35e-5]  
Epoch 15/50: 100%|██████████| 938/938 [01:30<00:00, 10.42it/s, loss=9.35e-5]


Epoch 15/50, Avg Loss: 0.0097


Epoch 16/50: 100%|██████████| 938/938 [00:43<00:00, 21.62it/s, loss=0.0305]  
Epoch 16/50: 100%|██████████| 938/938 [00:43<00:00, 21.62it/s, loss=0.0305]  


Epoch 16/50, Avg Loss: 0.0097


Epoch 17/50: 100%|██████████| 938/938 [00:43<00:00, 21.67it/s, loss=0.004]   
Epoch 17/50: 100%|██████████| 938/938 [00:43<00:00, 21.67it/s, loss=0.004]  


Epoch 17/50, Avg Loss: 0.0074


Epoch 18/50: 100%|██████████| 938/938 [00:44<00:00, 21.30it/s, loss=6.14e-6] 
Epoch 18/50: 100%|██████████| 938/938 [00:44<00:00, 21.30it/s, loss=6.14e-6]


Epoch 18/50, Avg Loss: 0.0078


Epoch 19/50: 100%|██████████| 938/938 [00:46<00:00, 20.21it/s, loss=0.000621]
Epoch 19/50: 100%|██████████| 938/938 [00:46<00:00, 20.21it/s, loss=0.000621]


Epoch 19/50, Avg Loss: 0.0086


Epoch 20/50: 100%|██████████| 938/938 [00:44<00:00, 21.06it/s, loss=1.27e-5] 
Epoch 20/50: 100%|██████████| 938/938 [00:44<00:00, 21.06it/s, loss=1.27e-5]


Epoch 20/50, Avg Loss: 0.0071


Epoch 21/50: 100%|██████████| 938/938 [00:43<00:00, 21.47it/s, loss=3.93e-5] 
Epoch 21/50: 100%|██████████| 938/938 [00:43<00:00, 21.47it/s, loss=3.93e-5] 


Epoch 21/50, Avg Loss: 0.0073


Epoch 22/50: 100%|██████████| 938/938 [00:45<00:00, 20.52it/s, loss=4.17e-6] 
Epoch 22/50: 100%|██████████| 938/938 [00:45<00:00, 20.52it/s, loss=4.17e-6] 


Epoch 22/50, Avg Loss: 0.0062


Epoch 23/50: 100%|██████████| 938/938 [00:44<00:00, 21.03it/s, loss=1.02e-6] 
Epoch 23/50: 100%|██████████| 938/938 [00:44<00:00, 21.03it/s, loss=1.02e-6]


Epoch 23/50, Avg Loss: 0.0080


Epoch 24/50: 100%|██████████| 938/938 [01:14<00:00, 12.58it/s, loss=1.05e-5] 
Epoch 24/50: 100%|██████████| 938/938 [01:14<00:00, 12.58it/s, loss=1.05e-5] 


Epoch 24/50, Avg Loss: 0.0070


Epoch 25/50: 100%|██████████| 938/938 [00:42<00:00, 21.87it/s, loss=1.42e-5] 
Epoch 25/50: 100%|██████████| 938/938 [00:42<00:00, 21.87it/s, loss=1.42e-5] 


Epoch 25/50, Avg Loss: 0.0063


Epoch 26/50: 100%|██████████| 938/938 [00:43<00:00, 21.49it/s, loss=0.225]   
Epoch 26/50: 100%|██████████| 938/938 [00:43<00:00, 21.49it/s, loss=0.225]  


Epoch 26/50, Avg Loss: 0.0059


Epoch 27/50: 100%|██████████| 938/938 [00:45<00:00, 20.64it/s, loss=2.05e-6] 
Epoch 27/50: 100%|██████████| 938/938 [00:45<00:00, 20.64it/s, loss=2.05e-6]


Epoch 27/50, Avg Loss: 0.0058


Epoch 28/50: 100%|██████████| 938/938 [00:48<00:00, 19.34it/s, loss=3.78e-6] 
Epoch 28/50: 100%|██████████| 938/938 [00:48<00:00, 19.34it/s, loss=3.78e-6]


Epoch 28/50, Avg Loss: 0.0051


Epoch 29/50: 100%|██████████| 938/938 [00:43<00:00, 21.58it/s, loss=0.0308]  
Epoch 29/50: 100%|██████████| 938/938 [00:43<00:00, 21.58it/s, loss=0.0308]  


Epoch 29/50, Avg Loss: 0.0062


Epoch 30/50: 100%|██████████| 938/938 [00:43<00:00, 21.76it/s, loss=8.97e-5] 
Epoch 30/50: 100%|██████████| 938/938 [00:43<00:00, 21.76it/s, loss=8.97e-5] 


Epoch 30/50, Avg Loss: 0.0030


Epoch 31/50: 100%|██████████| 938/938 [00:43<00:00, 21.66it/s, loss=0.00172] 
Epoch 31/50: 100%|██████████| 938/938 [00:43<00:00, 21.66it/s, loss=0.00172]


Epoch 31/50, Avg Loss: 0.0055


Epoch 32/50: 100%|██████████| 938/938 [00:43<00:00, 21.75it/s, loss=1.92e-6] 
Epoch 32/50: 100%|██████████| 938/938 [00:43<00:00, 21.75it/s, loss=1.92e-6]


Epoch 32/50, Avg Loss: 0.0063


Epoch 33/50: 100%|██████████| 938/938 [00:42<00:00, 22.23it/s, loss=2.82e-5] 
Epoch 33/50: 100%|██████████| 938/938 [00:42<00:00, 22.23it/s, loss=2.82e-5]


Epoch 33/50, Avg Loss: 0.0041


Epoch 34/50: 100%|██████████| 938/938 [00:43<00:00, 21.78it/s, loss=9.26e-6] 
Epoch 34/50: 100%|██████████| 938/938 [00:43<00:00, 21.78it/s, loss=9.26e-6]


Epoch 34/50, Avg Loss: 0.0056


Epoch 35/50: 100%|██████████| 938/938 [00:44<00:00, 21.05it/s, loss=3.73e-9] 
Epoch 35/50: 100%|██████████| 938/938 [00:44<00:00, 21.05it/s, loss=3.73e-9] 


Epoch 35/50, Avg Loss: 0.0060


Epoch 36/50: 100%|██████████| 938/938 [00:44<00:00, 20.87it/s, loss=5.63e-7] 
Epoch 36/50: 100%|██████████| 938/938 [00:44<00:00, 20.87it/s, loss=5.63e-7]


Epoch 36/50, Avg Loss: 0.0056


Epoch 37/50: 100%|██████████| 938/938 [00:44<00:00, 21.09it/s, loss=3.73e-9] 
Epoch 37/50: 100%|██████████| 938/938 [00:44<00:00, 21.09it/s, loss=3.73e-9]


Epoch 37/50, Avg Loss: 0.0046


Epoch 38/50: 100%|██████████| 938/938 [00:45<00:00, 20.56it/s, loss=1.98e-6] 
Epoch 38/50: 100%|██████████| 938/938 [00:45<00:00, 20.56it/s, loss=1.98e-6]


Epoch 38/50, Avg Loss: 0.0052


Epoch 39/50: 100%|██████████| 938/938 [00:46<00:00, 20.31it/s, loss=3.15e-5] 
Epoch 39/50: 100%|██████████| 938/938 [00:46<00:00, 20.31it/s, loss=3.15e-5] 


Epoch 39/50, Avg Loss: 0.0026


Epoch 40/50: 100%|██████████| 938/938 [00:44<00:00, 21.11it/s, loss=1.49e-8] 
Epoch 40/50: 100%|██████████| 938/938 [00:44<00:00, 21.11it/s, loss=1.49e-8]


Epoch 40/50, Avg Loss: 0.0058


Epoch 41/50: 100%|██████████| 938/938 [00:47<00:00, 19.95it/s, loss=0.0333]  
Epoch 41/50: 100%|██████████| 938/938 [00:47<00:00, 19.95it/s, loss=0.0333] 


Epoch 41/50, Avg Loss: 0.0043


Epoch 42/50: 100%|██████████| 938/938 [00:42<00:00, 22.21it/s, loss=1.27e-5] 
Epoch 42/50: 100%|██████████| 938/938 [00:42<00:00, 22.21it/s, loss=1.27e-5]


Epoch 42/50, Avg Loss: 0.0041


Epoch 43/50: 100%|██████████| 938/938 [00:41<00:00, 22.53it/s, loss=6.67e-7] 
Epoch 43/50: 100%|██████████| 938/938 [00:41<00:00, 22.53it/s, loss=6.67e-7]


Epoch 43/50, Avg Loss: 0.0055


Epoch 44/50: 100%|██████████| 938/938 [00:43<00:00, 21.70it/s, loss=5.38e-6] 
Epoch 44/50: 100%|██████████| 938/938 [00:43<00:00, 21.70it/s, loss=5.38e-6]


Epoch 44/50, Avg Loss: 0.0034


Epoch 45/50: 100%|██████████| 938/938 [00:43<00:00, 21.69it/s, loss=6e-7]    
Epoch 45/50: 100%|██████████| 938/938 [00:43<00:00, 21.69it/s, loss=6e-7]  


Epoch 45/50, Avg Loss: 0.0033


Epoch 46/50: 100%|██████████| 938/938 [00:41<00:00, 22.52it/s, loss=0.000428]
Epoch 46/50: 100%|██████████| 938/938 [00:41<00:00, 22.52it/s, loss=0.000428]


Epoch 46/50, Avg Loss: 0.0039


Epoch 47/50: 100%|██████████| 938/938 [00:43<00:00, 21.69it/s, loss=1.49e-7] 
Epoch 47/50: 100%|██████████| 938/938 [00:43<00:00, 21.69it/s, loss=1.49e-7] 


Epoch 47/50, Avg Loss: 0.0041


Epoch 48/50: 100%|██████████| 938/938 [00:42<00:00, 22.01it/s, loss=3.73e-9] 
Epoch 48/50: 100%|██████████| 938/938 [00:42<00:00, 22.01it/s, loss=3.73e-9]


Epoch 48/50, Avg Loss: 0.0051


Epoch 49/50: 100%|██████████| 938/938 [00:43<00:00, 21.54it/s, loss=0]       
Epoch 49/50: 100%|██████████| 938/938 [00:43<00:00, 21.54it/s, loss=0]       


Epoch 49/50, Avg Loss: 0.0028


Epoch 50/50: 100%|██████████| 938/938 [00:45<00:00, 20.62it/s, loss=5.59e-8] 

Epoch 50/50, Avg Loss: 0.0039





In [79]:
# -----------------------------
# 5. EVALUATION
# -----------------------------
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 99.18%


In [80]:
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# ------------------------------
# Gradio Sketchpad gives you:

# * a full-color NumPy array

# * black digit on white background

# * large resolution

# * no consistent scale
#
# Hence the preprocessing
# ------------------------------

def preprocess_image(image):
    sketch_transform = transforms.Compose([
    transforms.ToPILImage(),                      # NumPy → PIL
    transforms.Grayscale(),                       # ensure 1 channel
    transforms.Resize((28, 28)),                  # 28x28 like MNIST
    transforms.Lambda(lambda img: ImageOps.invert(img)),  # invert colors
    transforms.ToTensor(),                        # → tensor, shape (1,28,28), values in [0,1]
    ])
    # Gradio Sketchpad sometimes passes a dict with 'composite'
    if isinstance(image, dict):
        image = image['composite']   # this is a NumPy array
    
    # Apply the preprocessing transform
    img_tensor = sketch_transform(image)  # (1, 28, 28)
    
    # Add batch dimension → (1, 1, 28, 28)
    img_tensor = img_tensor.unsqueeze(0)

    return img_tensor

def predict_digit(image):
    # --- STEP 1: CHECK IF SOMETHING HAS BEEN DRAWN ---
    if image is None: return "Draw something!"

    # --- STEP 2: PREPROCESS THE IMAGE ---
    img_tensor = preprocess_image(image)
    
    # --- STEP 3: RUN THE MODEL ---
    with torch.no_grad():
        prediction = model(img_tensor)
        
        # Get the index of the highest score (the predicted digit)
        predicted_digit = torch.argmax(prediction).item()
        
    return str(predicted_digit)

# UI Setup
interface = gr.Interface(fn=predict_digit, inputs=gr.Sketchpad(label="Draw Here"), outputs="label")
interface.queue().launch()

* Running on local URL:  http://127.0.0.1:7862
* To create a public link, set `share=True` in `launch()`.


