In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")

In [3]:
labels: dict = {
    0 : "Anime",
    1 : "Cartoon",
    2 : "Human"
}
    
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")


def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


def show_image(image: np.ndarray, cmap: str="gnuplot2", title: Union[str, None]=None) -> None:
    plt.figure()
    plt.imshow(image, cmap=cmap)
    plt.axis("off")
    if title: plt.title(title)
    plt.show()

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=3)

    def forward(self, x):
        return self.model(x)

    
    
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=3, 
        size: int=256, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
MODEL_BASE_PATH: str = "/kaggle/input/ach-en4-a384-e10/saves"
MODEL_NAMES: list = os.listdir(MODEL_BASE_PATH)
    
for model_name in MODEL_NAMES:

    cfg = CFG(
        in_channels=3, 
        size=384, 
        opset_version=13, 
        path=f"/kaggle/input/ach-en4-a384-e10/saves/{model_name}"
    )
    
    model = Model()
    model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
    model.eval()

    clear_output()
    
#     param_size = 0
#     for param in model.parameters():
#         param_size += param.nelement() * param.element_size()

#     buffer_size = 0
#     for buffer in model.buffers():
#         buffer_size += buffer.nelement() * buffer.element_size()

#     size_all_mb = (param_size + buffer_size) / 1024**2

#     breaker()
#     print(f"Model size: {size_all_mb:.3f} MB")
#     breaker()

    torch.onnx.export(
        model=model, 
        args=cfg.dummy, 
        f=f"onnx/ach-en4-{model_name.split('_')[0]}-f{model_name.split('_')[-1]}.onnx", 
        input_names=["input"], 
        output_names=["output"], 
        opset_version=cfg.opset_version,
        export_params=True,
        training=torch.onnx.TrainingMode.EVAL,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
        dynamic_axes={
          "input"  : {0 : "batch_size"},
          "output" : {0 : "batch_size"},
        }
    )


In [6]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.ort_session = None
        self.size: int = 384
        
        self.mean: list = [0.59090, 0.51873, 0.48801]
        self.std: list  = [0.26181, 0.25198, 0.25229]
        
        self.path: str = path

        ort.set_default_logger_severity(3)
    
    def setup(self) -> None:
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:
        h, w, _ = image.shape

        image = image / 255
        image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
        for i in range(image.shape[0]): image[i, :, :] = (image[i, :, :] - self.mean[i]) / self.std[i]
        image = np.expand_dims(image, axis=0)
        input = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        return labels[np.argmax(self.ort_session.run(None, input))].title()

In [7]:
TEST_BASE_PATH: str = "/kaggle/input/anime-and-cartoon-image-classification/Training Data"

    
breaker()
for model_name in os.listdir("onnx"):
    
    print(f"{model_name}\n")
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    onnx_model.setup()
    
    for i in range(9):
        
        folder_1 = os.listdir(TEST_BASE_PATH)[r.randint(0, len(os.listdir(TEST_BASE_PATH)) - 1)]
        folder_2 = os.listdir(f"{TEST_BASE_PATH}/{folder_1}")[r.randint(0, len(os.listdir(f"{TEST_BASE_PATH}/{folder_1}")) - 1)]
        filename = os.listdir(f"{TEST_BASE_PATH}/{folder_1}/{folder_2}")[r.randint(0, len(os.listdir(f"{TEST_BASE_PATH}/{folder_1}/{folder_2}")) - 1)]

        filepath = f"{TEST_BASE_PATH}/{folder_1}/{folder_2}/{filename}"

        image = get_image(filepath, cfg.size)

        y_pred = onnx_model.infer(image, labels)

        print(f"{i+1}. Actual     : {folder_1}") 
        print(f"{i+1}. Prediction : {y_pred}")
        print("-")

    breaker()


**************************************************

ach-en4-ble-f4.pt.onnx

1. Actual     : Cartoon
1. Prediction : Cartoon
-
2. Actual     : Anime
2. Prediction : Cartoon
-
3. Actual     : Anime
3. Prediction : Cartoon
-
4. Actual     : Anime
4. Prediction : Cartoon
-
5. Actual     : Cartoon
5. Prediction : Cartoon
-
6. Actual     : Anime
6. Prediction : Cartoon
-
7. Actual     : Cartoon
7. Prediction : Cartoon
-
8. Actual     : Cartoon
8. Prediction : Cartoon
-
9. Actual     : Cartoon
9. Prediction : Cartoon
-

**************************************************

ach-en4-bae-f1.pt.onnx

1. Actual     : Cartoon
1. Prediction : Cartoon
-
2. Actual     : Cartoon
2. Prediction : Cartoon
-
3. Actual     : Cartoon
3. Prediction : Cartoon
-
4. Actual     : Cartoon
4. Prediction : Cartoon
-
5. Actual     : Cartoon
5. Prediction : Cartoon
-
6. Actual     : Cartoon
6. Prediction : Cartoon
-
7. Actual     : Cartoon
7. Prediction : Cartoon
-
8. Actual     : Anime
8. Prediction : Cartoon
-
9. Act