In [4]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
import numpy as np

os.makedirs("../app/models", exist_ok=True)

PATH_TO_WEIGHTS = "../outputs/cnn/fold_4_best_model.pth"
ONNX_PATH = "../app/models/cnn_eye_classifier.onnx"

print(f"PyTorch version: {torch.__version__}")
print(f"ONNX version: {onnx.__version__}")
print(f"ONNX Runtime version: {ort.__version__}")

PyTorch version: 2.5.1
ONNX version: 1.19.1
ONNX Runtime version: 1.22.2


In [5]:
IMG_SIZE = 64

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(p=0.2),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(p=0.3),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(p=0.4),
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Dropout(p=0.6),
            nn.Linear(128 * (IMG_SIZE // 8) * (IMG_SIZE // 8), 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        return self.net(x)

print("Model architecture defined")

Model architecture defined


In [6]:
device = torch.device("cpu")
model = CNN().to(device)

print(f"Loading weights from: {PATH_TO_WEIGHTS}")

if not os.path.exists(PATH_TO_WEIGHTS):
    raise FileNotFoundError(f"Model file not found: {PATH_TO_WEIGHTS}")

try:
    state_dict = torch.load(PATH_TO_WEIGHTS, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()
    print("Weights loaded successfully!")
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
except Exception as e:
    print(f"Error loading weights: {e}")
    raise

Loading weights from: ../outputs/cnn/fold_4_best_model.pth
Weights loaded successfully!
Total parameters: 1,142,082


In [7]:
print("Exporting to ONNX...")

dummy_input = torch.randn(1, 1, 64, 64, device=device)

with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        ONNX_PATH,
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )

file_size = os.path.getsize(ONNX_PATH) / (1024 * 1024)
print(f"ONNX model exported: {ONNX_PATH}")
print(f"File size: {file_size:.2f} MB")

Exporting to ONNX...
ONNX model exported: ../app/models/cnn_eye_classifier.onnx
File size: 4.36 MB


In [8]:
print("Verifying ONNX model...")

try:
    onnx_model = onnx.load(ONNX_PATH)
    onnx.checker.check_model(onnx_model)
    print("ONNX model valid")
    
    print(f"\nModel Info:")
    print(f"IR Version: {onnx_model.ir_version}")
    print(f"Producer: {onnx_model.producer_name}")
    print(f"Opset Version: {onnx_model.opset_import[0].version}")
    
    print(f"\nInput:")
    for input_tensor in onnx_model.graph.input:
        print(f"Name: {input_tensor.name}")
        shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' for dim in input_tensor.type.tensor_type.shape.dim]
        print(f"Shape: {shape}")
    
    print(f"\nOutput:")
    for output_tensor in onnx_model.graph.output:
        print(f"Name: {output_tensor.name}")
        shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' for dim in output_tensor.type.tensor_type.shape.dim]
        print(f"Shape: {shape}")
        
except Exception as e:
    print(f"ONNX model failed: {e}")
    raise

Verifying ONNX model...
ONNX model valid

Model Info:
IR Version: 7
Producer: pytorch
Opset Version: 13

Input:
Name: input
Shape: ['dynamic', 1, 64, 64]

Output:
Name: output
Shape: ['dynamic', 2]


In [9]:
print("Testing ONNX inference...")

ort_session = ort.InferenceSession(ONNX_PATH)

test_input = torch.randn(1, 1, 64, 64, device=device)

with torch.no_grad():
    pytorch_output = model(test_input).cpu().numpy()

onnx_output = ort_session.run(
    None,
    {"input": test_input.cpu().numpy()}
)[0]

diff = np.abs(pytorch_output - onnx_output).max()
print(f"\nMax difference between PyTorch and ONNX: {diff:.6f}")

if diff < 1e-4:
    print("ONNX model matches PyTorch model")
else:
    print("Outputs differ significantly")

print(f"\nPyTorch output shape: {pytorch_output.shape}")
print(f"ONNX output shape: {onnx_output.shape}")
print(f"PyTorch prediction: {pytorch_output.argmax()}")
print(f"ONNX prediction: {onnx_output.argmax()}")

Testing ONNX inference...

Max difference between PyTorch and ONNX: 0.000000
ONNX model matches PyTorch model

PyTorch output shape: (1, 2)
ONNX output shape: (1, 2)
PyTorch prediction: 0
ONNX prediction: 0


In [10]:
import time

print("Running performance benchmark...")

num_iterations = 100
test_input_np = test_input.cpu().numpy()

model.eval()
start_time = time.time()
with torch.no_grad():
    for _ in range(num_iterations):
        _ = model(test_input)
pytorch_time = (time.time() - start_time) / num_iterations * 1000

ort_session_std = ort.InferenceSession(ONNX_PATH)
start_time = time.time()
for _ in range(num_iterations):
    _ = ort_session_std.run(None, {"input": test_input_np})
onnx_time = (time.time() - start_time) / num_iterations * 1000

print(f"\n{'='*50}")
print(f"INFERENCE PERFORMANCE (Average over {num_iterations} runs)")
print(f"{'='*50}")
print(f"PyTorch Model:       {pytorch_time:.2f} ms")
print(f"ONNX Standard:       {onnx_time:.2f} ms ({pytorch_time/onnx_time:.2f}x)")
print(f"{'='*50}")

Running performance benchmark...

INFERENCE PERFORMANCE (Average over 100 runs)
PyTorch Model:       2.15 ms
ONNX Standard:       0.23 ms (9.35x)
