In [1]:
# -----------------------------------------------------------
# Environment & library imports
# -----------------------------------------------------------
import torch
from   torch import nn, optim
from   torchvision import datasets, transforms
from   torch.utils.data import DataLoader
import numpy as np
import time
import random

# ---- Reproducibility (optional) ----------------------------
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True      #  use only deterministic convolution kernels (repeatable outputs, slower).
torch.backends.cudnn.benchmark     = False     # disable algorithm auto-tuning (consistent algorithm choice, no profiling overhead).

In [2]:
# -----------------------------------------------------------
#   Dataset & DataLoader
# -----------------------------------------------------------
transform = transforms.ToTensor()          # (0-255) → (0-1) float32
batch_sz  = 64

train_ds = datasets.MNIST(root=".", train=True, download=True,
                          transform=transform)
test_ds  = datasets.MNIST(root=".", train=False, download=True,
                          transform=transform)

train_loader = DataLoader(train_ds, batch_size=batch_sz, shuffle=True)
test_loader  = DataLoader(test_ds , batch_size=256,    shuffle=False)

print(f"Training samples: {len(train_ds)} | Test samples: {len(test_ds)}")


Training samples: 60000 | Test samples: 10000


In [3]:
# -----------------------------------------------------------
#   Model definition 
# -----------------------------------------------------------
class TinyCNN(nn.Module):
    """
    Minimal CNN for MNIST:
        Conv(1→16, 3×3) → ReLU
        Flatten
        Linear(16*26*26 → 10)
    """
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=1,  out_channels=16,
                              kernel_size=3, stride=1, padding=0)
        self.fc   = nn.Linear(16 * 26 * 26, 10)

    def forward(self, x):
        x = torch.relu(self.conv(x))       # shape: [B,16,26,26]
        x = x.view(x.size(0), -1)          # flatten  → [B,16*26*26]
        logits = self.fc(x)                # → [B,10]
        return logits


model = TinyCNN()
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")


Model params: 108,330


In [4]:
# -----------------------------------------------------------
#   Training loop
# -----------------------------------------------------------
epochs  = 3                        # hits ~95+ % in < 30 s on CPU
lr      = 0.01
moment  = 0.7

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=moment)
criterion  = nn.CrossEntropyLoss()

for ep in range(1, epochs + 1):
    model.train()
    running_loss = 0.0

    for imgs, labels in train_loader:
        optimizer.zero_grad()
        logits = model(imgs)
        loss   = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # quick validation each epoch
    model.eval(); correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
    acc = 100 * correct / len(test_ds)
    print(f"Epoch {ep} | loss {running_loss/len(train_loader):.3f} "
          f"| test acc {acc:.2f}%")

torch.save(model.state_dict(), "mnist_cnn.pth")


Epoch 1 | loss 0.390 | test acc 91.90%
Epoch 2 | loss 0.283 | test acc 92.34%
Epoch 3 | loss 0.228 | test acc 94.73%


In [5]:
# -----------------------------------------------------------
# 4 · Export FP32 weights to ONNX
# -----------------------------------------------------------
import torch
import torch.onnx   as onnx_export
import onnx         # runtime checker

model.eval()                            # inference mode
dummy = torch.randn(1, 1, 28, 28)       # (N,C,H,W)  -- batch size 1

onnx_path = "mnist_cnn.onnx"
onnx_export.export(
    model, dummy, onnx_path,
    input_names  = ['input'],
    output_names = ['logits'],
    opset_version=13,                   # compatible with TVM ≥0.14
    dynamic_axes = {
        'input':  {0: 'batch'},         # allow N>1 at runtime
        'logits': {0: 'batch'}
    }
)

# Sanity-check the file
onnx.checker.check_model(onnx_path)
print("✓ ONNX file saved & structurally valid")


✓ ONNX file saved & structurally valid


In [6]:
import netron
# Option A – open in a new browser tab (default port 8080)
netron.start("mnist_cnn.onnx")

