In [3]:
# -----------------------------
# MNIST DIGIT CLASSIFIER (PyTorch)
# -----------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');
print(f'Using device: {device}');

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [4]:
# -----------------------------
# 1. LOAD DATA
# Transforms are preprocessing steps that get applied automatically to every image
# you load from a dataset. 
# Think of transforms as a recipe that says:

# “Every time you give me an image, do X, then Y, then Z to it.”
# “For every MNIST image: convert it to a PyTorch tensor.
# MNIST images come in as PIL images (Python Imaging Library).

# But your neural network expects tensors.
# -----------------------------
transform_train = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [5]:
# Load train dataset
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform_train,
    download=True
)

In [6]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform_test,
    download=True
)

In [7]:
# Make DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

unique_labels = sorted(set(train_dataset.targets.tolist()))
print('Unique labels in training dataset:', unique_labels)

Unique labels in training dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [8]:
# -----------------------------
# 2. DEFINE NEURAL NETWORK
# -----------------------------
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(784, 128)
        self.norm = nn.LayerNorm(128)
        self.fc2 = nn.Linear(128, 10)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.norm(self.fc1(x)))
        x = self.fc2(x)
        return x
    
model = SimpleNN().to(device)
print(model)

SimpleNN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [9]:
# -----------------------------
# 2. DEFINE CNN MODEL (better/alternate method that uses a convolutional neural network)
# The reason that we would want to use a CNN over a NN is that CNNs are just better at image related tasks, so we should expect better performance from this model.
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),)
        self.fc = nn.Sequential(nn.Linear(64 * 14 * 14, 128), nn.ReLU(), nn.Linear(128, 10))
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
model = SimpleCNN().to(device)
print(model)

SimpleCNN(
  (conv_layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=12544, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [10]:
#squeeze excitement block --> attention mechanism that helps to improve performance by focusing on the most important features from the convolutional layers
#this is especially useful in image classification tasks where certain features are more relevant for distinguishing between classes like numbers in MNIST
class SEBlock(nn.Module):
    def __init__(self, c, r=8):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), #global avg pooling to create channel descriptors to capture global spatial information
            nn.Conv2d(c, c // r, 1), #reduce channel dimensionality for higher computational efficiency
            nn.ReLU(inplace=True), #activation function (relu)
            nn.Conv2d(c // r, c, 1), #restore channel dimensionality to original size
            nn.Sigmoid(), #squash values to 0 and 1 to represent importance weights
        )

    def forward(self, x):
        w = self.se(x) #get importance weights of each channel
        return x * w #apply the weights to the input feature map to emphasize important features and suppress less important ones;


class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), #conv layer 1
            nn.BatchNorm2d(32), #norm layer to stabilize and speed up training a little
            nn.ReLU(), #activation func
            SEBlock(32), #sqz excite block of 32 for focus
            nn.Conv2d(32, 64, 3, padding=1), #conv layer 2
            nn.BatchNorm2d(64), #norm layer
            nn.ReLU(), #activation func
            SEBlock(64), #sqz excite block with 64 chann
            nn.MaxPool2d(2), #downsample feature maps to reduce spatial dimensions and computation load improving efficiency
            nn.Dropout(0.2), #dropout to prevent some overfitting (might need tuning later)
            nn.Conv2d(64, 128, 3, padding=1), #conv layer 3
            nn.BatchNorm2d(128), #norm layer
            nn.ReLU(), #activation func
            SEBlock(128), #sqz excite block with 128 chann
            nn.Conv2d(128, 128, 3, padding=1), #conv layer 4
            nn.BatchNorm2d(128), #norm layer (its 128 bc output of lsat conv was 128 chann)
            nn.ReLU(), #activation func
            SEBlock(128), #sqz excite block with 128 chann
            nn.MaxPool2d(2), #downsample feature maps again for efficiency
            nn.Dropout(0.3), #dropout to stop more overfitting
        )
        self.classifier = nn.Sequential( #this whole chunk is just to classify based on the features taken from conv layers
            nn.Flatten(), #flatten the 2d feat map into a 1d vec
            nn.Linear(128 * 7 * 7, 512), #fully conn layer is 128*7*7 bc last maxpool made the 28x28 img to 7x7
            nn.ReLU(), #activ func 
            nn.Dropout(0.5), #drop to not overfit
            nn.Linear(512, 10), #final out layer for 0-9 digits;
        )
        self.apply(self._init) #init ws

    def _init(self, m): #w init func
        if isinstance(m, (nn.Conv2d, nn.Linear)): #if conv or lin layer
            nn.init.kaiming_normal_(m.weight) #he init for relu
            if m.bias is not None:
                nn.init.zeros_(m.bias) #bias init to 0 if exists

    def forward(self, x):
        return self.classifier(self.features(x)) #pass through the feature extractor then the classifier on model eval
    
