# **MNIST CNN to ONNX: Can I guess the digit you drew?**

This notebook trains a tiny CNN on **MNIST**, **exports to ONNX** (the Open Neural Network eXchange), and serves a **Gradio** canvas app that runs **inference with ONNX Runtime** (no PyTorch in the serving path).

**Flow**  
1. Install deps (PyTorch, Torchvision, ONNX, ONNX Runtime, Gradio)  
2. Train small CNN (few epochs) **or** skip if `models/mnist_cnn.onnx` already exists  
3. Export to `models/mnist_cnn.onnx` (+ optional int8 quantization)  
4. Launch Gradio app with a large canvas and **MNIST-style preprocessing** (center/scale/pad) and a big upscaled 28×28 preview

**Key Components**
1. Model Training: Trains a small CNN on MNIST (or skips if ONNX model exists)
2. ONNX Export: Converts PyTorch model to ONNX format for deployment
3. Interactive Interface: Gradio app with drawing canvas for digit recognition
4. MNIST-style Preprocessing: Converts user drawings to 28x28 format matching training data

**Features**
- Large drawing canvas (560x560) with brush/eraser tools
- Automatic preprocessing (invert, crop, resize, blur, center-of-mass alignment)
- Real-time inference using ONNX Runtime (no PyTorch dependency)
- Confidence scores and probability distribution display
- Adjustable preprocessing parameters (threshold, blur)
- Optional int8 quantization for model compression, faster inference


## 0) Install/verify dependencies.


In [None]:

# If your environment already has these, this cell will be quick. If versions update, restart kernel.
%pip -q install --upgrade torch torchvision gradio onnx onnxruntime ml-dtypes tqdm


In [1]:
pip uninstall onnxruntime



SyntaxError: invalid syntax (894336283.py, line 1)

## 1) Imports and setup


In [None]:
import os, json, random, math
from pathlib import Path
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import onnx, onnxruntime as ort

from PIL import Image, ImageOps, ImageFilter, ImageChops
from tqdm import tqdm

random.seed(0); np.random.seed(0); torch.manual_seed(0)

DATA_DIR = "data"
MODELS_DIR = Path("models"); MODELS_DIR.mkdir(parents=True, exist_ok=True)
ONNX_PATH = MODELS_DIR / "mnist_cnn.onnx"
META_PATH = MODELS_DIR / "meta.json"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE


## 2) Dataset


In [None]:

# We'll use augmentations to better match mouse-drawn digits
train_transform = transforms.Compose([
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

train_ds = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=train_transform)
test_ds  = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=test_transform)

BATCH_SIZE = 128
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(test_ds)


## 3) Model definition


In [None]:

class SmallCNN(nn.Module):
    '''Assume input is MNIST 1x28x28 (channel, height, width), in PyTorch NCHW format (batch,channels,height,width)'''
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 3x3 magnifying glass
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # downsample 28 to 14
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 3x3 magnifying glass
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # downsample 14 to 7
            nn.Flatten(),  # flatten 64x7x7 to 3136 vector of learned features
            nn.Linear(64 * 7 * 7, 128), # build feature map of learned features to class scores
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes),   # output layer, one score per class, produces 10 logits, no activation
        )
    def forward(self, x):
        return self.net(x)


## 4) Training utilities


In [None]:

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running = 0.0
    for x, y in tqdm(loader, desc="train", leave=False):   # tqdm = progress bars
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running += loss.item() * x.size(0)
    return running / len(loader.dataset)

