In [1]:
import time

import torch
import torch.onnx

from models.estimator import Estimator

# deterministic input data
torch.manual_seed(0)

<torch._C.Generator at 0x7ffaa6aa1df0>

Export pytorch to onnx model

In [2]:
import onnx
import onnxruntime as ort

print("onnx version:", onnx.__version__)
print("onnxruntime version:", ort.__version__)

onnx_version = onnx.__version__.split('.')[1]

onnx version: 1.12.0
onnxruntime version: 1.4.0


In [3]:
device = 'cuda'
batch_size = 4
total_samples = 1000

In [4]:
# input example necessary to export onnx model
torch_input = torch.randn(batch_size, 3, 512, 512, requires_grad=False).to(device)

model_path = "/home/ramon/Git/adroit/vision_foliage_density/weights/foliage_density_v3/density_model_reg.pth"

model = Estimator((512, 512), model_path).model

torch.onnx.export(
    model, # model being run
    torch_input, # model input (or a tuple for multiple inputs)
    "density.onnx", # where to save the model (can be a file or file-like object)
    opset_version=12,
    export_params=True, # store the trained parameter weights inside the model file
    input_names=['input'], # the model's input names
    output_names=['output']) # the model's output names
    #dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) # variable length axes

try:
    # print human readable representation of the graph if exist
    print(onnx.helper.printable_graph(model.graph))
except AttributeError as error:
    print(error)

torch.cuda.empty_cache()

'VGGReg' object has no attribute 'graph'


Test latency inference output with ONNX Runtime

In [5]:
# turn tensor into numpy array 
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# create a onnx runtime session
session = ort.InferenceSession("density.onnx")

In [6]:
start = time.time()
for i in range(total_samples//batch_size):
    onnx_input = {"input": to_numpy(torch_input)}
    onnx_output = session.run(None, onnx_input)
end = time.time()

onnx_time = end - start
print(f"ONNX inference time = {onnx_time}")

ONNX inference time = 83.61464428901672
