In [1]:
import os, sys
from pathlib import Path
import pickle
import torch
import torch.onnx
import numpy as np

from TCN.mnist_pixel.utils import data_generator
from TCN.mnist_pixel.model import TCN

In [2]:
model_path = Path('./TCN/mnist_pixel/models')
data_path = Path('./TCN/mnist_pixel/data/mnist')
model_name = 'aug_k7l6'
batch_size = 1
in_channels = 1
n_classes = 10

args = pickle.load(open(model_path / (model_name+'_args.pkl'), 'rb'))
print(args)
channel_sizes = [args.nhid] * args.levels

print(model_path / (model_name+'.pt'))

_, test_loader = data_generator(data_path, batch_size)
model = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize, inference=True)
model.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model.eval()

print(model.receptive_field)

script_model = torch.jit.script(model)

if torch.cuda.is_available():
    model.cuda()
    script_model.cuda()


# model.set_fast_inference(batch_size)

Namespace(batch_size=64, clip=-1, cuda=True, dropout=0.05, epochs=50, ksize=7, levels=6, log_interval=100, lr=0.002, modelname='aug_k7l6', nhid=25, optim='Adam', permute=False, savedir=PosixPath('models'), savemodel=True, seed=-1, seq_augment=True)
TCN/mnist_pixel/models/aug_k7l6.pt
757


In [3]:
test_data, test_target = next(iter(test_loader))
test_data = test_data.view(test_data.size()[0], 1, -1).cuda()
test_input = test_data[:,:,0].view(test_data.size()[0], test_data.size()[1], 1)
# test_input = test_data
test_input.size()

torch.Size([1, 1, 1])

In [9]:
test_out = script_model(test_input)
test_out.shape

torch.Size([1, 10])

In [5]:
trace_model = torch.jit.trace(model, test_input)

With rtol=1e-05 and atol=1e-05, found 10 element(s) (out of 10) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.05174589157104492 (-2.222052574157715 vs. -2.2737984657287598), which occurred at index (0, 1).
  _module_class,


In [18]:
# export onnx model

torch.onnx.export(
    script_model,
    test_input,
    f'{model_path}/{model_name}.onnx',
    export_params=True,
#     do_constant_folding=True,
    keep_initializers_as_inputs=True,
    opset_version=12,
    input_names = ['input'],
    output_names = ['output'],
#     dynamic_axes={
#                  'input' : {0 : 'batch_size'}, 
#                  'output' : {0 : 'batch_size'}
#                  }
    example_outputs=test_out
)

RuntimeError: 
temporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name="cache"](%844, %cache_update.1)
:


In [11]:
import onnx
import onnxruntime
from onnx import helper, shape_inference

In [12]:
onnx_model = onnx.load(model_path / (model_name+'.onnx'))
onnx.checker.check_model(onnx_model)
# onnx.helper.printable_graph(onnx_model.graph)

In [13]:
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

In [14]:
onnx_path = model_path / (model_name+'.onnx')
ort_session = onnxruntime.InferenceSession(str(onnx_path))

In [15]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(test_input)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(test_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

AssertionError: 
Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 9 / 10 (90%)
Max absolute difference: 0.53876376
Max relative difference: 0.26503953
 x: array([[-2.571531, -2.101018, -2.800727, -3.385211, -2.661237, -3.102572,
        -2.47572 , -3.681555, -0.921515, -2.469629]], dtype=float32)
 y: array([[-2.032768, -2.222053, -2.835331, -3.388276, -2.601045, -2.940132,
        -2.14889 , -3.810934, -1.097009, -2.680756]], dtype=float32)

In [11]:
print(test_out)
print(ort_outs[0])

tensor([[-2.0328, -2.2221, -2.8353, -3.3883, -2.6010, -2.9401, -2.1489, -3.8109,
         -1.0970, -2.6808]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
[[-2.010258  -2.2737985 -2.8612118 -3.4314961 -2.5628037 -2.965137
  -2.1034932 -3.8295255 -1.1005102 -2.6809928]]


In [16]:
import matplotlib.pyplot as plt
from IPython import display
import ipywidgets as widgets
from PIL import Image
import io
import cv2
import time

In [19]:
num_samples = 0
# fig, ax = plt.subplots(1,1)

im_queue = [0 for i in range(28*28)]
curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28))
_, encoded_image = cv2.imencode('.png', curr_im)
im_bytes = encoded_image.tobytes()
im_disp = widgets.Image(value=im_bytes, width=200, height=200)

true_val = widgets.Label(value=f'True Label: N/A')
pred_val = widgets.Label(value=f'Predicted Label: N/A')
label_disp = widgets.VBox((true_val, pred_val))

display.display(im_disp)
display.display(label_disp)

for data, target in test_loader:
    im = data.squeeze().cpu().detach().numpy()
    rows, cols = im.shape
    im = (im - im.min())
    im = (im/im.max() * 255).astype('uint8')
    curr_im = np.ones(im.shape, dtype=np.uint8)*255
    
    true_val.value = f'True Label: {target.item()}'

    if torch.cuda.is_available():
        data, target = data.cuda(), target.cuda()
    
    data = data.view(data.size()[0], 1, -1)
    
    for i in range(data.size()[2]):
        num_samples += 1
        curr_row = i // cols
        curr_col = i % cols
        
        im_queue.append(im[curr_row,curr_col])
        im_queue = im_queue[1:]
        
        curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28))
        _, encoded_image = cv2.imencode('.png', curr_im)
        im_bytes = encoded_image.tobytes()
        im_disp.value = im_bytes
        
#         inp = data[:,:,i].view(data.size()[0], data.size()[1], 1)
#         ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inp)}
#         ort_outs = ort_session.run(None, ort_inputs)
    
        output = script_model(data[:,:,i].view(data.size()[0], data.size()[1], 1))
        pred_orig = output.max(1, keepdim=True)[1] #max returns values and indices
        
#         pred = ort_outs[0].argmax(axis=1)[0] #max returns values and indices
        
        if num_samples > 500:
            pred_val.value = f'Predicted Label: {pred_orig.item()}'
            
#         print(ort_outs[0])
#         print(output)
#         time.sleep(0.5)
        
    time.sleep(2)
    
    
    for i in np.zeros(np.random.randint(50, 200)):
        im_queue.append(0)
        im_queue = im_queue[1:]
        curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28)) 
        _, encoded_image = cv2.imencode('.png', curr_im)
        im_bytes = encoded_image.tobytes()
        im_disp.value = im_bytes
        true_val.value = f'True Label: N/A'
        num_samples += 1
        
#         inp = torch.tensor([i], dtype=torch.float).cuda().view(1, 1, 1)
#         ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inp)}
#         ort_outs = ort_session.run(None, ort_inputs)
        
# #         output = model(torch.tensor([i], dtype=torch.float).cuda().view(1, 1, 1))
#         pred = ort_outs[0].argmax(axis=1)[0] #max returns values and indices
        if num_samples > model.receptive_field:
            pred_val.value = f'Predicted Label: {pred.item()}'
        

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf\x80H\x…

VBox(children=(Label(value='True Label: N/A'), Label(value='Predicted Label: N/A')))

KeyboardInterrupt: 