@torch.no_grad()   # disable gradient calculation for faster evaluation, pure inference
def evaluate_acc(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    for x, y in tqdm(loader, desc="eval", leave=False):   # tqdm = progress bars
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total


## 5) Train (or skip if ONNX already exists and you want to reuse it)


In [None]:

SKIP_TRAINING_IF_ONNX_EXISTS = True

if SKIP_TRAINING_IF_ONNX_EXISTS and ONNX_PATH.exists():
    print("ONNX model already exists; skipping training. Delete it to retrain.")
else:
    EPOCHS = 5   # tune for accuracy/speed tradeoff
    LR = 1e-3    # small learning rate for stability, tune for speed/accuracy tradeoff

    model = SmallCNN().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(1, EPOCHS + 1):
        loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        acc = evaluate_acc(model, test_loader, DEVICE)
        print(f"Epoch {epoch}: loss={loss:.4f} | acc={acc:.4f}")
        best_acc = max(best_acc, acc)

    # Save PyTorch weights temporarily for export
    torch_ckpt_path = MODELS_DIR / "mnist_cnn.pt"
    torch.save(model.state_dict(), torch_ckpt_path)
    with open(META_PATH, "w") as f:
        json.dump({"num_classes": 10, "normalize_mean": 0.1307, "normalize_std": 0.3081}, f)
    print("Training complete. Best test acc:", best_acc)


## 6) Export to ONNX (if not present)


In [None]:

if not ONNX_PATH.exists():
    # Rebuild model and load weights just for a clean export step
    model = SmallCNN()
    torch_ckpt_path = MODELS_DIR / "mnist_cnn.pt"
    assert torch_ckpt_path.exists(), "No checkpoint for export—set SKIP_TRAINING_IF_ONNX_EXISTS=False and rerun training."
    model.load_state_dict(torch.load(torch_ckpt_path, map_location="cpu"))
    model.eval()

    dummy = torch.randn(1, 1, 28, 28)  # NCHW
    torch.onnx.export(
        model, dummy, ONNX_PATH.as_posix(),
        input_names=["input"], output_names=["logits"],
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
        opset_version=13
    )
    print("Exported:", ONNX_PATH.resolve())
else:
    print("Using existing:", ONNX_PATH.resolve())

# Validate ONNX model
m = onnx.load(ONNX_PATH.as_posix())
onnx.checker.check_model(m)
print("ONNX model is valid.")


## 7) Prepare ONNX Runtime session


In [None]:

sess = ort.InferenceSession(ONNX_PATH.as_posix(), providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

# Load meta
with open(META_PATH) as f:
    meta = json.load(f)
mean, std = meta["normalize_mean"], meta["normalize_std"]

# Preprocess to model tensor
to_tensor = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((mean,), (std,)),
])

def onnx_predict_probs_from_pil(pil_img: Image.Image):
    x = to_tensor(pil_img).unsqueeze(0).numpy()  # NCHW float32
    logits = sess.run([output_name], {input_name: x})[0]  # (N, 10)
    # softmax
    logits = logits - logits.max(axis=1, keepdims=True)
    exp = np.exp(logits)
    probs = exp / exp.sum(axis=1, keepdims=True)
    p = probs[0]
    return {str(i): float(p[i]) for i in range(10)}


## 8) MNIST-style preprocessing for drawings (center/scale/pad, COM recenter)


In [None]:

def _center_of_mass_shift(im: Image.Image):
    arr = np.array(im, dtype=np.float32)
    total = arr.sum()
    if total <= 0: return 0, 0
    ys, xs = np.indices(arr.shape)
    cy = (ys * arr).sum() / total
    cx = (xs * arr).sum() / total
    return int(round(arr.shape[1]/2 - cx)), int(round(arr.shape[0]/2 - cy))

