In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort

from typing import Union
from torchvision import models, transforms
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

In [3]:
labels: list = {
    0 : "Cat",
    1 : "Cow",
    2 : "Dog",
    3 : "Elephant",
    4 : "Panda",
}

    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=5)

    def forward(self, x):
        return self.model(x)


class CFG(object):  
    def __init__(self, 
             in_channels: int=3, 
             size: int=256, 
             opset_version: int=9, 
             path: str=None):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
MODEL_BASE_PATH: str = "../input/aic-en4-a384-e10/saves"
PT_MODEL_NAMES: list = os.listdir(MODEL_BASE_PATH)

if not os.path.exists("onnx"): os.makedirs("onnx")

for model_name in PT_MODEL_NAMES:
    cfg = CFG(
        in_channels=3, 
        size=384, 
        opset_version=13, 
        path=f"{MODEL_BASE_PATH}/{model_name}"
    )

    model = Model()
    model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
    model.eval()

    clear_output()

    # param_size = 0
    # for param in model.parameters():
    #     param_size += param.nelement() * param.element_size()

    # buffer_size = 0
    # for buffer in model.buffers():
    #     buffer_size += buffer.nelement() * buffer.element_size()

    # size_all_mb = (param_size + buffer_size) / 1024**2

    # breaker()
    # print(f"Model size: {size_all_mb:.3f} MB")
    # breaker()


    torch.onnx.export(
        model=model, 
        args=cfg.dummy, 
        f=f"onnx/{model_name[:-3].split('_')[0]}_f{model_name[:-3].split('_')[-1]}.onnx", 
        input_names=["input"], 
        output_names=["output"], 
        opset_version=cfg.opset_version,
        export_params=True,
        training=torch.onnx.TrainingMode.EVAL,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
        dynamic_axes={
          "input"  : {0 : "batch_size"},
          "output" : {0 : "batch_size"},
        }
    )

In [6]:
class OnnxModel(object):
    def __init__(self, path: Union[str, None]=None) -> None:
        self.ort_session = None
        self.size: int = 384
        
        self.mean: list = [0.52556, 0.50756, 0.44324]
        self.std: list  = [0.23221, 0.23140, 0.23923]
        
        self.path: str = path
        ort.set_default_logger_severity(3)
    
    def setup(self) -> None:
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:
        h, w, _ = image.shape

        image = image / 255
        image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
        for i in range(image.shape[0]): image[i, :, :] = (image[i, :, :] - self.mean[i]) / self.std[i]
        image = np.expand_dims(image, axis=0)
        input = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        return labels[np.argmax(self.ort_session.run(None, input))].title()

In [7]:
ONNX_MODEL_PATH: str = "onnx"
ONNX_MODEL_FILENAMES: list = sorted(os.listdir(ONNX_MODEL_PATH))
    

df = pd.read_csv("../input/aic-dataframe/test.csv")


for model_filename in ONNX_MODEL_FILENAMES:
    onnx_model = OnnxModel(f"{ONNX_MODEL_PATH}/{model_filename}")
    onnx_model.setup()
    
    breaker()
    print(f"{model_filename}\n")
    
    for i in range(9):
        index = r.randint(0, df.shape[0] - 1)

        filepath = df.iloc[index, 0]
        y_true = df.iloc[index, 1]

        image = get_image(filepath, cfg.size)
        y_pred = onnx_model.infer(image, labels)

        print(f"{i+1}. Actual     : {labels[y_true]}")
        print(f"{i+1}. Prediction : {y_pred}\n")
    
breaker()


**************************************************

bae_f1.onnx

1. Actual     : Elephant
1. Prediction : Elephant

2. Actual     : Cow
2. Prediction : Cow

3. Actual     : Panda
3. Prediction : Panda

4. Actual     : Panda
4. Prediction : Panda

5. Actual     : Elephant
5. Prediction : Elephant

6. Actual     : Cat
6. Prediction : Cat

7. Actual     : Cow
7. Prediction : Cow

8. Actual     : Dog
8. Prediction : Elephant

9. Actual     : Dog
9. Prediction : Dog


**************************************************

bae_f2.onnx

1. Actual     : Panda
1. Prediction : Panda

2. Actual     : Dog
2. Prediction : Dog

3. Actual     : Cat
3. Prediction : Cat

4. Actual     : Cow
4. Prediction : Cow

5. Actual     : Cat
5. Prediction : Cat

6. Actual     : Panda
6. Prediction : Panda

7. Actual     : Cow
7. Prediction : Cow

8. Actual     : Panda
8. Prediction : Panda

9. Actual     : Panda
9. Prediction : Panda


**************************************************

bae_f3.onnx

1. Actual     :