# Part 1: pytorch2onnx

### Load the dataset sample
Sample data made available in repo is generated by `train.py`.

In [12]:
import numpy as np

DATA_PATH = "../data/MNIST/"
X_test = np.load(f"{DATA_PATH}/X_test.npy")
Y_test = np.load(f"{DATA_PATH}/Y_test.npy")
X_test.shape, Y_test.shape

((10000, 1, 28, 28), (10000,))

### Convert to onnx using PyTorch onnx library
Sample model made available in repo is generated by `train.py`.

In [13]:
import torch

PKL_PATH = "../models/lenet_mnist.pkl"
ONNX_PATH = "../models/lenet_mnist.onnx"

model = torch.load(PKL_PATH)
model.eval()

# ONNX export
torch.onnx.export(
  model,                            # model being run
  torch.Tensor(X_test),             # model input (or a tuple for multiple inputs)
  ONNX_PATH,                        # where to save the model
  export_params=True,               # store the trained parameter weights in model file
  do_constant_folding=True,         # whether to execute constant folding for optimization
  input_names = ["input"],          # model's input names
  output_names = ["output"],        # model's output names
  dynamic_axes={"input" : {0 : "batch_size"},    # variable length axes
                "output" : {0 : "batch_size"}})
print("Converted to onnx")

verbose: False, log level: Level.ERROR

Converted to onnx


### Compare ONNX Runtime and PyTorch results

np.testing.assert_allclose ensures the following. See https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html for details.
* absolute(a - b) <= atol
* absolute(a - b) <= rtol * absolute(b)

In [14]:
import onnxruntime

input_tensor = torch.Tensor(X_test[[0]])
y_torch = model.forward(input_tensor)

ort_session = onnxruntime.InferenceSession(ONNX_PATH)
ort_inputs = {ort_session.get_inputs()[0].name: input_tensor.numpy()}
y_onnx = ort_session.run(None, ort_inputs)[0]

np.testing.assert_allclose(y_torch.detach().numpy(), y_onnx, rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!
