<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/Mnist_Pytorch_To_Int8_Onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [None]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
# ONNX + ORT quantization
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, QuantType, QuantFormat, CalibrationDataReader

In [None]:
# -----------------------
# Model
# -----------------------
class MNISTTinyCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, 3, 1) # 28 -> 26
    self.conv2 = nn.Conv2d(6, 16, 3, 1) # 13 -> 11
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2) # 28 -> 14

    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2) # 14 -> 7

    x = torch.flatten(x, 1)

    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x) # logits

    return x

In [None]:
# -----------------------
# Utils
# -----------------------
def total_time_minutes(start_time):
  return (time.time() - start_time) / 60

def get_loaders(batch_size=64):
  transform = transforms.Compose([transforms.ToTensor()])

  train_transform = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  test_transform = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

  train_loader = DataLoader(train_transform, batch_size=batch_size, shuffle=True)
  test_loader = DataLoader(test_transform, batch_size=batch_size, shuffle=False)

  return train_loader, test_loader

In [None]:
# -----------------------
# Training
# -----------------------
def train(epoch, model, train_loader, criterian, optimizer, log_every=600):
  model.train()
  trn_corr = 0
  last_loss = None

  for b, (X_train, y_train) in enumerate(train_loader):
    y_pred = model(X_train)
    loss = criterian(y_pred, y_train)
    last_loss = loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    predicted = y_pred.argmax(dim=1)
    trn_corr += (predicted == y_train).sum().item()

    if b % log_every == 0:
      seen = b * len(X_train)
      total = len(train_loader.dataset)
      pct = 100.0 * b / len(train_loader)
      print(f"Epoch {epoch+1} [{seen}/{total} ({pct:.0f}%)]  Loss: {last_loss:.6f}")

  tran_acc = trn_corr / len(train_loader.dataset)
  return last_loss, tran_acc

In [None]:
# -----------------------
# Testing
# -----------------------
def test(model, test_loader, criterian):
  model.eval()
  test_corr = 0
  total_loss = 0
  total = 0

  for X_test, y_test in test_loader:
    y_val = model(X_test)
    loss = criterian(y_val, y_test)

    total_loss += loss.item() * y_test.size(0)
    total += y_test.size(0)

    predicted = y_val.argmax(dim=1)
    test_corr += (predicted == y_test).sum().item()

  test_loss = total_loss / total
  test_acc = test_corr / total

  return test_loss, test_acc

In [None]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# -----------------------
# ONNX + INT8
# -----------------------
def export_onnx(model, onnx_path):
  model.eval()

  dummy = torch.randn(1, 1, 28, 28)
  torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        opset_version=13,
        do_constant_folding=True,
        dynamic_axes={"input": {0: "batch"}, "logits": {0: "batch"}},
    )
  onnx.checker.check_model(onnx_path)
  print(f"ONNX model saved to: {onnx_path}")

def quantize_onnx_dynamic(fp32_path, int8_path):
  quantize_dynamic(
      model_input=fp32_path,
      model_output=int8_path,
      weight_type=QuantType.QUInt8,
      extra_options={"DisableShapeInference": True}
    )
  print(f"INT8 model saved to: {int8_path}")

def ort_sanity_check(onnx_path):
  sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
  input_name = sess.get_inputs()[0].name
  x = np.random.rand(1, 1, 28, 28).astype(np.float32)
  out = sess.run(None, {input_name: x})[0]
  print("ORT run output shape: ", out.shape)

In [None]:
# -----------------------
# Main
# -----------------------
def main():
  torch.manual_seed(41)

  train_loader, test_loader = get_loaders(batch_size=64)

  model = MNISTTinyCNN()
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  epochs = 5
  train_losses, train_accs = [], []
  test_losses, test_accs = [], []

  start_time = time.time()

  for epoch in range(epochs):
    train_loss, train_acc = train(epoch, model, train_loader, criterion, optimizer)
    test_loss, test_acc = test(model, test_loader, criterion)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

  print(f"Time taken: {total_time_minutes(start_time)} minutes!")

  # Export + quantize for STM32 pipeline
  fp32_onnx = "mnist_lenet_fp32.onnx"
  int8_onnx = "minst_lenet_int8.onnx"

  export_onnx(model, fp32_onnx)
  quantize_onnx_dynamic(fp32_onnx, int8_onnx)
  ort_sanity_check(int8_onnx)

  print("STM32 pipeline is ready", fp32_onnx, int8_onnx)

In [None]:
if __name__ == "__main__":
  main()

Epoch 1/5 | Train Loss: 0.1874 | Train Acc: 0.8975 | Test Loss: 0.1055 | Test Acc: 0.9670
Epoch 2/5 | Train Loss: 0.1431 | Train Acc: 0.9706 | Test Loss: 0.0688 | Test Acc: 0.9779
Epoch 3/5 | Train Loss: 0.1530 | Train Acc: 0.9784 | Test Loss: 0.0572 | Test Acc: 0.9815
Epoch 4/5 | Train Loss: 0.0505 | Train Acc: 0.9832 | Test Loss: 0.0486 | Test Acc: 0.9826


  torch.onnx.export(
W1226 23:32:09.919000 641 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 13 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


Epoch 5/5 | Train Loss: 0.0027 | Train Acc: 0.9862 | Test Loss: 0.0396 | Test Acc: 0.9876
Time taken: 1.9165327350298564 minutes!
[torch.onnx] Obtain model graph for `MNISTTinyCNN([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MNISTTinyCNN([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/BaseConverter.h:68: adapter_lookup: Assertion `false`

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 1 of general pattern rewrite rules.
ONNX model saved to: mnist_lenet_fp32.onnx


InferenceError: [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (400) vs (120)