In [1]:
%matplotlib inline
import torch
import torch.nn.functional as F
import os, sys
from pathlib import Path
from TCN.mnist_pixel.utils import data_generator
from TCN.mnist_pixel.model import TCN
import numpy as np
import matplotlib.pyplot as plt
import pickle
from IPython import display
import ipywidgets as widgets
from PIL import Image
import io
import cv2
import time

from torch2trt import torch2trt

In [2]:
model_path = Path('./TCN/mnist_pixel/models')
data_path = Path('./TCN/mnist_pixel/data/mnist')
model_name = 'aug_k7l6'
batch_size = 1
in_channels = 1
n_classes = 10

args = pickle.load(open(model_path / (model_name+'_args.pkl'), 'rb'))
print(args)
channel_sizes = [args.nhid] * args.levels

print(model_path / (model_name+'.pt'))

_, test_loader = data_generator(data_path, batch_size)
model = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize, apply_max=True)
model.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model.eval()

print(model.receptive_field)

if torch.cuda.is_available():
    model.cuda()

Namespace(batch_size=64, clip=-1, cuda=True, dropout=0.05, epochs=50, ksize=7, levels=6, log_interval=100, lr=0.002, modelname='aug_k7l6', nhid=25, optim='Adam', permute=False, savedir=PosixPath('models'), savemodel=True, seed=-1, seq_augment=True)
TCN/mnist_pixel/models/aug_k7l6.pt


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


757


In [3]:
x = torch.rand((1,1,28*28)).cuda()
model(x)

tensor([0], device='cuda:0')

In [4]:
# model_traced = torch.jit.trace(model, x)
# model_trt = torch2trt(model, [x],  strict_type_constraints=True, use_onnx=True)
# torch.save(model_trt.state_dict(), model_path / (model_name+'_trt.pt'))

model_trt = torch2trt(model, [x],  strict_type_constraints=True, use_onnx=True, fp16=True)
torch.save(model_trt.state_dict(), model_path / (model_name+'_trt_fp16.pt'))

In [7]:
# import tensorrt as trt
# from torch2trt import tensorrt_converter
# from torch2trt.torch2trt import add_missing_trt_tensors

# @tensorrt_converter("torch.Tensor.topk")
# @tensorrt_converter("torch.topk")
# def convert_topk(ctx):
#     input_tensor = ctx.method_args[0]
#     # The K in "topk"
#     top_k = ctx.method_args[1]
#     output_val = ctx.method_return[0]
#     output_idx = ctx.method_return[1]

#     # Handle optional axis argument
#     has_axis = len(ctx.method_args) > 2
#     if has_axis:
#         axis = ctx.method_args[2]
#     else:
#         axis = 0

#     # input_trt = broadcast_trt_tensors(ctx.network, [input_trt], len(output[0].shape) - 1)
#     # input_layer = ctx.get
#     layer = ctx.network.add_topk(input_tensor._trt, trt.TopKOperation.MAX, top_k, axis)
   
#     # Has two outputs: (Tensor, LongTensor). 
#     # LongTensor is the corresponding indices to the topk operation
#     output_val._trt, output_idx._trt = layer.get_output(0), layer.get_output(1)


import tensorrt as trt
from torch2trt import tensorrt_converter
from torch2trt.torch2trt import add_missing_trt_tensors


@tensorrt_converter('TCN.mnist_pixel.model.trt_argmax')
def convert_argmax(ctx):
    input_tensor = ctx.method_args[0]
    # The K in "topk"
    axis = ctx.method_kwargs['dim']
    output = ctx.method_return
    
    if axis is None:
        input_tensor = torch.flatten(input_tensor)
        axis = 0
    
    input_tensor_trt = add_missing_trt_tensors(ctx.network, [input_tensor])[0]

    layer = ctx.network.add_topk(input_tensor_trt, trt.TopKOperation.MAX, 1, axis)
    
    # layer.get_output(0) would give the max value
    output._trt = layer.get_output(1)

In [8]:
model2 = TCN(in_channels, n_classes, channel_sizes, kernel_size=args.ksize, trt=True, apply_max=True)
model2.load_state_dict(torch.load(model_path / (model_name+'.pt')), strict=False)
model2.eval()

if torch.cuda.is_available():
    model2.cuda()

model_trt_no_onnx = torch2trt(model2, [x], fp16=True)
torch.save(model_trt_no_onnx.state_dict(), model_path / (model_name+'_trt_no_onnx_fp16.pt'))


