In [2]:
import os, sys
from pathlib import Path
import pickle
import torch
import torch.onnx
import numpy as np

from TCN.mnist_pixel.utils import data_generator
from TCN.mnist_pixel.model import TCN

In [3]:
model_path = Path('./TCN/mnist_pixel/models')
data_path = Path('./TCN/mnist_pixel/data/mnist')
model_name = 'aug_k7l6'
batch_size = 1
in_channels = 1
n_classes = 10

args = pickle.load(open(model_path / (model_name+'_args.pkl'), 'rb'))
print(args)
channel_sizes = [args.nhid] * args.levels

print(model_path / (model_name+'.pt'))

_, test_loader = data_generator(data_path, batch_size)
model = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize)
model.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model.eval()

print(model.receptive_field)

if torch.cuda.is_available():
    model.cuda()
    
# model.set_fast_inference(batch_size)

Namespace(batch_size=64, clip=-1, cuda=True, dropout=0.05, epochs=50, ksize=7, levels=6, log_interval=100, lr=0.002, modelname='aug_k7l6', nhid=25, optim='Adam', permute=False, savedir=PosixPath('models'), savemodel=True, seed=-1, seq_augment=True)
TCN/mnist_pixel/models/aug_k7l6.pt
757


In [4]:
test_data, test_target = next(iter(test_loader))
test_data = test_data.view(test_data.size()[0], 1, -1).cuda()
test_input = test_data[:,:,0].view(test_data.size()[0], test_data.size()[1], 1)
# test_input = test_data
test_input.size()

torch.Size([1, 1, 1])

In [5]:
test_out = model(test_input)
test_out.shape

torch.Size([1, 10])

In [6]:
# export onnx model

torch.onnx.export(
    model,
    test_input,
    f'{model_path}/{model_name}.onnx',
    export_params=True,
#     do_constant_folding=True,
    keep_initializers_as_inputs=True,
    opset_version=10,
    input_names = ['input'],
    output_names = ['output'],
#     dynamic_axes={
#                  'input' : {0 : 'batch_size'}, 
#                  'output' : {0 : 'batch_size'}
#                  }
)

In [7]:
import onnx
import onnxruntime
from onnx import helper, shape_inference

In [8]:
onnx_model = onnx.load(model_path / (model_name+'.onnx'))
onnx.checker.check_model(onnx_model)
# onnx.helper.printable_graph(onnx_model.graph)

In [9]:
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

In [10]:
onnx_path = model_path / (model_name+'.onnx')
ort_session = onnxruntime.InferenceSession(str(onnx_path))

In [11]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(test_input)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(test_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


In [12]:
print(test_out)
print(ort_outs[0])

tensor([[-2.0328, -2.2221, -2.8353, -3.3883, -2.6010, -2.9401, -2.1489, -3.8109,
         -1.0970, -2.6808]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
[[-2.0327675 -2.2220526 -2.835331  -3.3882763 -2.6010454 -2.9401317
  -2.1488905 -3.8109334 -1.0970094 -2.680756 ]]


In [22]:
dir(onnx_model)

['ByteSize',
 'Clear',
 'ClearExtension',
 'ClearField',
 'CopyFrom',
 'DESCRIPTOR',
 'DiscardUnknownFields',
 'Extensions',
 'FindInitializationErrors',
 'FromString',
 'HasExtension',
 'HasField',
 'IsInitialized',
 'ListFields',
 'MergeFrom',
 'MergeFromString',
 'ParseFromString',
 'RegisterExtension',
 'SerializePartialToString',
 'SerializeToString',
 'SetInParent',
 'UnknownFields',
 'WhichOneof',
 '_CheckCalledFromGeneratedFile',
 '_SetListener',
 '__class__',
 '__deepcopy__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__unicode__',
 '_extensions_by_name',
 '_extensions_by_number',
 'doc_string',
 'domain',
 'graph',
 'ir_version',
 'metadata_props',