def preprocess_mnist_style(pil: Image.Image, threshold=10, blur=0.5) -> Image.Image:
    img = pil.convert("L")
    if np.array(img).mean() > 127:  # invert if white background
        img = ImageOps.invert(img)
    arr = np.array(img, dtype=np.float32)
    if arr.max() > arr.min():
        arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
    arr_u8 = arr.astype("uint8")
    mask = arr_u8 > int(threshold)
    if mask.any():
        ys, xs = np.where(mask)
        y0, y1 = max(0, ys.min()-2), min(arr_u8.shape[0], ys.max()+3)
        x0, x1 = max(0, xs.min()-2), min(arr_u8.shape[1], xs.max()+3)
        img = Image.fromarray(arr_u8[y0:y1, x0:x1])
    else:
        img = Image.fromarray(arr_u8)
    w, h = img.size
    if max(w, h) == 0:
        return Image.new("L", (28, 28), 0)
    scale = 20 / max(w, h)
    nw, nh = max(1, int(round(w*scale))), max(1, int(round(h*scale)))
    img = img.resize((nw, nh), Image.BILINEAR)
    canvas = Image.new("L", (28, 28), 0)
    canvas.paste(img, ((28-nw)//2, (28-nh)//2))
    if blur and blur > 0:
        canvas = canvas.filter(ImageFilter.GaussianBlur(radius=float(blur)))
    dx, dy = _center_of_mass_shift(canvas)
    return ImageChops.offset(canvas, dx, dy)


## 9) Gradio UI (ImageEditor, big canvas, upscaled preview)


In [None]:

import gradio as gr

def top_pred(probs_dict):
    label, conf = max(probs_dict.items(), key=lambda kv: kv[1])
    return f"Prediction: {label} (confidence {conf:.2%})"

def _to_pil_from_editor(value):
    if value is None: return None
    if isinstance(value, dict):      # ImageEditor (type="pil")
        return value.get("composite") or value.get("background")
    if isinstance(value, Image.Image):
        return value
    if isinstance(value, np.ndarray):
        return Image.fromarray(value.astype("uint8"))
    return Image.open(value)

def predict(editor_value, threshold, blur, preview_px):
    raw = _to_pil_from_editor(editor_value)
    if raw is None:
        return "Please draw or upload a digit.", {}, None
    processed = preprocess_mnist_style(raw, threshold=threshold, blur=blur)
    probs = onnx_predict_probs_from_pil(processed)
    preview = processed.resize((int(preview_px), int(preview_px)), Image.NEAREST)
    return top_pred(probs), probs, preview

"""
Gradio UI (Blocks) for MNIST → ONNX Runtime demo.

Purpose
-------
Provide an interactive, notebook-native UI where students draw a digit,
inspect the exact 28x28 tensor fed to the model, and run inference using
ONNX Runtime.

How it works
------------
1) Students draw/upload in the left ImageEditor canvas.
2) We preprocess to MNIST style: grayscale → invert (if needed) →
   crop to digit via threshold → resize longest side to 20 px →
   pad to 28x28 → light blur → center-of-mass recenter.
3) The ONNX model sees ONLY the 28x28 tensor; the preview simply upscales it.

Controls
--------
- Preprocess threshold (0-50):
    Sets the pixel cutoff for detecting the digit's bounding box.
    • Increase to ignore faint specks/noise.
    • Decrease if thin/light strokes get clipped.

- Preprocess blur (σ (sigma), 0.0-1.5):
    Small Gaussian blur after resizing helps antialiased mouse strokes.
    • Typical sweet spot: 0.4-0.8.
    • Too high will over-smooth skinny digits.

- Preview size (px):
    Upscales the 28x28 model input for inspection only.
    • Does NOT change the model's input resolution.

UI pieces
---------
- ImageEditor (left): draw with brush/eraser; large canvas for comfort.
- Preview (right): the 28x28 MNIST-style input, upscaled with nearest-neighbor.
- Predict button: runs preprocessing + ONNX inference and displays
  (a) the top prediction + confidence and (b) full class probabilities (0-9).

Tips
----
• Use a thicker brush and center the digit; adjust threshold/blur if predictions look off.
• The preview is your ground truth for what the model actually “sees.”
"""

with gr.Blocks(title="MNIST (ONNX Runtime)") as demo:
    gr.Markdown("### Draw a digit (0-9), then **Predict** — served via **ONNX Runtime**")

    with gr.Row():
        with gr.Column():
            canvas = gr.ImageEditor(
                image_mode="L", type="pil",
                sources=["upload", "clipboard", "webcam"],
                brush=gr.Brush(default_size=36),
                eraser=gr.Eraser(default_size=36),
                canvas_size=(560, 560),
                height=640, width=640, label="Canvas / Upload",
            )
            threshold = gr.Slider(0, 50, value=10, step=1, label="Preprocess threshold")
            blur = gr.Slider(0.0, 1.5, value=0.5, step=0.1, label="Preprocess blur (σ)")
            preview_px = gr.Slider(320, 768, value=560, step=80, label="Preview size (px)")
        processed_preview = gr.Image(label="Model input (28x28, upscaled)")

    btn = gr.Button("Predict (ONNX)", variant="primary")
    pred_text = gr.Markdown()
    probs_json = gr.JSON(label="Class probabilities")

    btn.click(
        predict,
        inputs=[canvas, threshold, blur, preview_px],
        outputs=[pred_text, probs_json, processed_preview],
    )

demo.launch(share=False, inline=True)



## Optional: Dynamic Quantization
Shrink the ONNX model and sometimes speed up CPU inference. Use this ONNYX model when you are using CPU only, want a small download and snappy startup, or plan to ship the .onnx to a web/backend service.

### What you get
- Smaller file: weights go from float32 (4 bytes) → int8 (1 byte).
    - Our ~421k-param model is ~1.7 MB in fp32; INT8 drops ~4× (≈0.4–0.5 MB).
- Faster CPU inference: ONNX Runtime can use vectorized INT8 kernels (e.g., AVX2/VNNI), which can beat fp32, especially on laptops/servers without powerful GPUs.
- Simpler deployment: smaller artifacts = faster downloads, lighter containers.

### What you give up
- Tiny accuracy dip is possible (usually negligible for MNIST).
- Op coverage: all layers must be supported by ORT’s quantized kernels (your CNN is fine).
- Speedups are hardware-dependent: newer CPUs benefit more.


In [None]:
try:
    from onnxruntime.quantization import quantize_dynamic, QuantType
    INT8_PATH = (MODELS_DIR / "mnist_cnn.int8.onnx").as_posix()
    quantize_dynamic(
        model_input=ONNX_PATH.as_posix(),
        model_output=INT8_PATH,
        per_channel=False,
        reduce_range=False,
        weight_type=QuantType.QInt8,
    )
    print("Quantized model written to:", INT8_PATH)
except Exception as e:
    print("Quantization not available:", e)
