In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.simplefilter('ignore')

import gc

from os import path
import sys
sys.path.append(path.abspath('..'))

In [None]:
import numpy as np
import onnx
import torch
import cv2
from PIL import Image
import tensorrt as trt
import matplotlib.pyplot as plt
from timm import create_model

import pycuda.autoinit

from src.transforms import torch_preprocessing, trt_preprocessing
import src.common as common

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [None]:
DEVICE = 'cuda:0'
TORCH_FILE = '../models/gernet_l.pth'

ONNX_FILE_STATIC = '../models/gernet_l_static.onnx'
TRT_FILE_STATIC = '../models/gernet_l_static.engine'
TRT_FILE_FP16_STATIC = '../models/gernet_l_fp16_static.engine'
TRT_FILE_INT8_STATIC = '../models/gernet_l_int8_static.engine'

ONNX_FILE_DYNAMIC = '../models/gernet_l_dynamic.onnx'
TRT_FILE_DYNAMIC = '../models/gernet_l_dynamic.engine'
TRT_FILE_FP16_DYNAMIC = '../models/gernet_l_fp16_dynamic.engine'
TRT_FILE_INT8_DYNAMIC = '../models/gernet_l_int8_dynamic.engine'

In [None]:
model = create_model('gernet_l', pretrained=True)
_ = model.to(DEVICE)
_ = model.eval()

In [None]:
image = cv2.imread('../data/dog.jpg')[..., ::-1]
print(image.shape)
Image.fromarray(image)

In [None]:
torch_input_tensor = torch_preprocessing(image).to(DEVICE)

In [None]:
with torch.no_grad():
    torch_output_tensor = model(torch_input_tensor).cpu().detach().numpy()[0]
print(softmax(torch_output_tensor).argmax())

In [None]:
# сохраняем торчовый чекпоинт
torch.save(model, TORCH_FILE)

## Статический размер батча

### Torch -> ONNX

In [None]:
# сохраняем статичный onnx файл
dummy_input = torch.rand(1, 3, 224, 224, device=DEVICE)
torch.onnx.export(
    model,
    dummy_input,
    ONNX_FILE_STATIC,
    verbose=True,
    input_names=['input'],
    output_names=['output'],
)

### Check ONNX

In [None]:
onnx_model = onnx.load(ONNX_FILE_STATIC)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

### ONNX -> TensorRT

In [None]:
# fp32
engine = common.build_engine(
    ONNX_FILE_STATIC,
    TRT_FILE_STATIC,
)

In [None]:
# fp16
engine = common.build_engine(
    ONNX_FILE_STATIC,
    TRT_FILE_FP16_STATIC,
    fp16=True,
)

In [None]:
# Грузим и готовим изображение для калибровки
image = cv2.imread('../data/cat.jpeg')[..., ::-1]
trt_input_tensor = trt_preprocessing(image)
print(image.shape)
Image.fromarray(image)

In [None]:
# инициализируем калибратор
calibrator = common.EntropyCalibrator(trt_input_tensor, '../models/calibrator')

# int8
engine = common.build_engine(
    ONNX_FILE_STATIC,
    TRT_FILE_INT8_STATIC,
    int8=True,
    int8_calibrator=calibrator,
)

## Динамический размер батча

### Torch -> ONNX

In [None]:
# сохраняем onnx файл
dummy_input = torch.rand(1, 3, 224, 224, device=DEVICE)
torch.onnx.export(
    model,
    dummy_input,
    ONNX_FILE_DYNAMIC,
    verbose=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes = {'input': [0], 'output': [0]}, # динамический батч
)

### Check ONNX

In [None]:
onnx_model = onnx.load(ONNX_FILE_DYNAMIC)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

### ONNX -> TRT

In [None]:
# fp32
engine = common.build_engine(
    ONNX_FILE_DYNAMIC,
    TRT_FILE_DYNAMIC,
    max_batch_size=5,
    min_shape=(1, 3, 224, 224),
    opt_shape=(5, 3, 224, 224),
    max_shape=(5, 3, 224, 224),
)

In [None]:
# fp16
engine = common.build_engine(
    ONNX_FILE_DYNAMIC,
    TRT_FILE_FP16_DYNAMIC,
    max_batch_size=5,
    min_shape=(1, 3, 224, 224),
    opt_shape=(5, 3, 224, 224),
    max_shape=(5, 3, 224, 224),
    fp16=True,
)

In [None]:
# котейка для калибровки
image = cv2.imread('../data/cat.jpeg')[..., ::-1]
trt_input_tensor = trt_preprocessing(image)
print(image.shape)
Image.fromarray(image)

In [None]:
calibrator = common.EntropyCalibrator(trt_input_tensor, '../models/calibrator')

engine = common.build_engine(
    ONNX_FILE_DYNAMIC,
    TRT_FILE_INT8_DYNAMIC,
    int8=True,
    int8_calibrator=calibrator,
    max_batch_size=5,
    min_shape=(1, 3, 224, 224),
    opt_shape=(5, 3, 224, 224),
    max_shape=(5, 3, 224, 224),
)