In [1]:
import torch
from TCN.mnist_pixel.utils import data_generator
from TCN.mnist_pixel.model import TCN
import numpy as np
import pickle
import time
from pathlib import Path

In [2]:
model_path = Path('./TCN/mnist_pixel/models')
data_path = Path('./TCN/mnist_pixel/data/mnist')
model_name = 'aug_k7l6'
batch_size = 1
in_channels = 1
n_classes = 10

# Torch Model

In [5]:
# Load Model

args = pickle.load(open(model_path / (model_name+'_args.pkl'), 'rb'))
print(args)
channel_sizes = [args.nhid] * args.levels

_, test_loader = data_generator(data_path, batch_size)
model = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize, trt=True)
model.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model.eval()

if torch.cuda.is_available():
    model.cuda()

Namespace(batch_size=64, clip=-1, cuda=True, dropout=0.05, epochs=50, ksize=7, levels=6, log_interval=100, lr=0.002, modelname='aug_k7l6', nhid=25, optim='Adam', permute=False, savedir=PosixPath('models'), savemodel=True, seed=-1, seq_augment=True)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [24]:
num_samples = 10
runtimes = []

data_queue = torch.zeros((1,1,28*28))
data_queue = data_queue.cuda() if torch.cuda.is_available() else data_queue

for s, (data, target) in enumerate(test_loader):
    if s >= num_samples:
        break
    print(f'Running sample image {s+1}')
    
    if torch.cuda.is_available():
        data, target = data.cuda(), target.cuda()
    
    data = data.view(data.size()[0], 1, -1)
    
    for i in range(data.size()[2]):
        data_queue = data_queue.roll(-1, 2)
        data_queue[:,:,-1] = data[:,:,i]
        
        start = time.time()
        output = model(data_queue)
        pred = output.max(1, keepdim=True)[1]
        runtimes.append(time.time() - start)
        
print(f'Average runtime per sample: {np.mean(runtimes)*1000} ms')

Running sample image 1
Running sample image 2
Running sample image 3
Running sample image 4
Running sample image 5
Running sample image 6
Running sample image 7
Running sample image 8
Running sample image 9
Running sample image 10
Average runtime per sample: 2.2639483213424683 ms


# ONNX GPU Runtime

In [18]:
import onnx
import onnxruntime as ort
from onnx import helper, shape_inference

In [19]:
# Check Model
onnx_model = onnx.load(model_path / (model_name+'.onnx'))
onnx.checker.check_model(onnx_model)
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

# Load Session
onnx_path = model_path / (model_name+'.onnx')
ort_session = ort.InferenceSession(str(onnx_path))

In [25]:
num_samples = 10
runtimes = []

data_queue = torch.zeros((1,1,28*28))
data_queue = data_queue.cuda() if torch.cuda.is_available() else data_queue

for s, (data, target) in enumerate(test_loader):
    if s >= num_samples:
        break
    print(f'Running sample image {s+1}')
    
    if torch.cuda.is_available():
        data, target = data.cuda(), target.cuda()
    
    data = data.view(data.size()[0], 1, -1)
    
    for i in range(data.size()[2]):
        data_queue = data_queue.roll(-1, 2)
        data_queue[:,:,-1] = data[:,:,i]
        
        start = time.time()
        ort_inputs = {ort_session.get_inputs()[0].name: data_queue.cpu().numpy()}
        ort_outs = ort_session.run(None, ort_inputs)
        pred = ort_outs[0].argmax(axis=1)[0]
        runtimes.append(time.time() - start)
        
print(f'Average runtime per sample: {np.mean(runtimes)*1000} ms')

Running sample image 1
Running sample image 2
Running sample image 3
Running sample image 4
Running sample image 5
Running sample image 6
Running sample image 7
Running sample image 8
Running sample image 9
Running sample image 10
Average runtime per sample: 0.4884302920224715 ms


# TensorRT Model

In [3]:
from torch2trt import TRTModule

In [4]:
# Load Model
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(model_path / (model_name+'_trt.pt')))

<All keys matched successfully>