<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/Mnist_Pytorch_To_Int8_Onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
# !pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [50]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [51]:
# ONNX + ORT quantization
# import onnx
# import onnxruntime as ort

# from onnxruntime.quantization import quantize_dynamic, QuantType, QuantFormat, CalibrationDataReader

In [52]:
# -----------------------
# Data Loaders
# -----------------------
batch_size = 64
transform = transforms.Compose([transforms.ToTensor()])

train_transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_transform = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_transform, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_transform, batch_size=batch_size, shuffle=False)

In [53]:
# -----------------------
# Model
# -----------------------
class MNISTTinyCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, 3, 1) # 28 -> 26
    self.conv2 = nn.Conv2d(6, 18, 3, 1) # 13 -> 11
    self.fc1 = nn.Linear(18 * 5 * 5, 150)
    self.fc2 = nn.Linear(150, 50)
    self.fc3 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2) # 28 - 2 / 2 -> 13

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2) # 13 - 2 / 2 -> 5

    x = x.view(-1, 18 * 5 * 5)

    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x) # logits

    return x

In [54]:
# -----------------------
# Random Seeds and Model Instance
# -----------------------
torch.manual_seed(41)
model = MNISTTinyCNN()

In [55]:
# -----------------------
# Loss & Optimizer
# -----------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [56]:
# -----------------------
# Training
# -----------------------
def train(epoch, model, loader, criterion, optimizer, log_every=600):
  model.train() # training mode

  correct = 0
  total = 0
  running_loss = 0

  for b, (X, y) in enumerate(loader):

    # Forward
    outputs = model(X)

    # Loss
    loss = criterion(outputs, y)

    # Backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Metrics
    running_loss += loss.item()
    preds = outputs.argmax(dim=1)
    correct += (preds == y).sum().item()
    total += y.size(0)

    # Print out results
    # if b % 600 == 0:
    #   print(f"Epoch {epoch} finished | Avg loss so far = {running_loss / len(loader):.4f}")

  avg_loss = running_loss / len(loader)
  accuracy = correct / total

  return avg_loss, accuracy

In [74]:
# -----------------------
# Testing
# -----------------------
def test(model, loader, criterion):
  model.eval() # testing mode

  correct = 0
  total = 0
  running_loss = 0

  with torch.no_grad():
    for X, y in loader:
      # Forward
      outputs = model(X)

      # Loss
      loss = criterion(outputs, y)

      # Metrics
      running_loss += loss.item()
      preds = outputs.argmax(dim=1)
      correct += (preds == y).sum().item()
      total += y.size(0)

  avg_loss = running_loss / len(loader)
  accuracy = correct / total

  return avg_loss, accuracy

In [73]:
epochs = 5
start_time = time.time()

for epoch in range(1, epochs + 1):
  train_loss, train_acc = train(epoch, model, train_loader, criterion, optimizer)
  test_loss, test_acc = test(model, test_loader, criterion)

  print(
      f"Epoch {epoch}/{epochs} | "
      f"Train loss: {train_loss:.4f}, Train acc: {train_acc*100:.2f}% | "
      f"Test loss: {test_loss:.4f}, Test acc: {test_acc*100:.2f}%"
  )

print(f"Time taken: {(time.time() - start_time) / 60} minutes!")

Epoch 1/5 | Train loss: 0.0379, Train acc: 98.75% | Test loss: 0.0470, Test acc: 98.52%
Epoch 2/5 | Train loss: 0.0322, Train acc: 98.99% | Test loss: 0.0337, Test acc: 98.94%
Epoch 3/5 | Train loss: 0.0270, Train acc: 99.11% | Test loss: 0.0355, Test acc: 98.79%
Epoch 4/5 | Train loss: 0.0238, Train acc: 99.22% | Test loss: 0.0367, Test acc: 98.83%
Epoch 5/5 | Train loss: 0.0197, Train acc: 99.33% | Test loss: 0.0399, Test acc: 98.87%
Time taken: 1.7040016889572143 minutes!