# Option B – embed as an <iframe> inside JupyterLab
from IPython.display import IFrame
netron_port = 10080
netron.start("mnist_cnn.onnx")
IFrame(src=f"http://localhost:{netron_port}", width="100%", height=600)


Serving 'mnist_cnn.onnx' at http://localhost:8080
Serving 'mnist_cnn.onnx' at http://localhost:8081


In [12]:
# -----------------------------------------------------------
# 5 · Bring the ONNX graph into TVM Relay IR
# -----------------------------------------------------------
import tvm
from   tvm import relay

onnx_model = onnx.load(onnx_path)
shape_dict = {"input": (1, 1, 28, 28)}  # batch=1 placeholder

mod_fp32, params_fp32 = relay.frontend.from_onnx(
    onnx_model,
    shape =shape_dict,
    freeze_params=False                 # weights baked into graph
)

print(mod.astext(show_meta_data=False)[:600], "...\n")
print(f"Relay graph imported | params: {len(params_fp32)} tensors")


Using injective.cpu for less based on highest priority (10)
Using injective.cpu for take based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using concatenate.cpu for concatenate based on highest priority (10)


#[version = "0.0.5"]
def @main(%input: Tensor[(1, 1, 28, 28), float32] /* ty=Tensor[(1, 1, 28, 28), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
  %0 = nn.conv2d(%input, meta[relay.Constant][0] /* ty=Tensor[(16, 1, 3, 3), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 26, 26), float32] span=/conv/Conv:0:0 */;
  %1 = nn.bias_add(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] span=/conv/Conv.conv.bias:0:0 */) /* ty=Tensor[(1, 16, 26, 26), float32] span=/conv/Conv:0:0 */;
  %2 = nn.relu(%1)  ...

Relay graph imported | params: 4 tensors


In [20]:
# -----------------------------------------------------------
# 6 · Fast INT8 quantisation (global scale method)
# -----------------------------------------------------------
from tvm import relay

with relay.quantize.qconfig(
        calibrate_mode   = 'global_scale',
        global_scale     = 8.0,
        weight_scale     = 'max',
        skip_conv_layers = [],     # ← quantise every conv
        skip_dense_layer = False   # ← quantise the FC too
):
    mod_int8 = relay.quantize.quantize(mod_fp32, params_fp32)

# Make sure types are inferred **after** quantisation
mod_int8 = relay.transform.InferType()(mod_int8)

print(mod_int8.astext(show_meta_data=False)[:1000], "...")


Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
Using injective.cpu for clip based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
Using injective.cpu for clip based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
Using injective.cpu for clip based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using injective.cpu for fixed_point_multiply based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)


#[version = "0.0.5"]
def @main(%input: Tensor[(1, 1, 28, 28), float32] /* ty=Tensor[(1, 1, 28, 28), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
  %0 = multiply(%input, 16f /* ty=float32 */) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 1, 28, 28), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 1, 3, 3), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 26, 26), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1455275264, shift=-7) /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 26, 26), int32] */;
  

In [21]:
# -----------------------------------------------------------
# 7 · Compile INT8 graph and time it
# -----------------------------------------------------------
target = "llvm -mcpu=core-avx2"      # adjust if you’re on ARM / Apple Silicon
dev    = tvm.cpu()

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod_int8, target=target, params=params_fp32)

module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

import numpy as np, time
data = np.random.rand(1,1,28,28).astype("float32")
module.set_input("input", data)

# warm-up
for _ in range(10):
    module.run()

# timing
t = module.module.time_evaluator("run", dev, repeat=100)
tvm_ms = np.median(t().results) * 1000
print(f"TVM INT8 latency: {tvm_ms:.3f} ms  (batch 1)")


Download pre-tuned parameters package from https://raw.githubusercontent.com/tlc-pack/tophub/main/tophub/llvm_v0.04.log
Downloading from url https://raw.githubusercontent.com/tlc-pack/tophub/main/tophub/llvm_v0.04.log to /home/andres/.tvm/tophub/llvm_v0.04.log
Using pad.generic for nn.pad based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using reduce.cpu for sum based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
Using conv2d_nchw_int8.x86 for nn.conv2d based on highest priority (10)
Using dense_pack.x86 for nn.dense based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using layo

