<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/MNIST_STM32_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Load Trained MNIST Model

In [11]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType, QuantFormat, CalibrationDataReader

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
# Data Loaders
batch_size = 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_transform = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_transform, batch_size=batch_size, shuffle=False)

In [15]:
# Model
class MNISTTinyCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
    self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
    self.fc1 = nn.Linear(16 * 7 * 7, 196)
    self.fc2 = nn.Linear(196, 49)
    self.fc3 = nn.Linear(49, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2) # 28 / 2 -> 14

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2) # 14 / 2 -> 7

    x = x.view(x.size(0), -1) # Flatten

    x = F.relu(self.fc1(x)) # 784 -> 196
    x = F.relu(self.fc2(x)) # 196 -> 49
    x = self.fc3(x) # 49 -> 10 (logits)

    return x

In [16]:
# Load weights
model = MNISTTinyCNN()
model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_mnist_model.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [37]:
# Accuracy on the first 200 samples
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        if i >= 200:
            break
        out = model(x)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

print("PyTorch accuracy on first 200:", 100*correct/total)

PyTorch accuracy on first 200: 99.5


### FP32 Pipeline

In [17]:
# Export To ONNX
def export_onnx(model, onnx_path):
    model.eval()
    dummy = torch.randn(1, 1, 28, 28)  # NCHW

    torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
        dynamo=False
    )
    onnx.checker.check_model(onnx_path, full_check=False)
    print(f"ONNX model saved to: {onnx_path}")

export_onnx(model, "mnist_lenet_fp32.onnx")

ONNX model saved to: mnist_lenet_fp32.onnx


  torch.onnx.export(


In [18]:
# Collect N samples
model.eval()

N = 200
inputs = []
logits = []
labels = []

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        if i >= N:
            break

        out = model(x)

        # Safe CPU conversion (works whether model is on CPU or CUDA)
        x_nhwc = x.detach().cpu().numpy().transpose(0, 2, 3, 1).astype(np.float32)  # (1,28,28,1)
        out_np = out.detach().cpu().numpy()[0].astype(np.float32)                   # (10,)

        inputs.append(x_nhwc[0])   # (28,28,1)
        logits.append(out_np)      # (10,)
        labels.append(int(y.item()))

inputs = np.stack(inputs, axis=0)          # (N,28,28,1)
logits = np.stack(logits, axis=0)          # (N,10)
labels = np.array(labels, dtype=np.int32)  # (N,)

np.savez("mnist_val_200_io.npz", input=inputs, logits=logits)
np.savez("mnist_labels_200.npz", label=labels)

In [19]:

# Compute Accuracy
def compute_accuracy(
    labels_npz_path,
    outputs_npz_path,
    output_key="c_outputs_1",
    num_classes=10,
    as_percentage=False
):
    labels = np.load(labels_npz_path)["label"].astype(np.int64)
    out = np.load(outputs_npz_path)

    logits = out[output_key].reshape(len(labels), num_classes)
    pred = np.argmax(logits, axis=1)

    acc = (pred == labels).mean()
    return acc * 100 if as_percentage else acc

In [20]:
acc = compute_accuracy(
    "mnist_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 accuracy:", acc)

STM32 accuracy: 99.5


### Int8 Pipeline

In [32]:
# Calibration NPZ (inputs only)
def make_calib_npz(test_dataset, N=200, out_path="mnist_calib_200.npz"):
    loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    xs = []
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i >= N:
                break
            xs.append(x.detach().cpu().numpy()[0].astype(np.float32))  # (1,28,28)

    xs = np.stack(xs, axis=0)  # (N,1,28,28)
    np.savez(out_path, input=xs)
    print("Saved calib:", out_path, xs.shape)
    return out_path

In [30]:
# Quantize FP32 ONNX â†’ INT8 ONNX (QDQ)
class CalibReader(CalibrationDataReader):
    def __init__(self, npz_path, input_name="input"):
        self.x = np.load(npz_path)["input"].astype(np.float32)  # (N,1,28,28)
        self.input_name = input_name
        self.i = 0

    def get_next(self):
        if self.i >= len(self.x):
            return None
        batch = self.x[self.i:self.i+1]  # (1,1,28,28)
        self.i += 1
        return {self.input_name: batch}

def quantize_int8_qdq(fp32_onnx="mnist_lenet_fp32.onnx",
                      calib_npz="mnist_calib_200.npz",
                      int8_onnx="mnist_lenet_int8_static_qdq.onnx"):
    reader = CalibReader(calib_npz, input_name="input")

    quantize_static(
        model_input=fp32_onnx,
        model_output=int8_onnx,
        calibration_data_reader=reader,
        quant_format=QuantFormat.QDQ,
        activation_type=QuantType.QInt8,
        weight_type=QuantType.QInt8,
        per_channel=True,
    )
    print("Saved INT8:", int8_onnx)
    return int8_onnx

In [33]:
calib_npz = make_calib_npz(test_transform, N=200, out_path="mnist_calib_200.npz")
quantize_int8_qdq("mnist_lenet_fp32.onnx", calib_npz, "mnist_lenet_int8_static_qdq.onnx")



Saved calib: mnist_calib_200.npz (200, 1, 28, 28)
Saved INT8: mnist_lenet_int8_static_qdq.onnx


'mnist_lenet_int8_static_qdq.onnx'

In [35]:
acc = compute_accuracy(
    "mnist_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 accuracy:", acc)

STM32 accuracy: 99.5
