In [48]:
import torch
import torch.nn as nn
import time

import onnxruntime as ort
import numpy as np

print(torch.cuda.is_available())

False


In [None]:

# Define a simple feedforward network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(2, 2)  # Input to layer 1
        self.fc2 = nn.Linear(2, 3)  # Layer 1 to layer 2
        self.fc3 = nn.Linear(3, 4)  # Layer 2 to layer 3 (output)

    def forward(self, x):
        # Forward pass through the network
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)  # Output layer has no activation
        return x

# Instantiate the model
model = SimpleNet()

# Example input for inference
example_input = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) # 3x2x1

# Perform inference
output1 = model(example_input)
print("PyTorch Inference Output:", output1.detach().numpy())

output2 = model(example_input)
print("PyTorch Inference Output:", output2.detach().numpy())

output3 = model(example_input)
print("PyTorch Inference Output:", output3.detach().numpy())

PyTorch Inference Output: [[-0.07739066 -0.6254066   0.47564512 -0.77045244]
 [-0.07206593 -0.6295359   0.4737581  -0.7702721 ]
 [-0.06761487 -0.6332613   0.47209924 -0.769849  ]]
PyTorch Inference Output: [[-0.07739066 -0.6254066   0.47564512 -0.77045244]
 [-0.07206593 -0.6295359   0.4737581  -0.7702721 ]
 [-0.06761487 -0.6332613   0.47209924 -0.769849  ]]
PyTorch Inference Output: [[-0.07739066 -0.6254066   0.47564512 -0.77045244]
 [-0.07206593 -0.6295359   0.4737581  -0.7702721 ]
 [-0.06761487 -0.6332613   0.47209924 -0.769849  ]]


In [56]:


# Specify the path for the ONNX model file
onnx_model_path = "simple_model.onnx"

# Convert the PyTorch model to ONNX
torch.onnx.export(
    model,                          # model being exported
    example_input,                  # model input (or a tuple for multiple inputs)
    onnx_model_path,                # where to save the model (can be a file or file-like object)
    input_names=["input"],          # the model's input names
    output_names=["output"],        # the model's output names
)

print(f"Model successfully converted to ONNX: {onnx_model_path}")



[torch.onnx] Obtain model graph for `SimpleNet([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SimpleNet([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Model successfully converted to ONNX: simple_model.onnx


In [None]:
# Load the ONNX model
onnx_model_path = "simple_model.onnx"
ort_session = ort.InferenceSession(onnx_model_path)

# Prepare sample input data (same shape as the PyTorch model)
onnx_input = np.array([[1.0, 2.0]], dtype=np.float32)


# Run inference on the ONNX model
onnx_output = ort_session.run(None, {"input": onnx_input})

# Print the ONNX inference result
print("ONNX Inference Output:", onnx_output)

ONNX Inference Output: [array([[ 0.49386343,  0.23764727, -0.32167193,  0.3408503 ]],
      dtype=float32)]


In [25]:
import time

# Warm-up runs (not timed)
example_input = torch.randn(1, 2)
for _ in range(100):
    _ = model(example_input)

# Now benchmark
time_sum = 0
n_iters = 50_000

for i in range(n_iters):
    t0 = time.perf_counter()
    output = model(example_input)
    time_sum += (time.perf_counter() - t0)

print(f"Average time: {time_sum/n_iters:.6f} seconds")
print(f"Average time: {(time_sum/n_iters)*1000:.4f} ms")

Average time: 0.000237 seconds
Average time: 0.2373 ms


In [26]:
onnx_model_path = "simple_model.onnx"
ort_session = ort.InferenceSession(onnx_model_path)

# Prepare input outside the loop
onnx_input = np.random.rand(1, 2).astype(np.float32)

# Warm-up runs (not timed)
for _ in range(100):
    _ = ort_session.run(None, {"input": onnx_input})

# Now benchmark
time_sum = 0
n_iters = 50_000

for i in range(n_iters):
    t0 = time.perf_counter()
    onnx_output = ort_session.run(None, {"input": onnx_input})
    time_sum += (time.perf_counter() - t0)

print(f"Average time: {time_sum/n_iters:.6f} seconds")
print(f"Average time: {(time_sum/n_iters)*1000:.4f} ms")

Average time: 0.000040 seconds
Average time: 0.0397 ms
