In [3]:
# -----------------------------
# MNIST DIGIT CLASSIFIER (PyTorch)
# -----------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
# -----------------------------
# 1. LOAD DATA
# Transforms are preprocessing steps that get applied automatically to every image
# you load from a dataset. 
# Think of transforms as a recipe that says:

# “Every time you give me an image, do X, then Y, then Z to it.”
# “For every MNIST image: convert it to a PyTorch tensor.
# MNIST images come in as PIL images (Python Imaging Library).

# But your neural network expects tensors.
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
# Load training dataset (MNIST)
train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform,
    download=True
)


100.0%
100.0%
100.0%
100.0%


In [6]:
# Load test dataset
test_dataset = datasets.MNIST(
    root="./data",
    train=False,
    transform=transform,
    download=True
)


In [9]:
# Make DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

unique_labels = sorted(set(train_dataset.targets.tolist()))
print('Unique labels in training dataset:', unique_labels)

Unique labels in training dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [10]:
# -----------------------------
# 2. DEFINE NEURAL NETWORK
# Neural Network with 1 hidden layer of 128 neurons
# -----------------------------
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128);
        self.fc2 = nn.Linear(128, 10);

    def forward(self, x):
        # Flatten image: (batch, 1, 28, 28) → (batch, 784)
        x = x.view(-1, 28*28);
        x = torch.relu(self.fc1(x));
        x = self.fc2(x);
        return x;

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');
model = SimpleNN().to(device);
print(f'Using device: {device}');

Using device: cpu


In [12]:
# -----------------------------
# 3. LOSS FUNCTION + OPTIMIZER
# -----------------------------
criterion = nn.CrossEntropyLoss();
optimizer = optim.Adam(model.parameters(), lr=1e-3);

In [16]:
# -----------------------------
# 4. TRAINING LOOP
# -----------------------------

epochs = 20;

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device);
        labels = labels.to(device);
        
        optimizer.zero_grad();

        outputs = model(images);

        loss = criterion(outputs, labels);

        loss.backward(); #backprop

        optimizer.step(); #update the weights based on the gradients

        total_loss+=loss.item()*images.size(0); #this is to get the total loss for the epoch

    avg_loss = total_loss/len(train_loader.dataset) #avg loss for that epoch
    print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}") #print above


Epoch 1/20, Avg Loss: 0.0062
Epoch 2/20, Avg Loss: 0.0047
Epoch 2/20, Avg Loss: 0.0047
Epoch 3/20, Avg Loss: 0.0014
Epoch 3/20, Avg Loss: 0.0014
Epoch 4/20, Avg Loss: 0.0040
Epoch 4/20, Avg Loss: 0.0040
Epoch 5/20, Avg Loss: 0.0018
Epoch 5/20, Avg Loss: 0.0018
Epoch 6/20, Avg Loss: 0.0004
Epoch 6/20, Avg Loss: 0.0004
Epoch 7/20, Avg Loss: 0.0002
Epoch 7/20, Avg Loss: 0.0002
Epoch 8/20, Avg Loss: 0.0002
Epoch 8/20, Avg Loss: 0.0002
Epoch 9/20, Avg Loss: 0.0021
Epoch 9/20, Avg Loss: 0.0021
Epoch 10/20, Avg Loss: 0.0102
Epoch 10/20, Avg Loss: 0.0102
Epoch 11/20, Avg Loss: 0.0007
Epoch 11/20, Avg Loss: 0.0007
Epoch 12/20, Avg Loss: 0.0021
Epoch 12/20, Avg Loss: 0.0021
Epoch 13/20, Avg Loss: 0.0044
Epoch 13/20, Avg Loss: 0.0044
Epoch 14/20, Avg Loss: 0.0022
Epoch 14/20, Avg Loss: 0.0022
Epoch 15/20, Avg Loss: 0.0030
Epoch 15/20, Avg Loss: 0.0030
Epoch 16/20, Avg Loss: 0.0013
Epoch 16/20, Avg Loss: 0.0013
Epoch 17/20, Avg Loss: 0.0041
Epoch 17/20, Avg Loss: 0.0041
Epoch 18/20, Avg Loss: 0.00

In [None]:
# -----------------------------
# 5. EVALUATION
# -----------------------------
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 90.36%


In [11]:
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# -----------------------------
# 6. TEST SINGLE PREDICTION
# -----------------------------
# ------------------------------
# Gradio Sketchpad gives you:

# * a full-color NumPy array

# * black digit on white background

# * large resolution

# * no consistent scale
#
# Hence the preprocessing
# ------------------------------

def preprocess_image(image):
    sketch_transform = transforms.Compose([
    transforms.ToPILImage(),                      # NumPy → PIL
    transforms.Grayscale(),                       # ensure 1 channel
    transforms.Resize((28, 28)),                  # 28x28 like MNIST
    transforms.Lambda(lambda img: ImageOps.invert(img)),  # invert colors
    transforms.ToTensor(),                        # → tensor, shape (1,28,28), values in [0,1]
    ])
    # Gradio Sketchpad sometimes passes a dict with 'composite'
    if isinstance(image, dict):
        image = image['composite']   # this is a NumPy array
    
    # Apply the preprocessing transform
    img_tensor = sketch_transform(image)  # (1, 28, 28)
    
    # Add batch dimension → (1, 1, 28, 28)
    img_tensor = img_tensor.unsqueeze(0)

    return img_tensor

def predict_digit(image):
    # --- STEP 1: CHECK IF SOMETHING HAS BEEN DRAWN ---
    if image is None: return "Draw something!"

    # --- STEP 2: PREPROCESS THE IMAGE ---
    img_tensor = preprocess_image(image)
    
    # --- STEP 3: RUN THE MODEL ---
    with torch.no_grad():
        prediction = model(img_tensor)
        
        # Get the index of the highest score (the predicted digit)
        predicted_digit = torch.argmax(prediction).item()
        
    return str(predicted_digit)

# UI Setup
interface = gr.Interface(fn=predict_digit, inputs=gr.Sketchpad(label="Draw Here"), outputs="label")
interface.queue().launch()

Actual label:     7
Predicted label:  7
