<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/Mnist_Pytorch_To_Int8_Onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Model Training

In [33]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [34]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [35]:
# ONNX + ORT quantization
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType, QuantFormat, CalibrationDataReader

In [36]:
# -----------------------
# Data Loaders
# -----------------------
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_transform = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_transform, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_transform, batch_size=batch_size, shuffle=False)

In [37]:
# -----------------------
# Model
# -----------------------
class MNISTTinyCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
    self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
    self.fc1 = nn.Linear(16 * 7 * 7, 196)
    self.fc2 = nn.Linear(196, 49)
    self.fc3 = nn.Linear(49, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2) # 28 / 2 -> 14

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2) # 14 / 2 -> 7

    x = x.view(x.size(0), -1) # Flatten

    x = F.relu(self.fc1(x)) # 784 -> 196
    x = F.relu(self.fc2(x)) # 196 -> 49
    x = self.fc3(x) # 49 -> 10 (logits)

    return x

In [38]:
# -----------------------
# Random Seeds and Model Instance
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(41)
model = MNISTTinyCNN().to(device)

In [39]:
# -----------------------
# Loss & Optimizer
# -----------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [40]:
# -----------------------
# Training
# -----------------------
def train(epoch, model, loader, criterion, optimizer):
  model.train() # training mode

  correct = 0
  total = 0
  running_loss = 0

  for b, (X, y) in enumerate(loader):
    X = X.to(device)
    y = y.to(device)

    # Forward
    outputs = model(X)

    # Loss
    loss = criterion(outputs, y)

    # Backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Metrics
    batch_size = y.size(0)
    running_loss += loss.item() * batch_size
    preds = outputs.argmax(dim=1)
    correct += (preds == y).sum().item()
    total += batch_size

  avg_loss = running_loss / total
  accuracy = correct / total

  return avg_loss, accuracy

In [41]:
# -----------------------
# Testing
# -----------------------
def test(model, loader, criterion):
  model.eval() # testing mode

  correct = 0
  total = 0
  running_loss = 0

  with torch.no_grad():
    for X, y in loader:
      X = X.to(device)
      y = y.to(device)

      # Forward
      outputs = model(X)

      # Loss
      loss = criterion(outputs, y)

      # Metrics
      batch_size = y.size(0)
      running_loss += loss.item() * batch_size
      preds = outputs.argmax(dim=1)
      correct += (preds == y).sum().item()
      total += batch_size

  avg_loss = running_loss / total
  accuracy = correct / total

  return avg_loss, accuracy

In [42]:
epochs = 5
start_time = time.time()

for epoch in range(1, epochs + 1):
  train_loss, train_acc = train(epoch, model, train_loader, criterion, optimizer)
  test_loss, test_acc = test(model, test_loader, criterion)

  print(
      f"Epoch {epoch}/{epochs} | "
      f"Train loss: {train_loss:.4f}, Train acc: {train_acc*100:.2f}% | "
      f"Test loss: {test_loss:.4f}, Test acc: {test_acc*100:.2f}%"
  )

print(f"Time taken: {(time.time() - start_time) / 60} minutes!")

Epoch 1/5 | Train loss: 0.2322, Train acc: 93.04% | Test loss: 0.0706, Test acc: 97.73%
Epoch 2/5 | Train loss: 0.0648, Train acc: 97.98% | Test loss: 0.0455, Test acc: 98.55%
Epoch 3/5 | Train loss: 0.0451, Train acc: 98.59% | Test loss: 0.0325, Test acc: 98.84%
Epoch 4/5 | Train loss: 0.0323, Train acc: 99.00% | Test loss: 0.0368, Test acc: 98.75%
Epoch 5/5 | Train loss: 0.0264, Train acc: 99.15% | Test loss: 0.0389, Test acc: 98.88%
Time taken: 2.8720115661621093 minutes!


In [43]:
# -----------------------
# Compute Accuracy
# -----------------------

def compute_int8_accuracy(
    labels_npz_path,
    outputs_npz_path,
    output_key="c_outputs_1",
    num_classes=10,
    as_percentage=False
):
    labels = np.load(labels_npz_path)["label"].astype(np.int64)
    out = np.load(outputs_npz_path)

    logits = out[output_key].reshape(len(labels), num_classes)
    pred = np.argmax(logits, axis=1)

    acc = (pred == labels).mean()
    return acc * 100 if as_percentage else acc

### STM32 Pipeline