TVM INT8 latency: 0.158 ms  (batch 1)


In [25]:
import time, torch

model.eval()
with torch.no_grad():
    # warm-up
    for _ in range(10):
        _ = model(torch.from_numpy(data))

    start = time.time()
    for _ in range(100):
        _ = model(torch.from_numpy(data))
    dt = time.time() - start          # total time for 100 runs

pt_ms = dt / 100 * 1000              # ← divide, not multiply
print(f"PyTorch FP32 latency: {pt_ms:.3f} ms")

print(f"Speed-up: {pt_ms / tvm_ms:.2f}×  (INT8 vs. FP32)")

PyTorch FP32 latency: 0.239 ms
Speed-up: 1.51×  (INT8 vs. FP32)


## ───────────────────────────────────────────
## Variant 1 : Conv + ReLU + MaxPool (Batch-1)
## ───────────────────────────────────────────


In [26]:
# -----------------------------------------------------------
#   TinyCNN-Pool: Conv → ReLU → MaxPool → Flatten → Dense
# -----------------------------------------------------------
class TinyCNNPool(nn.Module):
    """
    Conv(1→16, 3×3) → ReLU → MaxPool(2×2)
    → Flatten → Dense(13*13*16 → 10)
    """
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 16, 3)          # out: 26×26×16
        self.pool = nn.MaxPool2d(2, 2)           # out: 13×13×16
        self.fc   = nn.Linear(16*13*13, 10)

    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = TinyCNNPool()
print(f"Param count: {sum(p.numel() for p in model.parameters()):,}")


Param count: 27,210


In [27]:
# same loaders as before
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion  = nn.CrossEntropyLoss()

# training loop for 3 epochs
for ep in range(3):
    model.train()
    for imgs, labels in train_loader:
        optimizer.zero_grad()
        criterion(model(imgs), labels).backward()
        optimizer.step()

    # val
    model.eval(); correct = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            correct += (model(imgs).argmax(1) == labels).sum().item()
    print(f"Epoch {ep+1}: {100*correct/len(test_ds):.2f}%")

torch.save(model.state_dict(), "mnist_cnn_pool.pth")


Epoch 1: 94.54%
Epoch 2: 96.89%
Epoch 3: 97.58%


In [29]:
# Export to ONNX
dummy = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    model, dummy, "mnist_cnn_pool.onnx",
    input_names=['input'], output_names=['logits'],
    opset_version=13,
    dynamic_axes={'input':{0:'batch'}, 'logits':{0:'batch'}}
)


In [32]:
# Re-import and quantise

onnx_model = onnx.load("mnist_cnn_pool.onnx")
shape = {"input": (1,1,28,28)}

mod_fp32, params_fp32 = relay.frontend.from_onnx(
    onnx_model, shape=shape, freeze_params=False
)

with relay.quantize.qconfig(
        calibrate_mode='global_scale', global_scale=8.0,
        weight_scale='max', skip_conv_layers=[], skip_dense_layer=False):
    mod_int8 = relay.quantize.quantize(mod_fp32, params_fp32)

mod_int8 = relay.transform.InferType()(mod_int8)
print(mod_int8.astext(show_meta_data=False)[:1000], "...")


Using injective.cpu for less based on highest priority (10)
Using injective.cpu for take based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using concatenate.cpu for concatenate based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
Using injective.cpu for clip based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
Using injective.cpu for clip based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using injective.cpu for round based on highest priority (10)
U

#[version = "0.0.5"]
def @main(%input: Tensor[(1, 1, 28, 28), float32] /* ty=Tensor[(1, 1, 28, 28), float32] span=/conv/Conv.input:0:0 */) -> Tensor[(1, 10), float32] {
  %0 = multiply(%input, 16f /* ty=float32 */) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 1, 28, 28), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 1, 28, 28), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 1, 3, 3), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 26, 26), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1261564800, shift=-6) /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 26, 26), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 26, 26), int32] */;
  