In [1]:
from torch2trt import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load(model_path / (model_name+'_trt.pt')))

NameError: name 'torch' is not defined

In [5]:
model_trt(x)

tensor([[0]], device='cuda:0', dtype=torch.int32)

In [9]:
num_samples = 0
# fig, ax = plt.subplots(1,1)

im_queue = [0 for i in range(28*28)]
data_queue = torch.zeros((1,1,28*28)).cuda()
curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28))
_, encoded_image = cv2.imencode('.png', curr_im)
im_bytes = encoded_image.tobytes()
im_disp = widgets.Image(value=im_bytes, width=200, height=200)

true_val = widgets.Label(value=f'True Label: N/A')
pred_val = widgets.Label(value=f'Predicted Label: N/A')
label_disp = widgets.VBox((true_val, pred_val))

display.display(im_disp)
display.display(label_disp)

for data, target in test_loader:
    im = data.squeeze().cpu().detach().numpy()
    rows, cols = im.shape
    im = (im - im.min())
    im = (im/im.max() * 255).astype('uint8')
    curr_im = np.ones(im.shape, dtype=np.uint8)*255
    
    true_val.value = f'True Label: {target.item()}'

    if torch.cuda.is_available():
        data, target = data.cuda(), target.cuda()
    
    data = data.view(data.size()[0], 1, -1)
    
    for i in range(data.size()[2]):
        num_samples += 1
        curr_row = i // cols
        curr_col = i % cols
        
        im_queue.append(im[curr_row,curr_col])
        im_queue = im_queue[1:]
        
        data_queue = data_queue.roll(-1, 2)
        data_queue[:,:,-1] = data[:,:,i]
        
        curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28))
        _, encoded_image = cv2.imencode('.png', curr_im)
        im_bytes = encoded_image.tobytes()
        im_disp.value = im_bytes
        
        # output = model(data[:,:,i].view(data.size()[0], data.size()[1], 1))
        output = model_trt_no_onnx(data_queue)
        # pred = output.max(2, keepdim=True)[1] #max returns values and indices
        pred = output
        
        if num_samples > model.receptive_field:
            pred_val.value = f'Predicted Label: {pred.item()}'
    
        time.sleep(0.001)
    
    
    for i in np.zeros(np.random.randint(50, 200)):
        im_queue.append(0)
        im_queue = im_queue[1:]
        data_queue = data_queue.roll(-1, 2)
        data_queue[:,:,-1] = i
        curr_im = np.array(im_queue, dtype=np.uint8).reshape((28,28)) 
        _, encoded_image = cv2.imencode('.png', curr_im)
        im_bytes = encoded_image.tobytes()
        im_disp.value = im_bytes
        true_val.value = f'True Label: N/A'
        num_samples += 1
        # output = model(torch.tensor([i], dtype=torch.float).cuda().view(1, 1, 1))
        output = model_trt_no_onnx(data_queue)
        # pred = output.max(2, keepdim=True)[1] #max returns values and indices
        pred = output
        if num_samples > model.receptive_field:
            pred_val.value = f'Predicted Label: {pred.item()}'

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf\x80H\x…

VBox(children=(Label(value='True Label: N/A'), Label(value='Predicted Label: N/A')))

KeyboardInterrupt: 

In [5]:
x = torch.ones((1,1,784)).cuda()
model_trt(x)

tensor([[[ -0.1782, -12.2473,  -4.7862,  -8.8309,  -8.3275,  -8.6724,  -4.0078,
          -10.8096,  -1.9996,  -7.1733]]], device='cuda:0')

In [2]:
from TCN.tcn_trt import TCNTRTNet

In [3]:
num_channels = [25] * 8
kernel_size = 7
# test_model = TCNTRTNet(1, num_channels, kernel_size=kernel_size).eval().cuda()
test_model = TCN(trt=True).eval().cuda()

In [4]:
x = torch.rand((1,1,28*28)).cuda()
test_model(x)

tensor([[0.0968, 0.0824, 0.1307, 0.1124, 0.1008, 0.0800, 0.1120, 0.0733, 0.1091,
         0.1024]], device='cuda:0', grad_fn=<SoftmaxBackward>)

In [5]:
test_model_trt = torch2trt(test_model, [x],  strict_type_constraints=True)

In [6]:
test_model_trt(x)

tensor([[0.0968, 0.0824, 0.1307, 0.1124, 0.1008, 0.0800, 0.1120, 0.0733, 0.1091,
         0.1024]], device='cuda:0')