In [7]:
import torch
import onnx
import onnxruntime as ort
import numpy as np

def load_pytorch_model(model_path):
    """Loads a PyTorch model from a given path."""
    model = torch.load(model_path, map_location="cpu",weights_only = False)
    model.eval()
    return model

def convert_to_onnx(model, onnx_path="model.onnx", input_shape=(1, 3, 224, 224)):
    """Converts a PyTorch model to ONNX."""
    dummy_input = torch.randn(*input_shape)
    torch.onnx.export(
        model, dummy_input, onnx_path, input_names=["input"], output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, opset_version=12
    )
    print(f"✅ Model converted to ONNX: {onnx_path}")

def verify_onnx_model(onnx_path):
    """Verifies the integrity of an ONNX model."""
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("✅ ONNX model integrity verified!")

def run_onnx_inference(onnx_path, input_data):
    """Runs inference using ONNX Runtime."""
    session = ort.InferenceSession(onnx_path)
    outputs = session.run(None, {"input": input_data.numpy()})
    print("🔥 ONNX model inference successful!")
    return outputs

In [8]:
model = load_pytorch_model('./data/models/retrained_mobilenet_v2.pt')

In [9]:
convert_to_onnx(model)

✅ Model converted to ONNX: model.onnx


In [10]:
verify_onnx_model('./model.onnx')

✅ ONNX model integrity verified!


In [11]:
dummy_input = torch.randn(1, 3, 224, 224)
run_onnx_inference("model.onnx", dummy_input)

🔥 ONNX model inference successful!


[array([[-5.032561  ,  2.4403944 , -0.34752664, -6.0373116 , -4.012481  ,
          0.5832955 ,  1.0356275 , -0.17992587, -1.1850585 ,  4.4787087 ,
          5.5268545 , -7.0498238 , -2.5604897 , -2.1431074 ,  0.4872554 ,
         -2.7128272 ,  2.8310237 , -0.39761472, -2.7208414 , -1.6594896 ,
         -4.2011743 , -3.994329  , -1.5888466 ]], dtype=float32)]