<a href="https://colab.research.google.com/github/raki-rankawat/stm32/blob/main/CIFAR10_STM32_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Load Trained CIFAR10 Model


In [4]:
!pip -q install onnx onnxruntime onnxscript onnxruntime-tools

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import onnx
import onnxruntime as ort

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType, QuantFormat, CalibrationDataReader

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
# Data Loaders
batch_size = 1

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

100%|██████████| 170M/170M [00:02<00:00, 57.2MB/s]


In [8]:
# CNN Model for CIFAR10 & CIFAR10 Pruned
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128 * 2 * 2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2) # 32 -> 16

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2) # 16 -> 8

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2) # 8 -> 4

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2) # 4 -> 2

        x = x.view(x.size(0), -1) # Flatten

        # x = self.dropout(x)
        x = F.relu(self.fc1(x))
        if self.training:
          x = self.dropout(x)
        x = self.fc2(x)

        return x

In [9]:
# CNN Model for CIFAR10 KD
class StudentNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(16)
    self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
    self.bn2 = nn.BatchNorm2d(32)

    self.fc1 = nn.Linear(32 * 8 * 8, 256)
    self.fc2 = nn.Linear(256, 10)

  def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.max_pool2d(x, 2) # 32 -> 16

    x = F.relu(self.bn2(self.conv2(x)))
    x = F.max_pool2d(x, 2) # 16 -> 8

    x = x.view(x.size(0), -1) # Flatten

    x = F.relu(self.fc1(x))
    x = self.fc2(x)

    return x

In [10]:
# Load weights for CIFAR10 & CIFAR10 Pruned
# model = CIFARConvNet()
# model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_model.pth", map_location=torch.device('cpu')))
# model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_pruned_model.pth", map_location=torch.device('cpu'))) # Pruned Model

In [12]:
# Load weights for CIFAR10 KD
model = StudentNet()
model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/stm_cifar10_kd_model.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [19]:
# Accuracy on the 200 samples
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        if i >= 200:
            break
        out = model(x)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

print("PyTorch accuracy on first 200:", 100*correct/total)

PyTorch accuracy on first 200: 72.0


### FP32 Pipeline

In [14]:
# Export To ONNX
def export_onnx(model, onnx_path):
    model.eval()
    dummy = torch.randn(1, 3, 32, 32)  # NCHW

    torch.onnx.export(
        model,
        dummy,
        onnx_path,
        input_names=["input"],
        output_names=["logits"],
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
        dynamo=False
    )
    onnx.checker.check_model(onnx_path, full_check=False)
    print(f"ONNX model saved to: {onnx_path}")

export_onnx(model, "cifar10_convnet_fp32.onnx")

  torch.onnx.export(


ONNX model saved to: cifar10_convnet_fp32.onnx


In [17]:
# Collect N samples (SAVE NCHW for ST to avoid any internal NHWC->NCHW conversion)
model.eval()

N = 200
inputs_nchw = []
logits = []
labels = []

with torch.no_grad():
    for i, (x, y) in enumerate(test_loader):
        if i >= N:
            break

        out = model(x)

        # Keep NCHW (1,3,32,32) -> store (3,32,32)
        x_nchw = x.detach().cpu().numpy().astype(np.float32)   # (1,3,32,32)
        out_np = out.detach().cpu().numpy()[0].astype(np.float32)  # (10,)

        inputs_nchw.append(x_nchw[0])   # (3,32,32)
        logits.append(out_np)           # (10,)
        labels.append(int(y.item()))

inputs_nchw = np.stack(inputs_nchw, axis=0)       # (N,3,32,32)
logits = np.stack(logits, axis=0)                 # (N,10)
labels = np.array(labels, dtype=np.int32)         # (N,)

np.savez("cifar10_val_200_io.npz", input=inputs_nchw, logits=logits)
np.savez("cifar10_labels_200.npz", label=labels)

print("Saved input shape:", inputs_nchw.shape, "min/max:", inputs_nchw.min(), inputs_nchw.max())

Saved input shape: (200, 3, 32, 32) min/max: -1.0 1.0


In [20]:
# Compute Accuracy
def compute_accuracy(
    labels_npz_path,
    outputs_npz_path,
    output_key="c_outputs_1",
    num_classes=10,
    as_percentage=False
):
    labels = np.load(labels_npz_path)["label"].astype(np.int64)
    out = np.load(outputs_npz_path)

    logits = out[output_key].reshape(len(labels), num_classes)
    pred = np.argmax(logits, axis=1)

    acc = (pred == labels).mean()
    return acc * 100 if as_percentage else acc

In [21]:
acc = compute_accuracy(
    "cifar10_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 accuracy:", acc)

STM32 accuracy: 72.0


### Int8 Pipeline

In [22]:
# Calibration NPZ (inputs only)
def make_calib_npz(test_dataset, N=200, out_path="cifar10_calib_200.npz"):
    loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    xs = []
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            if i >= N:
                break
            xs.append(x.detach().cpu().numpy()[0].astype(np.float32))

    xs = np.stack(xs, axis=0)
    np.savez(out_path, input=xs)
    print("Saved calib:", out_path, xs.shape)
    return out_path

In [23]:
# Quantize FP32 ONNX → INT8 ONNX (QDQ)
class CalibReader(CalibrationDataReader):
    def __init__(self, npz_path, input_name="input"):
        self.x = np.load(npz_path)["input"].astype(np.float32)
        self.input_name = input_name
        self.i = 0

    def get_next(self):
        if self.i >= len(self.x):
            return None
        batch = self.x[self.i:self.i+1]
        self.i += 1
        return {self.input_name: batch}

def quantize_int8_qdq(fp32_onnx="cifar10_convnet_fp32.onnx",
                      calib_npz="cifar10_calib_200.npz",
                      int8_onnx="cifar10_lenet_int8_static_qdq.onnx"):
    reader = CalibReader(calib_npz, input_name="input")

    quantize_static(
        model_input=fp32_onnx,
        model_output=int8_onnx,
        calibration_data_reader=reader,
        quant_format=QuantFormat.QDQ,
        activation_type=QuantType.QInt8,
        weight_type=QuantType.QInt8,
        per_channel=True,
    )
    print("Saved INT8:", int8_onnx)
    return int8_onnx

In [24]:
calib_npz = make_calib_npz(test_dataset, N=200, out_path="cifar10_calib_200.npz")
quantize_int8_qdq("cifar10_convnet_fp32.onnx", calib_npz, "cifar10_convnet_int8_static_qdq.onnx")



Saved calib: cifar10_calib_200.npz (200, 3, 32, 32)




Saved INT8: cifar10_convnet_int8_static_qdq.onnx


'cifar10_convnet_int8_static_qdq.onnx'

In [25]:
# d = np.load("network_val_io.npz")
# print("keys:", d.files)
# for k in d.files:
#     print(k, d[k].shape, d[k].dtype)

acc = compute_accuracy(
    "cifar10_labels_200.npz",
    "network_val_io.npz",
    as_percentage=True
)

print("STM32 accuracy:", acc)

STM32 accuracy: 73.5