In [59]:
# # -----------------------
# # STM32 Pipeline
# # -----------------------
# def export_onnx(model, onnx_path):
#   model.eval()

#   dummy = torch.randn(1, 1, 28, 28)

#   torch.onnx.export(
#         model,
#         dummy,
#         onnx_path,
#         input_names=["input"],
#         output_names=["logits"],
#         export_params=True,
#         opset_version=18,
#         do_constant_folding=True,
#         dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
#         dynamo=False
#     )
#   onnx.checker.check_model(onnx_path, full_check=False)
#   print(f"ONNX model saved to: {onnx_path}")

# export_onnx(model, "mnist_lenet_fp32.onnx")

In [60]:
# # -----------------------
# # 20 MNIST test images
# # -----------------------
# test = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())

# N = 20

# # Inputs in NHWC float32 (matches ST log format)
# x = np.zeros((N, 28, 28, 1), dtype=np.float32)
# for i in range(N):
#     img, _ = test[i]
#     x[i, :, :, 0] = img.squeeze(0).numpy()

# # Save INPUTS ONLY
# np.savez("mnist_20.npz", input=x)

# print("Saved mnist_20.npz with keys:", np.load("mnist_20.npz").files)

### Quantization

In [61]:
# # -----------------------
# # Load MNIST (NO normalization for now; matches what you already ran)
# # -----------------------
# test_loader_quant = DataLoader(test_transform, batch_size=1, shuffle=False)

In [62]:
# # -----------------------
# # Load trained weights
# # -----------------------
# model.eval()

In [63]:
# # -----------------------
# # Collect N samples
# # -----------------------
# N = 200
# inputs_nhwc = []
# logits = []
# labels = []

# with torch.no_grad():
#   for i, (x, y) in enumerate(test_loader_quant):
#     if i >= N:
#       break

#     out = model(x)
#     # convert input to NHWC (1x28x28x1)
#     x_nhwc = x.numpy().transpose(0, 2, 3, 1).astype(np.float32)
#     inputs_nhwc.append(x_nhwc[0])
#     logits.append(out.numpy()[0].astype(np.float32))
#     labels.append(int(y.item()))

# inputs_nhwc = np.stack(inputs_nhwc, axis=0)
# logits = np.stack(logits, axis=0)
# labels = np.array(labels, dtype=np.int32)

# print("Inputs:", inputs_nhwc.shape, inputs_nhwc.dtype)
# print("Logits:", logits.shape, logits.dtype)
# print("Labels:", labels.shape, logits.dtype)

In [64]:
# # -----------------------
# # Save files
# # -----------------------
# # 1) Calibration file (inputs only)
# np.savez("mnist_calib_200.npz", input=inputs_nhwc)

# # 2) Validation file for stedgeai: input + ONE output array
# # Use name "logits" OR "output". If one fails, try the other.
# np.savez("mnist_val_200_io.npz", input=inputs_nhwc, logits=logits)

# # 3) Labels saved separately (for accuracy in Python)
# np.savez("mnist_labels_200.npz", label=labels)

In [65]:
# val = np.load("mnist_val_200_io.npz")
# labels = np.load("mnist_labels_200.npz")["label"]

# out = np.load("network_val_io.npz")
# logits = out["c_outputs_1"].reshape(len(labels), 10)

# acc = (np.argmax(logits, axis=1) == labels).mean()
# print("STM32 accuracy:", acc)

In [66]:
# d = np.load("mnist_calib_200.npz")
# print(d.files)
# x = d["input"]
# print(x.shape, x.dtype, x.min(), x.max())

In [67]:
# inp = r"mnist_lenet_fp32.onnx"
# out = r"mnist_lenet_int8_dynamic.onnx"

# quantize_dynamic(
#     model_input=inp,
#     model_output=out,
#     weight_type=QuantType.QInt8
# )

# print("Saved:", out)


In [68]:
# import numpy as np
# from onnxruntime.quantization import (
#     quantize_static, CalibrationDataReader,
#     QuantFormat, QuantType
# )

