[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)

In [None]:
# to make onnx export work
!pip install onnx onnxruntime

See complete tutorial in Pytorch docs:
 - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

In [1]:
import onnx
import onnxruntime
import numpy as np

import torch
import segmentation_models_pytorch as smp

### Create random model (or load your own model)

In [2]:
model = smp.Unet("resnet34", encoder_weights="imagenet", classes=1)
model = model.eval()

### Export the model to ONNX

In [3]:
# dynamic_axes is used to specify the variable length axes. it can be just batch size
dynamic_axes = {0: "batch_size", 2: "height", 3: "width"}

onnx_model_name = "unet_resnet34.onnx"

onnx_model = torch.onnx.export(
    model,  # model being run
    torch.randn(1, 3, 224, 224),  # model input
    onnx_model_name,  # where to save the model (can be a file or file-like object)
    export_params=True,  # store the trained parameter weights inside the model file
    opset_version=17,  # the ONNX version to export
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names=["input"],  # the model's input names
    output_names=["output"],  # the model's output names
    dynamic_axes={  # variable length axes
        "input": dynamic_axes,
        "output": dynamic_axes,
    },
)

In [4]:
# check with onnx first
onnx_model = onnx.load(onnx_model_name)
onnx.checker.check_model(onnx_model)

### Run with onnxruntime

In [5]:
# create sample with different batch size, height and width
# from what we used in export above
sample = torch.randn(2, 3, 512, 512)

ort_session = onnxruntime.InferenceSession(
    onnx_model_name, providers=["CPUExecutionProvider"]
)

# compute ONNX Runtime output prediction
ort_inputs = {"input": sample.numpy()}
ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)
ort_outputs

[array([[[[-1.41701847e-01, -4.63768840e-03,  1.21411584e-01, ...,
            5.22197843e-01,  3.40217263e-01,  8.52423906e-02],
          [-2.29843616e-01,  2.19401851e-01,  3.53053480e-01, ...,
            2.79466838e-01,  3.20288718e-01, -2.22393833e-02],
          [-3.12503517e-01, -3.66358161e-02,  1.19251609e-02, ...,
           -5.48991561e-02,  3.71140465e-02, -1.82842150e-01],
          ...,
          [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,
           -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],
          [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,
           -8.64556879e-02, -1.74177170e-01,  6.03154302e-03],
          [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,
           -1.91345289e-01, -1.82653710e-01,  1.17175849e-02]]],
 
 
        [[[-6.16237633e-02,  1.12350248e-01,  1.59193069e-01, ...,
            4.03313845e-01,  2.26862252e-01,  7.33022243e-02],
          [-1.60109222e-01,  1.21696621e-01,  1.84655115e-01, ...,
  

### Verify it's the same as for pytorch model

In [6]:
# compute PyTorch output prediction
with torch.no_grad():
    torch_out = model(sample)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!
