<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/Mnist_Pytorch_To_Int8_Onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.1/18.1 MB[0m [31m87.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m94.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m693.4/693.4 kB[0m [31m42.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.7/212.7 kB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.1/133.1 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.5/55.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
# ONNX + ORT quantization
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, QuantType, QuantFormat, CalibrationDataReader

In [4]:
# -----------------------
# Variables & Utils
# -----------------------

batch_size=64
epochs = 5

def total_time_minutes(start_time):
  return (time.time() - start_time) / 60

In [5]:
# -----------------------
# Data Loaders
# -----------------------
transform = transforms.Compose([transforms.ToTensor()])

train_transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_transform = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_transform, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_transform, batch_size=batch_size, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.57MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.15MB/s]


In [6]:
# -----------------------
# Model
# -----------------------
class MNISTTinyCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, 3, 1) # 28 -> 26
    self.conv2 = nn.Conv2d(6, 16, 3, 1) # 13 -> 11
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2) # 28 -> 14

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2) # 14 -> 7

    x = x.view(-1, 16 * 5 * 5)

    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x) # logits

    return x

In [7]:
# -----------------------
# Random Seeds and Model Instance
# -----------------------
torch.manual_seed(41)
model = MNISTTinyCNN()

In [8]:
# -----------------------
# Loss & Optimizer
# -----------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
# -----------------------
# Training
# -----------------------
def train(epoch, model, train_loader, criterian, optimizer, log_every=600):
  model.train()
  trn_corr = 0
  last_loss = None

  for b, (X_train, y_train) in enumerate(train_loader):
    y_pred = model(X_train)
    loss = criterian(y_pred, y_train)
    last_loss = loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    predicted = y_pred.argmax(dim=1)
    trn_corr += (predicted == y_train).sum().item()

    if b % log_every == 0:
      seen = b * len(X_train)
      total = len(train_loader.dataset)
      pct = 100.0 * b / len(train_loader)
      print(f"Epoch {epoch+1} [{seen}/{total} ({pct:.0f}%)]  Loss: {last_loss:.6f}")

  tran_acc = trn_corr / len(train_loader.dataset)
  return last_loss, tran_acc

In [10]:
train_losses = []
train_accs = []
start_time = time.time()

for epoch in range(epochs):
  train_loss, train_acc = train(epoch, model, train_loader, criterion, optimizer)

  train_losses.append(train_loss)
  train_accs.append(train_acc)

  print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

print(f"Time taken: {total_time_minutes(start_time)} minutes!")

Epoch 1/5 | Train Loss: 0.1874 | Train Acc: 0.8975
Epoch 2/5 | Train Loss: 0.1256 | Train Acc: 0.9702
Epoch 3/5 | Train Loss: 0.0080 | Train Acc: 0.9791
Epoch 4/5 | Train Loss: 0.0689 | Train Acc: 0.9834
Epoch 5/5 | Train Loss: 0.0030 | Train Acc: 0.9865
Time taken: 1.5691192150115967 minutes!


In [11]:
# -----------------------
# Testing
# -----------------------
def test(model, test_loader, criterian):
  model.eval()
  test_corr = 0
  total_loss = 0
  total = 0

  for X_test, y_test in test_loader:
    y_val = model(X_test)
    loss = criterian(y_val, y_test)

    total_loss += loss.item() * y_test.size(0)
    total += y_test.size(0)

    predicted = y_val.argmax(dim=1)
    test_corr += (predicted == y_test).sum().item()

  test_loss = total_loss / total
  test_acc = test_corr / total

  return test_loss, test_acc

In [12]:
test_losses = []
test_accs = []
start_time = time.time()

for epoch in range(epochs):
  test_loss, test_acc = test(model, test_loader, criterion)

  test_losses.append(test_loss)
  test_accs.append(test_acc)

  print(f"Epoch {epoch+1}/{epochs} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

print(f"Time taken: {total_time_minutes(start_time)} minutes!")

Epoch 1/5 | Test Loss: 0.0370 | Test Acc: 0.9883
Epoch 2/5 | Test Loss: 0.0370 | Test Acc: 0.9883
Epoch 3/5 | Test Loss: 0.0370 | Test Acc: 0.9883
Epoch 4/5 | Test Loss: 0.0370 | Test Acc: 0.9883
Epoch 5/5 | Test Loss: 0.0370 | Test Acc: 0.9883
Time taken: 0.1804893453915914 minutes!


In [13]:
# -----------------------
# STM32 Pipeline
# -----------------------
def export_onnx(model, onnx_path):
  model.eval()

  dummy = torch.randn(1, 1, 28, 28)

  torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        dynamo=False
    )
  onnx.checker.check_model(onnx_path, full_check=False)
  print(f"ONNX model saved to: {onnx_path}")

