In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.simplefilter('ignore')

import gc

from os import path
import sys
sys.path.append(path.abspath('..'))

In [None]:
import numpy as np
import onnx
from onnxsim import simplify
import torch
import cv2
from PIL import Image
import tensorrt as trt
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

import pycuda.autoinit

from src.transforms import torch_preprocessing
import src.common as common

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [None]:
DEVICE = 'cuda:0'
TORCH_FILE = '../models/resnet34.pth'
ONNX_FILE = '../models/resnet34.onnx'
TRT_FILE ='../models/resnet34.engine'
TRT_FILE_FP16 ='../models/resnet34_fp16.engine'

In [None]:
model = smp.Unet(
    encoder_name='resnet34',
    encoder_weights='imagenet',
)
_ = model.to(DEVICE)
_ = model.eval()

In [None]:
image = cv2.imread('../data/dog.jpg')[..., ::-1]
print(image.shape)
Image.fromarray(image)

In [None]:
torch_input_tensor = torch_preprocessing(image).to(DEVICE)

In [None]:
with torch.no_grad():
    torch_output_tensor = model(torch_input_tensor).cpu().detach().numpy()[0]

In [None]:
print(torch_output_tensor.shape)
_ = plt.imshow(sigmoid(torch_output_tensor)[0])
_ = plt.axis('off')

In [None]:
# сохраняем торчовый чекпоинт
torch.save(model, TORCH_FILE)

## Torch -> ONNX

In [None]:
dummy_input = torch.rand(1, 3, 224, 224, device=DEVICE)
torch.onnx.export(
    model,
    dummy_input,
    ONNX_FILE,
    verbose=True,
    input_names=['input'],
    output_names=['output'],
)

## Check ONNX

In [None]:
onnx_model = onnx.load(ONNX_FILE)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

In [None]:
onnx_model_simp, check = simplify(onnx_model)
print(check)
onnx.save(onnx_model_simp, ONNX_FILE)

## ONNX -> TRT

In [None]:
# fp32
engine = common.build_engine(
    ONNX_FILE,
    TRT_FILE,
)

In [None]:
# fp16
engine = common.build_engine(
    ONNX_FILE,
    TRT_FILE_FP16,
    fp16=True,
)