In [7]:
!pip install -q onnx onnxruntime onnxscript

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import onnxruntime
from torchvision import models

In [9]:
class MobileNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = models.get_model(name="mobilenet_v3_large")
        self.network.classifier = nn.Sequential(
            nn.Linear(self.network.classifier[0].in_features, 512, bias=True),
            nn.Hardswish(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(512, 136, bias=True)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.network(x)
        x = x.reshape(x.size(0), 68, 2)
        return x

mobilenetv3 = MobileNet()
# test load state dict
mobilenetv3.load_state_dict(torch.load("../ckpts/mobilenetv3.pth", weights_only=True))
random_input = torch.randn([1, 3, 256, 256])
output = mobilenetv3(random_input)
print(f"\nINPUT SHAPE: {random_input.shape}")
print(f"OUTPUT SHAPE: {output.shape}")


INPUT SHAPE: torch.Size([1, 3, 256, 256])
OUTPUT SHAPE: torch.Size([1, 68, 2])


In [10]:
mobilenetv3.eval()
sample_input = torch.randn(1, 3, 256, 256)

In [11]:
# Export the model to ONNX format
torch.onnx.export(
    mobilenetv3,                   # The model to be exported
    sample_input,            # The sample input tensor
    "../ckpts/mobilenetv3.onnx",            # The output file name
    export_params=True,      # Store the trained parameter weights inside the model file
    opset_version=17,        # The ONNX version to export the model to
    do_constant_folding=True,  # Whether to execute constant folding for optimization
    input_names=['input'],     # The model's input names
    output_names=['output'],   # The model's output names
)

In [12]:
import onnx
import onnxruntime

# Load the ONNX model
onnx_model = onnx.load("../ckpts/mobilenetv3.onnx")

# Check that the model is well-formed
onnx.checker.check_model(onnx_model)

# Run inference using ONNX Runtime
ort_session = onnxruntime.InferenceSession("../ckpts/mobilenetv3.onnx")

# Prepare the input
ort_inputs = {ort_session.get_inputs()[0].name: sample_input.numpy()}

# Run the model
ort_outs = ort_session.run(None, ort_inputs)

print("ONNX model output:", ort_outs)

ONNX model output: [array([[[-0.39168403, -0.18489534],
        [-0.37024066, -0.06099384],
        [-0.33947164,  0.06636669],
        [-0.29212415,  0.1807946 ],
        [-0.21712731,  0.26600364],
        [-0.14818606,  0.33422005],
        [-0.08506522,  0.38997275],
        [-0.03413018,  0.43978712],
        [ 0.03120015,  0.45437065],
        [ 0.10034149,  0.42534643],
        [ 0.16026662,  0.35515857],
        [ 0.2254807 ,  0.2908583 ],
        [ 0.28484723,  0.21786368],
        [ 0.345946  ,  0.13682531],
        [ 0.38384056,  0.03233783],
        [ 0.40671682, -0.07947797],
        [ 0.41877228, -0.19772974],
        [-0.3136756 , -0.29935202],
        [-0.2650919 , -0.32679853],
        [-0.21008322, -0.32956272],
        [-0.1523622 , -0.3112476 ],
        [-0.09870441, -0.28283286],
        [ 0.05771669, -0.26997918],
        [ 0.10144198, -0.29522717],
        [ 0.15228482, -0.30567175],
        [ 0.2023467 , -0.3050224 ],
        [ 0.2558356 , -0.28021204],
        

## Resnet18

In [13]:
class ResNet(nn.Module):
    def __init__(self, model_name: str = "resnet18", weights: str = "DEFAULT"):
        super().__init__()
        self.network = models.get_model(name="resnet18", weights=weights)
        self.network.fc = nn.Linear(self.network.fc.in_features, 136)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.network(x)
        x = x.reshape(x.size(0), 68, 2)
        return x

In [None]:
resnet18 = ResNet()
resnet18.load_state_dict(torch.load("../ckpts/resnet18.pth", weights_only=True))
resnet18.eval()

<All keys matched successfully>

In [16]:
# Export the model to ONNX format
torch.onnx.export(
    resnet18,                   # The model to be exported
    sample_input,            # The sample input tensor
    "../ckpts/resnet18.onnx",            # The output file name
    export_params=True,      # Store the trained parameter weights inside the model file
    opset_version=17,        # The ONNX version to export the model to
    do_constant_folding=True,  # Whether to execute constant folding for optimization
    input_names=['input'],     # The model's input names
    output_names=['output'],   # The model's output names
)

In [17]:
import onnx
import onnxruntime

# Load the ONNX model
onnx_model = onnx.load("../ckpts/resnet18.onnx")

# Check that the model is well-formed
onnx.checker.check_model(onnx_model)

# Run inference using ONNX Runtime
ort_session = onnxruntime.InferenceSession("../ckpts/resnet18.onnx")

# Prepare the input
ort_inputs = {ort_session.get_inputs()[0].name: sample_input.numpy()}

# Run the model
ort_outs = ort_session.run(None, ort_inputs)

print("ONNX model output:", ort_outs)

ONNX model output: [array([[[-0.45445412, -0.06765888],
        [-0.43737787,  0.0323271 ],
        [-0.40834022,  0.14212349],
        [-0.37059885,  0.24048012],
        [-0.30883074,  0.33653918],
        [-0.23823655,  0.41860116],
        [-0.1491706 ,  0.48789018],
        [-0.05199888,  0.52036834],
        [ 0.05197332,  0.5154389 ],
        [ 0.14668252,  0.4961621 ],
        [ 0.23877901,  0.4483102 ],
        [ 0.32559776,  0.38359955],
        [ 0.37981635,  0.30120146],
        [ 0.4051861 ,  0.18851948],
        [ 0.41418374,  0.07769438],
        [ 0.4147988 , -0.04413731],
        [ 0.40104443, -0.16021669],
        [-0.40172106, -0.2363964 ],
        [-0.34704244, -0.2708525 ],
        [-0.28175843, -0.30783334],
        [-0.20375592, -0.30759847],
        [-0.14086582, -0.29197752],
        [-0.03487684, -0.3070017 ],
        [ 0.03634944, -0.34451452],
        [ 0.1126644 , -0.3502441 ],
        [ 0.18725513, -0.32493007],
        [ 0.26198304, -0.2812723 ],
        