In [1]:
import torch
import numpy as np
import pickle
import time
from pathlib import Path
from TCN.mnist_pixel.utils import data_generator, average_runtime
from TCN.mnist_pixel.model import TCN

import onnx
import onnxruntime as ort
from onnx import helper, shape_inference

from torch2trt import TRTModule

In [2]:
model_path = Path('./TCN/mnist_pixel/models')
trt_model_path = Path('./TCN/mnist_pixel/models_trt')
data_path = Path('./TCN/mnist_pixel/data/mnist')
model_name = 'aug_k7l6'
batch_size = 1
in_channels = 1
n_classes = 10

_, test_loader = data_generator(data_path, batch_size)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Torch Model

In [4]:
# Load Model

args = pickle.load(open(model_path / (model_name+'_args.pkl'), 'rb'))
print(args)
channel_sizes = [args.nhid] * args.levels

model = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize, trt=True, apply_max=True)
model.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model.eval()

if torch.cuda.is_available():
    model.cuda()

Namespace(batch_size=64, clip=-1, cuda=True, dropout=0.05, epochs=50, ksize=7, levels=6, log_interval=100, lr=0.002, modelname='aug_k7l6', nhid=25, optim='Adam', permute=False, savedir=PosixPath('models'), savemodel=True, seed=-1, seq_augment=True)


In [5]:
torch_runtime = average_runtime(model, num_samples=10000)
print(f'average runtime: {torch_runtime} ms')

Finished running 0 samples
Finished running 1000 samples
Finished running 2000 samples
Finished running 3000 samples
Finished running 4000 samples
Finished running 5000 samples
Finished running 6000 samples
Finished running 7000 samples
Finished running 8000 samples
Finished running 9000 samples
average runtime: 2.247088265419006 ms


# ONNX GPU Runtime

In [5]:
# Check Model
onnx_model = onnx.load(model_path / (model_name+'.onnx'))
onnx.checker.check_model(onnx_model)
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

# Load Session
onnx_path = model_path / (model_name+'.onnx')
ort_session = ort.InferenceSession(str(onnx_path))

In [6]:
num_samples = 10
runtimes = []

data_queue = torch.zeros((1,1,28*28))
data_queue = data_queue.cuda() if torch.cuda.is_available() else data_queue

for s, (data, target) in enumerate(test_loader):
    if s >= num_samples:
        break
    print(f'Running sample image {s+1}')
    
    if torch.cuda.is_available():
        data, target = data.cuda(), target.cuda()
    
    data = data.view(data.size()[0], 1, -1)
    
    for i in range(data.size()[2]):
        data_queue = data_queue.roll(-1, 2)
        data_queue[:,:,-1] = data[:,:,i]
        
        start = time.time()
        ort_inputs = {ort_session.get_inputs()[0].name: data_queue.cpu().numpy()}
        ort_outs = ort_session.run(None, ort_inputs)
        runtimes.append(time.time() - start)

onnx_runtimes = np.mean(runtimes)*1000
print(f'Average runtime per sample: {onnx_runtimes} ms')

Running sample image 1
Running sample image 2
Running sample image 3
Running sample image 4
Running sample image 5
Running sample image 6
Running sample image 7
Running sample image 8
Running sample image 9
Running sample image 10
Average runtime per sample: 1.212620218189395 ms


# TensorRT Models

## torch2trt using ONNX

In [3]:
# Load Model
model_trt_onnx = TRTModule()
model_trt_onnx.load_state_dict(torch.load(trt_model_path / (model_name+'_trt_onnx_amax.pt')))

<All keys matched successfully>

In [6]:
trt_onnx_runtime = average_runtime(model_trt_onnx, num_samples=10000)
print(f'average runtime: {trt_onnx_runtime} ms')

Finished running 0 samples
Finished running 1000 samples
Finished running 2000 samples
Finished running 3000 samples
Finished running 4000 samples
Finished running 5000 samples
Finished running 6000 samples
Finished running 7000 samples
Finished running 8000 samples
Finished running 9000 samples
average runtime: 0.4672294855117798 ms


## torch2trt using converters

In [7]:
# Load Model
model_trt_no_onnx = TRTModule()
model_trt_no_onnx.load_state_dict(torch.load(trt_model_path / (model_name+'_trt_amax.pt')))

<All keys matched successfully>

In [9]:
trt_no_onnx_runtime = average_runtime(model_trt_no_onnx, num_samples=10000)
print(f'average runtime: {trt_no_onnx_runtime} ms')

Finished running 0 samples
Finished running 1000 samples
Finished running 2000 samples
Finished running 3000 samples
Finished running 4000 samples
Finished running 5000 samples
Finished running 6000 samples
Finished running 7000 samples
Finished running 8000 samples
Finished running 9000 samples
average runtime: 0.48139767646789555 ms


## torch2trt with ONNX - FP16

In [10]:
# Load Model
model_trt_onnx_fp16 = TRTModule()
model_trt_onnx_fp16.load_state_dict(torch.load(trt_model_path / (model_name+'_trt_onnx_amax_fp16.pt')))

<All keys matched successfully>

In [11]:
trt_onnx_fp16_runtime = average_runtime(model_trt_onnx_fp16, num_samples=10000, fp16=True)
print(f'average runtime: {trt_onnx_fp16_runtime} ms')

Finished running 0 samples
Finished running 1000 samples
Finished running 2000 samples
Finished running 3000 samples
Finished running 4000 samples
Finished running 5000 samples
Finished running 6000 samples
Finished running 7000 samples
Finished running 8000 samples
Finished running 9000 samples
average runtime: 0.5889959812164306 ms


## torch2trt with converters - FP16

In [15]:
# Load Model
model_trt_no_onnx_fp16 = TRTModule()
model_trt_no_onnx_fp16.load_state_dict(torch.load(trt_model_path / (model_name+'_trt_amax_fp16.pt')))

<All keys matched successfully>

In [16]:
trt_no_onnx_fp16_runtime = average_runtime(model_trt_no_onnx_fp16, num_samples=10000, fp16=True)
print(f'average runtime: {trt_no_onnx_fp16_runtime} ms')

Finished running 0 samples
Finished running 1000 samples
Finished running 2000 samples
Finished running 3000 samples
Finished running 4000 samples
Finished running 5000 samples
Finished running 6000 samples
Finished running 7000 samples
Finished running 8000 samples
Finished running 9000 samples
average runtime: 0.47009398937225344 ms


In [25]:
print(f'TensorRT (using onnx) Runtime is {torch_runtime/trt_onnx_runtime} times faster than regular PyTorch')
print(f'TensorRT (no onnx) Runtime is {torch_runtime/trt_no_onnx_runtime} times faster than regular PyTorch')

print(f'TensorRT (using onnx) with fp16 mode is {trt_onnx_runtime/ trt_onnx_fp16_runtime} times faster than full precision')
print(f'TensorRT (no onnx) with fp16 mode is {trt_no_onnx_runtime/ trt_no_onnx_fp16_runtime} times faster than full precision')


TensorRT (using onnx) Runtime is 4.409675921520733 times faster than regular PyTorch
TensorRT (no onnx) Runtime is 3.4072912494022343 times faster than regular PyTorch
TensorRT (using onnx) with fp16 mode is 0.9841077976509873 times faster than full precision
TensorRT (no onnx) with fp16 mode is 1.267474802359857 times faster than full precision
