<a href="https://colab.research.google.com/github/rrankawat/stm32/blob/main/CIFAR10_STM32_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Pytorch To ONNX

In [39]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [40]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType, QuantFormat, CalibrationDataReader

In [41]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [42]:
# Data Loaders
batch_size = 1

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [43]:
# CNN Model
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 32 -> 16

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 16 -> 8

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 8 -> 4

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 4 -> 2

        x = x.view(x.size(0), -1) # Flatten

        # x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

In [44]:
# Load weights
model = CIFARConvNet()
model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_model.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [45]:
# Export To ONNX
def export_onnx(model, onnx_path):
  model.eval()

  dummy = torch.randn(1, 3, 32, 32)  # 3 channels and 32x32 for CIFAR-10

  torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
        dynamo=False
    )
  onnx.checker.check_model(onnx_path, full_check=False)
  print(f"ONNX model saved to: {onnx_path}")

export_onnx(model, "cifar10_fp32.onnx")

  torch.onnx.export(


ONNX model saved to: cifar10_fp32.onnx


### FP32 Pipeline

In [46]:
# Collect N samples
N = 200
inputs_nhwc = []
logits = []
labels = []

with torch.no_grad():
  for i, (x, y) in enumerate(test_loader):
    model.eval()

    if i >= N:
      break

    out = model(x)
    # convert input to NHWC (1x28x28x1)
    x_nhwc = x.numpy().transpose(0, 2, 3, 1).astype(np.float32)
    inputs_nhwc.append(x_nhwc[0])
    logits.append(out.numpy()[0].astype(np.float32))
    labels.append(int(y.item()))

inputs_nhwc = np.stack(inputs_nhwc, axis=0)
logits = np.stack(logits, axis=0)
labels = np.array(labels, dtype=np.int32)

# Save files
np.savez("cifar10_calib_200.npz", input=inputs_nhwc) # For quantization later
np.savez("cifar10_val_200_io.npz", input=inputs_nhwc, logits=logits)
np.savez("cifar10_labels_200.npz", label=labels)

In [48]:
# Compute Accuracy
def compute_accuracy(
    labels_npz_path,
    outputs_npz_path,
    output_key="c_outputs_1",
    num_classes=10,
    as_percentage=False
):
    labels = np.load(labels_npz_path)["label"].astype(np.int64)
    out = np.load(outputs_npz_path)

    logits = out[output_key].reshape(len(labels), num_classes)
    pred = np.argmax(logits, axis=1)

    acc = (pred == labels).mean()
    return acc * 100 if as_percentage else acc

In [49]:
acc = compute_accuracy(
    "cifar10_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 FP32 accuracy:", acc)

STM32 FP32 accuracy: 14.000000000000002