export_onnx(model, "mnist_lenet_fp32.onnx")

ONNX model saved to: mnist_lenet_fp32.onnx


  torch.onnx.export(


In [14]:
# -----------------------
# 20 MNIST test images
# -----------------------
test = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())

N = 20

# Inputs in NHWC float32 (matches ST log format)
x = np.zeros((N, 28, 28, 1), dtype=np.float32)
for i in range(N):
    img, _ = test[i]
    x[i, :, :, 0] = img.squeeze(0).numpy()

# Save INPUTS ONLY
np.savez("mnist_20.npz", input=x)

print("Saved mnist_20.npz with keys:", np.load("mnist_20.npz").files)

Saved mnist_20.npz with keys: ['input']


### Quantization

In [15]:
# -----------------------
# Load MNIST (NO normalization for now; matches what you already ran)
# -----------------------
test_loader_quant = DataLoader(test_transform, batch_size=1, shuffle=False)

In [16]:
# -----------------------
# Load trained weights
# -----------------------
model.eval()

MNISTTinyCNN(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [25]:
# -----------------------
# Collect N samples
# -----------------------
N = 200
inputs_nhwc = []
logits = []
labels = []

with torch.no_grad():
  for i, (x, y) in enumerate(test_loader_quant):
    if i >= N:
      break

    out = model(x)
    # convert input to NHWC (1x28x28x1)
    x_nhwc = x.numpy().transpose(0, 2, 3, 1).astype(np.float32)
    inputs_nhwc.append(x_nhwc[0])
    logits.append(out.numpy()[0].astype(np.float32))
    labels.append(int(y.item()))

inputs_nhwc = np.stack(inputs_nhwc, axis=0)
logits = np.stack(logits, axis=0)
labels = np.array(labels, dtype=np.int32)

print("Inputs:", inputs_nhwc.shape, inputs_nhwc.dtype)
print("Logits:", logits.shape, logits.dtype)
print("Labels:", labels.shape, logits.dtype)

Inputs: (200, 28, 28, 1) float32
Logits: (200, 10) float32
Labels: (200,) float32


In [26]:
# -----------------------
# Save files
# -----------------------
# 1) Calibration file (inputs only)
np.savez("mnist_calib_200.npz", input=inputs_nhwc)

# 2) Validation file for stedgeai: input + ONE output array
# Use name "logits" OR "output". If one fails, try the other.
np.savez("mnist_val_200_io.npz", input=inputs_nhwc, logits=logits)

# 3) Labels saved separately (for accuracy in Python)
np.savez("mnist_labels_200.npz", label=labels)

In [27]:
# # load logits (reference outputs) and labels (ground truth)
# val = np.load("mnist_val_200_io.npz")
# lab = np.load("mnist_labels_200.npz")

# logits = val["logits"]          # shape (200, 10)
# labels = lab["label"]           # shape (200,)

# # predicted class = argmax over 10 logits
# pred = np.argmax(logits, axis=1)

# acc = (pred == labels).mean() * 100.0
# print(f"Accuracy on these 200 samples: {acc:.2f}%")

# # optional: confusion matrix
# cm = np.zeros((10, 10), dtype=int)
# for y, p in zip(labels, pred):
#     cm[y, p] += 1
# print("Confusion matrix (rows=true, cols=pred):")
# print(cm)

In [29]:
import numpy as np

val = np.load("mnist_val_200_io.npz")
labels = np.load("mnist_labels_200.npz")["label"]

out = np.load("network_val_io.npz")
logits = out["c_outputs_1"].reshape(len(labels), 10)

acc = (np.argmax(logits, axis=1) == labels).mean()
print("STM32 accuracy:", acc)

STM32 accuracy: 0.99
