In [1]:
import torch
import onnxruntime
import numpy as np
import cv2
from albumentations.pytorch import ToTensorV2
import albumentations as A
import onnxruntime as onnxrt



In [2]:
def pre_process(numpy_data):
    transform = A.Compose(
                [
                    A.Resize(height=224, width=224),
                    A.Normalize(
                        mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225],
                        max_pixel_value=255., 
                        p=1.0
                    ),
                    ToTensorV2(),
                ])
    transform = transform(image=numpy_data)
    image = transform["image"]
    return image
    
def predict(numpy_data):
    input_data = pre_process(numpy_data)
    input_batch = input_data.unsqueeze(0)
    with open('imagenet-idx.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    onnx_session= onnxrt.InferenceSession("models\\tensorrt_fp16_model\\1\\model.onnx")
    onnx_output = onnx_session.run(None, {'input' : input_batch.numpy()})
    onnx_output = torch.FloatTensor(np.array(onnx_output)).reshape(1,-1)
    prob = torch.nn.functional.softmax(onnx_output, dim=1)[0] * 100
    _, indices = torch.sort(prob, descending=True)
    return [(classes[idx], prob[idx].item()) for idx in indices[:10]]

In [3]:
image = cv2.imread('images\\car.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
transformed_image = pre_process(image)

In [4]:
pred = predict(image)

In [5]:
pred

[('sports car, sport car', 95.72743225097656),
 ('racer, race car, racing car', 2.1000659465789795),
 ('car wheel', 1.018267035484314),
 ('convertible', 0.9588374495506287),
 ('grille, radiator grille', 0.10270078480243683),
 ('pickup, pickup truck', 0.04620283469557762),
 ('beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
  0.014915943145751953),
 ('cab, hack, taxi, taxicab', 0.005070718936622143),
 ('passenger car, coach, carriage', 0.0028018183074891567),
 ('limousine, limo', 0.001993861049413681)]