In [44]:
# -----------------------
# Export To ONNX
# -----------------------
def export_onnx(model, onnx_path):
  model.eval()

  dummy = torch.randn(1, 1, 28, 28)

  torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        dynamo=False
    )
  onnx.checker.check_model(onnx_path, full_check=False)
  print(f"ONNX model saved to: {onnx_path}")

export_onnx(model, "mnist_lenet_fp32.onnx")

ONNX model saved to: mnist_lenet_fp32.onnx


  torch.onnx.export(


In [45]:
# -----------------------
# Load MNIST (NO normalization for now; matches what you already ran)
# -----------------------
test_loader_quant = DataLoader(test_transform, batch_size=1, shuffle=False)

In [46]:
# -----------------------
# Collect N samples
# -----------------------
N = 200
inputs_nhwc = []
logits = []
labels = []

with torch.no_grad():
  for i, (x, y) in enumerate(test_loader_quant):
    model.eval()

    if i >= N:
      break

    out = model(x)
    # convert input to NHWC (1x28x28x1)
    x_nhwc = x.numpy().transpose(0, 2, 3, 1).astype(np.float32)
    inputs_nhwc.append(x_nhwc[0])
    logits.append(out.numpy()[0].astype(np.float32))
    labels.append(int(y.item()))

inputs_nhwc = np.stack(inputs_nhwc, axis=0)
logits = np.stack(logits, axis=0)
labels = np.array(labels, dtype=np.int32)

In [47]:
# -----------------------
# Save files
# -----------------------
# 1) Calibration file (for qantization later)
np.savez("mnist_calib_200.npz", input=inputs_nhwc)

# 2) Validation file for stedgeai: input + ONE output array
np.savez("mnist_val_200_io.npz", input=inputs_nhwc, logits=logits)

# 3) Labels saved separately (for accuracy in Python)
np.savez("mnist_labels_200.npz", label=labels)

#### Run classifier for accuracy (Optional):

```
& "C:\Users\rakes\STM32Cube\Repository\Packs\STMicroelectronics\X-CUBE-AI\10.2.0\Utilities\windows\stedgeai.exe" validate `
  --target stm32n6 `
  --name network `
  -m C:\Users\rakes\Downloads\mnist_lenet_fp32.onnx `
  --st-neural-art n6-allmems-O3@C:\Users\rakes\STM32Cube\Repository\Packs\STMicroelectronics\X-CUBE-AI\10.2.0\scripts\N6_scripts\user_neuralart.json `
  --workspace C:\Users\rakes\AppData\Local\Temp\mxAI_ws_val `
  --output C:\Users\rakes\.stm32cubemx\network_output `
  --mode target `
  --valinput C:\Users\rakes\Downloads\mnist_val_200_io.npz `
  --classifier `
  --desc serial:COM3:921600
```



In [48]:
acc = compute_int8_accuracy(
    "mnist_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 INT8 accuracy:", acc)

STM32 INT8 accuracy: 99.5


### Int8 Pipeline

Because embedded hardware has brutal constraints:

* Limited RAM
* Limited flash
* Power budget
* Real-time deadlines

INT8 is how we make DL fit and run fast on MCUs.

In [49]:
class MNISTCalibReader(CalibrationDataReader):
    def __init__(self, npz_path, input_name="input"):
        d = np.load(npz_path)
        x = d[input_name].astype(np.float32)          # likely (N,28,28,1)

        # If NHWC, convert to NCHW expected by your ONNX model
        if x.ndim == 4 and x.shape[-1] == 1:         # (N,28,28,1)
            x = np.transpose(x, (0, 3, 1, 2))         # -> (N,1,28,28)

        self.x = x
        self.input_name = input_name
        self.i = 0

    def get_next(self):
        if self.i >= len(self.x):
            return None
        batch = self.x[self.i:self.i+1]              # (1,1,28,28)
        self.i += 1
        return {self.input_name: batch}

inp   = r"mnist_lenet_fp32.onnx"
calib = r"mnist_calib_200.npz"
out   = r"mnist_lenet_int8_static_qdq.onnx"

reader = MNISTCalibReader(calib, input_name="input")

quantize_static(
    model_input=inp,
    model_output=out,
    calibration_data_reader=reader,
    quant_format=QuantFormat.QDQ,     # important for compiler friendliness
    activation_type=QuantType.QInt8,
    weight_type=QuantType.QInt8,
    per_channel=True,
)

print("Saved:", out)



Saved: mnist_lenet_int8_static_qdq.onnx


In [50]:
acc = compute_int8_accuracy(
    "mnist_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 INT8 accuracy:", acc)

STM32 INT8 accuracy: 99.5