In [34]:
# Build and benchmark

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod_int8, target=target, params=params_fp32)

module = tvm.contrib.graph_executor.GraphModule(lib['default'](dev))
for _ in range(10):
    module.set_input('input', np.random.rand(1,1,28,28).astype('float32'))
    module.run()
ft = module.module.time_evaluator("run", dev, repeat=100)
tvm_ms = np.median(ft().results)*1000
print(f"TVM INT8 + MaxPool: {tvm_ms:.3f} ms")

# PyTorch baseline
model.eval(); data = torch.randn(1,1,28,28)
with torch.no_grad():
    start = time.time()
    for _ in range(100):
        _ = model(data)
pt_ms = (time.time()-start)/100*1000
print(f"PyTorch FP32 + MaxPool: {pt_ms:.3f} ms")
print(f"Speed-up: {pt_ms/tvm_ms:.2f}×")


Using pad.generic for nn.pad based on highest priority (10)
Using injective.cpu for cast based on highest priority (10)
Using reduce.cpu for sum based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using injective.cpu for multiply based on highest priority (10)
Using conv2d_nchw_int8.x86 for nn.conv2d based on highest priority (10)
Using dense_pack.x86 for nn.dense based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for

TVM INT8 + MaxPool: 0.064 ms
PyTorch FP32 + MaxPool: 0.112 ms
Speed-up: 1.74×


## ───────────────────────────────────────────
## Batch-size 8 timing
## ───────────────────────────────────────────

In [37]:
# 1 · Re-import ONNX with batch 8 shape
from tvm import relay
onnx_model = onnx.load("mnist_cnn_pool.onnx")

shape8 = {"input": (8, 1, 28, 28)}
mod8, params8 = relay.frontend.from_onnx(
    onnx_model, shape=shape8, freeze_params=False
)

# 2 · Quantise exactly as before
with relay.quantize.qconfig(calibrate_mode='global_scale',
                            global_scale=8.0,
                            weight_scale='max'):
    mod_int8_8 = relay.quantize.quantize(mod8, params8)
mod_int8_8 = relay.transform.InferType()(mod_int8_8)

# 3 · Build
with tvm.transform.PassContext(opt_level=3):
    lib8 = relay.build(mod_int8_8, target=target, params=params8)

module8 = tvm.contrib.graph_executor.GraphModule(lib8['default'](dev))

# 4 · Benchmark
batch8 = np.random.rand(8,1,28,28).astype('float32')
module8.set_input('input', batch8)
for _ in range(5): module8.run()        # warm-up
t8 = module8.module.time_evaluator("run", dev, repeat=100)
tvm_ms8 = np.median(t8().results)*1000
print(f"TVM INT8 batch-8 : {tvm_ms8:.3f} ms  → {tvm_ms8/8:.3f} ms / img")


Using injective.cpu for less based on highest priority (10)
Using injective.cpu for take based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using concatenate.cpu for concatenate based on highest priority (10)
Using conv2d_nchw.x86 for nn.conv2d based on highest priority (10)
Using dense_pack.x86 for nn.dense based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using injective.cpu for expand_dims based on highest priority (10)
Using layout_transform.generic for layout_transform based on highest priority (10)
Using conv2d_NCHWc.x86 for nn.contrib_conv2d_NCHWc based on highest priority (1

TVM INT8 batch-8 : 0.229 ms  → 0.029 ms / img


In [38]:
data8 = torch.randn(8,1,28,28)
model.eval()
with torch.no_grad():
    for _ in range(5): _ = model(data8)            # warm-up
    start = time.time()
    for _ in range(100):
        _ = model(data8)
pt_ms8 = (time.time()-start)/100*1000
print(f"PyTorch FP32 batch-8 : {pt_ms8:.3f} ms  → {pt_ms8/8:.3f} ms / img")
print(f"Throughput speed-up  : {(pt_ms8/8)/(tvm_ms8/8):.2f}×")


PyTorch FP32 batch-8 : 0.755 ms  → 0.094 ms / img
Throughput speed-up  : 3.29×
