In [1]:
## 1. Install the required dependencies

!pip install onnx
!pip install onnxscript


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
## 2. Use existing model
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        # Load the pre-trained weights
        self.load_pretrained_weights()

    def load_pretrained_weights(self):
        # Load the existing model weights
        model_path = "./best_metric_model.pth"
        pretrained_dict = torch.load(model_path)

        # Initialize the model's state dictionary with pre-trained weights
        model_dict = self.state_dict()

        # Filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

        # Update the model's state dictionary
        model_dict.update(pretrained_dict)

        # Load the updated state dictionary into the model
        self.load_state_dict(model_dict)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of MyModel
model = MyModel()

# Optionally, you can print the model to verify if the weights are loaded correctly
print(model)




MyModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [3]:
## 3.Export the model to ONNX format

torch_input = torch.randn(1, 1, 32, 32)
output_path = "my_model.onnx"
torch.onnx.export(model, torch_input, output_path, verbose=True)

Exported graph: graph(%input.1 : Float(1, 1, 32, 32, strides=[1024, 1024, 32, 1], requires_grad=0, device=cpu),
      %conv1.weight : Float(6, 1, 5, 5, strides=[25, 25, 5, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(6, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(16, 6, 5, 5, strides=[150, 25, 5, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %fc1.weight : Float(120, 400, strides=[400, 1], requires_grad=1, device=cpu),
      %fc1.bias : Float(120, strides=[1], requires_grad=1, device=cpu),
      %fc2.weight : Float(84, 120, strides=[120, 1], requires_grad=1, device=cpu),
      %fc2.bias : Float(84, strides=[1], requires_grad=1, device=cpu),
      %fc3.weight : Float(10, 84, strides=[84, 1], requires_grad=1, device=cpu),
      %fc3.bias : Float(10, strides=[1], requires_grad=1, device=cpu)):
  %/conv1/Conv_output_0 : Float(1, 6, 28, 28, strides=[4704, 784, 28, 1], requires_grad=0,

In [4]:
## 4. Save the ONNX model in a file
import os

# Rename the ONNX file if needed
output_path = "spl.onnx"

# Check if the file exists
if os.path.exists("my_model.onnx"):
    # Rename the file
    os.rename("my_model.onnx", output_path)
    print(f"ONNX model saved as {output_path}")
else:
    print("Error: The ONNX file does not exist.")


ONNX model saved as spl.onnx


In [5]:
## load the ONNX file back into memory and check if it is well formed with the following code:


import onnx
onnx_model = onnx.load("spl.onnx")
onnx.checker.check_model(onnx_model)

In [6]:
## 6. Execute the ONNX model with ONNX Runtime

import onnxruntime
import numpy as np

# Assuming `torch_input` is a torch.Tensor
torch_input = torch.randn(1, 1, 32, 32)

# Assuming `onnx_model` is already loaded
onnx_model = onnx.load("spl.onnx")

# Create ONNX runtime session
ort_session = onnxruntime.InferenceSession("spl.onnx")

# Adapt torch inputs to ONNX inputs
input_name = ort_session.get_inputs()[0].name
onnx_input = {input_name: torch_input.numpy()}  # Convert torch tensor to numpy array

# Run inference
onnxruntime_outputs = ort_session.run(None, onnx_input)

# Print results
print("Output shape:", [output.shape for output in onnxruntime_outputs])


Output shape: [(1, 10)]


In [7]:
import torch
import onnxruntime
import numpy as np

# Assuming `torch_input` is a torch.Tensor
torch_input = torch.randn(1, 1, 32, 32)

# Assuming `torch_model` is your PyTorch model
torch_model = MyModel()

# Run inference with PyTorch model
torch_outputs = torch_model(torch_input)

# Create ONNX runtime session
ort_session = onnxruntime.InferenceSession("spl.onnx")

# Adapt torch outputs to ONNX outputs (not required)
# No need to adapt torch outputs, they can be directly compared with ONNX outputs

# Run inference with ONNX runtime
onnxruntime_inputs = {ort_session.get_inputs()[0].name: torch_input.detach().numpy()}
onnxruntime_outputs = ort_session.run(None, onnxruntime_inputs)

# Convert ONNX runtime outputs to torch tensors for comparison
onnxruntime_outputs = [torch.tensor(output) for output in onnxruntime_outputs]


assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
    # Ignore batch dimension mismatch if present
    if torch_output.shape[0] == 1:
        torch_output = torch_output.squeeze(0)
    
    # Reshape PyTorch output to match ONNX Runtime output shape
    torch_output = torch_output.reshape(onnxruntime_output.shape)

    # Print values of both outputs for comparison
    print("PyTorch output:", torch_output)
    print("ONNX Runtime output:", onnxruntime_output)

    # Convert numpy array back to PyTorch tensor
    onnx_val_tensor = torch.tensor(onnxruntime_output).detach()  # Cloning and detaching the tensor

    # Check for element-wise closeness
    for index, torch_val in np.ndenumerate(torch_output.detach().numpy()):
        onnx_val = onnx_val_tensor[index]  # Access the corresponding value from ONNX Runtime output
        if not torch.isclose(torch.tensor(torch_val), onnx_val, atol=1e-05, rtol=1e-06):
            print(f"Difference found at index {index}:")
            print(f"PyTorch value: {torch_val}")
            print(f"ONNX Runtime value: {onnx_val}")

print("PyTorch and ONNX Runtime output matched!")
print(f"Output length: {len(onnxruntime_outputs)}")






PyTorch output: tensor([[ 0.0562,  0.1270,  0.0226, -0.0884, -0.0753,  0.0391,  0.0653, -0.0704,
          0.1106, -0.0061]], grad_fn=<ReshapeAliasBackward0>)
ONNX Runtime output: tensor([[-0.0338, -0.0471,  0.0242,  0.0872,  0.1257,  0.1178,  0.0555,  0.1461,
          0.0130, -0.0281]])
Difference found at index (0, 0):
PyTorch value: 0.056172024458646774
ONNX Runtime value: -0.03375621885061264
Difference found at index (0, 1):
PyTorch value: 0.12696225941181183
ONNX Runtime value: -0.04705527424812317
Difference found at index (0, 2):
PyTorch value: 0.022590460255742073
ONNX Runtime value: 0.024227982386946678
Difference found at index (0, 3):
PyTorch value: -0.08837180584669113
ONNX Runtime value: 0.08721882104873657
Difference found at index (0, 4):
PyTorch value: -0.07527130097150803
ONNX Runtime value: 0.12569761276245117
Difference found at index (0, 5):
PyTorch value: 0.03910406678915024
ONNX Runtime value: 0.11777838319540024
Difference found at index (0, 6):
PyTorch value: 

  onnx_val_tensor = torch.tensor(onnxruntime_output).detach()  # Cloning and detaching the tensor