# class MNISTCalibReader(CalibrationDataReader):
#     def __init__(self, npz_path, input_name="input"):
#         d = np.load(npz_path)
#         x = d[input_name].astype(np.float32)          # likely (N,28,28,1)

#         # If NHWC, convert to NCHW expected by your ONNX model
#         if x.ndim == 4 and x.shape[-1] == 1:         # (N,28,28,1)
#             x = np.transpose(x, (0, 3, 1, 2))         # -> (N,1,28,28)

#         self.x = x
#         self.input_name = input_name
#         self.i = 0

#     def get_next(self):
#         if self.i >= len(self.x):
#             return None
#         batch = self.x[self.i:self.i+1]              # (1,1,28,28)
#         self.i += 1
#         return {self.input_name: batch}

# inp   = r"mnist_lenet_fp32.onnx"
# calib = r"mnist_calib_200.npz"
# out   = r"mnist_lenet_int8_static_qdq.onnx"

# reader = MNISTCalibReader(calib, input_name="input")

# quantize_static(
#     model_input=inp,
#     model_output=out,
#     calibration_data_reader=reader,
#     quant_format=QuantFormat.QDQ,     # important for compiler friendliness
#     activation_type=QuantType.QInt8,
#     weight_type=QuantType.QInt8,
#     per_channel=True,
# )

# print("Saved:", out)

### Re-quantize

In [69]:
# # Step 1) Create a FLOAT calibration NPZ in NCHW

# import numpy as np

# d = np.load(r"mnist_calib_200.npz")
# x = d["input"].astype(np.float32)  # currently NHWC (N,28,28,1)

# x = np.transpose(x, (0,3,1,2))     # -> NCHW (N,1,28,28)

# np.savez(r"mnist_calib_200_nchw_fp32.npz", input=x)
# print(x.shape, x.dtype, x.min(), x.max())

In [70]:
# # Step 2) Quantize again with QDQ, but donâ€™t let ORT change IO types

# import numpy as np
# from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantFormat, QuantType

# class CalibReader(CalibrationDataReader):
#     def __init__(self, npz_path, input_name="input"):
#         self.x = np.load(npz_path)[input_name].astype(np.float32)  # (N,1,28,28)
#         self.input_name = input_name
#         self.i = 0

#     def get_next(self):
#         if self.i >= len(self.x):
#             return None
#         batch = self.x[self.i:self.i+1]  # (1,1,28,28)
#         self.i += 1
#         return {self.input_name: batch}

# inp   = r"mnist_lenet_fp32.onnx"
# calib = r"mnist_calib_200_nchw_fp32.npz"
# out   = r"mnist_lenet_int8_static_qdq_floatio.onnx"

# reader = CalibReader(calib, "input")

# quantize_static(
#     model_input=inp,
#     model_output=out,
#     calibration_data_reader=reader,
#     quant_format=QuantFormat.QDQ,
#     activation_type=QuantType.QInt8,
#     weight_type=QuantType.QInt8,
#     per_channel=True,
# )

# print("Saved:", out)

In [71]:
# # Step 3) Verify the new ONNX input type is float

# import onnx
# m = onnx.load(r"mnist_lenet_int8_static_qdq_floatio.onnx")
# t = m.graph.input[0].type.tensor_type
# print("input elem type:", t.elem_type)  # should correspond to float

In [72]:
# import numpy as np

# # labels from your dataset creation step
# labels = np.load("mnist_labels_200.npz")["label"]  # shape (200,)

# # outputs produced by stedgeai validate
# out = np.load(r"network_val_io.npz")

# # int8 outputs from the board, shape (200,1,1,10)
# c = out["c_outputs_1"].astype(np.float32).reshape(len(labels), 10)

# # dequantize using values printed in your log
# scale = 0.165920198
# zp = 18
# logits = (c - zp) * scale

# pred = np.argmax(logits, axis=1)
# acc = (pred == labels).mean()
# print("STM32 INT8 accuracy:", acc)