model = MNISTNet() #model inst
model.load_state_dict(torch.load("99.68%.pth", map_location="cpu")) #call load model using above architecture and saved .pth file on local
model.eval() #set to eval mode

MNISTNet(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): SEBlock(
      (se): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
        (2): ReLU(inplace=True)
        (3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
        (4): Sigmoid()
      )
    )
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): SEBlock(
      (se): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
        (2): ReLU(inplace=True)
        (3): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (4): Sigmoid()
      )
    )
    (8): MaxPool2d(kernel_size=2, stride=2,

In [10]:
# -----------------------------
# 3. LOSS FUNCTION + OPTIMIZER
# -----------------------------
criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [11]:
# -----------------------------
# 4. TRAINING LOOP
# -----------------------------

epochs = 35
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * images.size(0)
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
Epoch 1/35: 100%|██████████| 938/938 [01:24<00:00, 11.09it/s, loss=0.0127] 


Epoch 1/35, Avg Loss: 0.2162


Epoch 2/35: 100%|██████████| 938/938 [01:17<00:00, 12.07it/s, loss=0.0628] 


Epoch 2/35, Avg Loss: 0.0580


Epoch 3/35: 100%|██████████| 938/938 [01:16<00:00, 12.21it/s, loss=0.00112] 


Epoch 3/35, Avg Loss: 0.0412


Epoch 4/35: 100%|██████████| 938/938 [-148:57:01<00:00, -0.00it/s, loss=0.00508] 


Epoch 4/35, Avg Loss: 0.0336


Epoch 5/35: 100%|██████████| 938/938 [01:53<00:00,  8.27it/s, loss=0.0021]  


Epoch 5/35, Avg Loss: 0.0280


Epoch 6/35: 100%|██████████| 938/938 [01:49<00:00,  8.57it/s, loss=0.00147] 


Epoch 6/35, Avg Loss: 0.0233


Epoch 7/35: 100%|██████████| 938/938 [01:55<00:00,  8.10it/s, loss=0.0516]  


Epoch 7/35, Avg Loss: 0.0226


Epoch 8/35: 100%|██████████| 938/938 [01:53<00:00,  8.23it/s, loss=0.0214]  


Epoch 8/35, Avg Loss: 0.0184


Epoch 9/35: 100%|██████████| 938/938 [01:51<00:00,  8.39it/s, loss=0.0578]  


Epoch 9/35, Avg Loss: 0.0176


Epoch 10/35: 100%|██████████| 938/938 [01:57<00:00,  7.98it/s, loss=0.00303] 


Epoch 10/35, Avg Loss: 0.0168


Epoch 11/35: 100%|██████████| 938/938 [01:51<00:00,  8.39it/s, loss=0.00141] 


Epoch 11/35, Avg Loss: 0.0143


Epoch 12/35: 100%|██████████| 938/938 [01:56<00:00,  8.04it/s, loss=0.000108]


Epoch 12/35, Avg Loss: 0.0132


Epoch 13/35: 100%|██████████| 938/938 [01:58<00:00,  7.94it/s, loss=0.000179]


Epoch 13/35, Avg Loss: 0.0117


Epoch 14/35: 100%|██████████| 938/938 [01:59<00:00,  7.88it/s, loss=4.02e-5] 


Epoch 14/35, Avg Loss: 0.0121


Epoch 15/35: 100%|██████████| 938/938 [01:52<00:00,  8.33it/s, loss=1.69e-5] 


Epoch 15/35, Avg Loss: 0.0104


Epoch 16/35: 100%|██████████| 938/938 [01:54<00:00,  8.20it/s, loss=0.00078] 


Epoch 16/35, Avg Loss: 0.0090


Epoch 17/35: 100%|██████████| 938/938 [01:59<00:00,  7.86it/s, loss=0.00832] 


Epoch 17/35, Avg Loss: 0.0099


Epoch 18/35: 100%|██████████| 938/938 [01:43<00:00,  9.07it/s, loss=0.000463]


Epoch 18/35, Avg Loss: 0.0089


Epoch 19/35: 100%|██████████| 938/938 [02:03<00:00,  7.58it/s, loss=8.49e-6] 


Epoch 19/35, Avg Loss: 0.0077


Epoch 20/35: 100%|██████████| 938/938 [01:54<00:00,  8.17it/s, loss=8.53e-5] 


Epoch 20/35, Avg Loss: 0.0078


Epoch 21/35: 100%|██████████| 938/938 [02:10<00:00,  7.16it/s, loss=0.00407] 


Epoch 21/35, Avg Loss: 0.0086


Epoch 22/35: 100%|██████████| 938/938 [01:57<00:00,  7.98it/s, loss=0.000281]


Epoch 22/35, Avg Loss: 0.0070


Epoch 23/35: 100%|██████████| 938/938 [01:42<00:00,  9.13it/s, loss=0.00289] 


Epoch 23/35, Avg Loss: 0.0075


Epoch 24/35: 100%|██████████| 938/938 [02:09<00:00,  7.23it/s, loss=1.23e-5] 


Epoch 24/35, Avg Loss: 0.0060


Epoch 25/35: 100%|██████████| 938/938 [02:58<00:00,  5.26it/s, loss=4.56e-5] 


Epoch 25/35, Avg Loss: 0.0056


Epoch 26/35: 100%|██████████| 938/938 [03:13<00:00,  4.86it/s, loss=0.000193]


Epoch 26/35, Avg Loss: 0.0070


Epoch 27/35: 100%|██████████| 938/938 [03:19<00:00,  4.69it/s, loss=8.08e-6] 


Epoch 27/35, Avg Loss: 0.0063


Epoch 28/35: 100%|██████████| 938/938 [03:03<00:00,  5.11it/s, loss=1.53e-6] 


Epoch 28/35, Avg Loss: 0.0050


Epoch 29/35: 100%|██████████| 938/938 [1:40:36<00:00,  6.43s/it, loss=0.000132]   


Epoch 29/35, Avg Loss: 0.0052


Epoch 30/35: 100%|██████████| 938/938 [01:40<00:00,  9.32it/s, loss=7.32e-5] 


Epoch 30/35, Avg Loss: 0.0062


Epoch 31/35: 100%|██████████| 938/938 [02:25<00:00,  6.43it/s, loss=3.51e-5] 


Epoch 31/35, Avg Loss: 0.0051


Epoch 32/35: 100%|██████████| 938/938 [01:44<00:00,  9.00it/s, loss=3.26e-6] 


Epoch 32/35, Avg Loss: 0.0045


Epoch 33/35: 100%|██████████| 938/938 [02:41<00:00,  5.80it/s, loss=2.16e-7] 


Epoch 33/35, Avg Loss: 0.0050


Epoch 34/35: 100%|██████████| 938/938 [02:57<00:00,  5.28it/s, loss=2.14e-5] 


Epoch 34/35, Avg Loss: 0.0053


Epoch 35/35: 100%|██████████| 938/938 [02:38<00:00,  5.91it/s, loss=0.0297]  

Epoch 35/35, Avg Loss: 0.0063





In [1]:
# -----------------------------
# 5. EVALUATION
# -----------------------------
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


NameError: name 'model' is not defined

In [27]:
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# Gradio sketchpad returns a full-color NumPy array (H,W,3).
# MNIST images are grayscale (1x28x28) and normalized.
# This preprocessing converts user drawings into MNIST format.
# -----------------------------

# MNIST normalization values:
MNIST_MEAN = (0.1307,)
MNIST_STD  = (0.3081,)

def preprocess_image(image):
    """Convert Gradio Sketchpad output to a normalized 1x28x28 tensor."""

    # Gradio may pass {'composite': array}
    if isinstance(image, dict) and "composite" in image:
        image = image["composite"]

    # Define preprocessing pipeline
    sketch_transform = transforms.Compose([
        transforms.ToPILImage(),                      # NumPy → PIL
        transforms.Grayscale(num_output_channels=1),  # Convert to 1 channel
        transforms.Resize((28, 28)),                  # Match MNIST input
        transforms.Lambda(lambda img: ImageOps.invert(img)),  
        transforms.ToTensor(),                        # → (1, 28, 28), values in [0,1]
        transforms.Normalize(MNIST_MEAN, MNIST_STD),  # Match MNIST training normalization
    ])

    tensor = sketch_transform(image)                  # Shape: (1, 28, 28)
    tensor = tensor.unsqueeze(0)                      # Shape: (1, 1, 28, 28)
    return tensor.to(device)                          # Move to same device as model


def predict_digit(image):
    """Take raw Sketchpad input → return predicted digit + confidence."""
    
    if image is None:
        return "Draw something!"

    # Convert to model input format
    input_tensor = preprocess_image(image)

    # Ensure model is in eval mode
    model.eval()
    
    with torch.no_grad():
        logits = model(input_tensor)
        
        # turn logits into probabilities
        probs = torch.softmax(logits, dim=1)

        # predicted class index
        predicted_class = torch.argmax(probs, dim=1).item()

        # confidence of that class
        confidence = probs[0, predicted_class].item()

    # return nicely formatted output
    return f"{predicted_class}  ({confidence * 100:.2f}% confidence)"



# -----------------------------
# GRADIO UI
# -----------------------------
interface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Sketchpad(label="Draw Here"),
    outputs="label",
    live=False,
)

interface.queue().launch()


* Running on local URL:  http://127.0.0.1:7871
* To create a public link, set `share=True` in `launch()`.


