In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")

In [3]:
labels: dict = {
    0 : "Anime",
    1 : "Cartoon",
    2 : "Human"
}
    
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

    
def sigmoid(x: Union[float, list, tuple, np.ndarray]):
    if isinstance(x, float):
        return 1 / (1 + np.exp(-x))
    else:
        return np.array([1 / (1 + np.exp(-y)) for y in x]).squeeze()


def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


def show_image(image: np.ndarray, cmap: str="gnuplot2", title: Union[str, None]=None) -> None:
    plt.figure()
    plt.imshow(image, cmap=cmap)
    plt.axis("off")
    if title: plt.title(title)
    plt.show()

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=3)

    def forward(self, x):
        return self.model(x)

    
    
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=3, 
        size: int=256, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=13, 
    path=f"/kaggle/input/ach-en4-a384-e10/saves/ble_state_fold_4.pt"
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

clear_output()
    
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2

breaker()
print(f"Model size: {size_all_mb:.3f} MB")
breaker()

torch.onnx.export(
    model=model, 
    args=cfg.dummy, 
    f=f"ach-en4-ble-f4.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)


**************************************************

Model size: 67.442 MB

**************************************************



In [6]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.ort_session = None
        self.size: int = 384
        
        self.mean: list = [0.59090, 0.51873, 0.48801]
        self.std: list  = [0.26181, 0.25198, 0.25229]
        
        self.path: str = path

        ort.set_default_logger_severity(3)
    
    def setup(self) -> None:
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:
        h, w, _ = image.shape

        image = image / 255
        image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
        for i in range(image.shape[0]): image[i, :, :] = (image[i, :, :] - self.mean[i]) / self.std[i]
        image = np.expand_dims(image, axis=0)
        inputs = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        result = self.ort_session.run(None, inputs)
        return result, sigmoid(result), labels[np.argmax(result)].title()

    
onnx_model = OnnxModel("ach-en4-ble-f4.onnx")
onnx_model.setup()

In [7]:
TEST_BASE_PATH: str = "/kaggle/input/anime-and-cartoon-image-classification/Training Data"

    
breaker()
for i in range(25):

    folder_1 = os.listdir(TEST_BASE_PATH)[r.randint(0, len(os.listdir(TEST_BASE_PATH)) - 1)]
    folder_2 = os.listdir(f"{TEST_BASE_PATH}/{folder_1}")[r.randint(0, len(os.listdir(f"{TEST_BASE_PATH}/{folder_1}")) - 1)]
    filename = os.listdir(f"{TEST_BASE_PATH}/{folder_1}/{folder_2}")[r.randint(0, len(os.listdir(f"{TEST_BASE_PATH}/{folder_1}/{folder_2}")) - 1)]

    filepath = f"{TEST_BASE_PATH}/{folder_1}/{folder_2}/{filename}"

    image = get_image(filepath, cfg.size)

    result, probs, y_pred = onnx_model.infer(image, labels)
    
    print(f"{i+1}")
    print(f"p(Anime)   : {probs[0]:.5f}")
    print(f"p(Cartoon) : {probs[1]:.5f}")
    print(f"p(Human)   : {probs[2]:.5f}")
    print("-")

breaker()


**************************************************

1
p(Anime)   : 0.00003
p(Cartoon) : 1.00000
p(Human)   : 0.00003
-
2
p(Anime)   : 0.00058
p(Cartoon) : 1.00000
p(Human)   : 0.00047
-
3
p(Anime)   : 0.00014
p(Cartoon) : 1.00000
p(Human)   : 0.00014
-
4
p(Anime)   : 0.00294
p(Cartoon) : 0.99998
p(Human)   : 0.00081
-
5
p(Anime)   : 0.00018
p(Cartoon) : 1.00000
p(Human)   : 0.00025
-
6
p(Anime)   : 0.00328
p(Cartoon) : 0.99999
p(Human)   : 0.00070
-
7
p(Anime)   : 0.00015
p(Cartoon) : 1.00000
p(Human)   : 0.00006
-
8
p(Anime)   : 0.00127
p(Cartoon) : 0.99999
p(Human)   : 0.00058
-
9
p(Anime)   : 0.00001
p(Cartoon) : 1.00000
p(Human)   : 0.00001
-
10
p(Anime)   : 0.00002
p(Cartoon) : 1.00000
p(Human)   : 0.00007
-
11
p(Anime)   : 0.00012
p(Cartoon) : 1.00000
p(Human)   : 0.00021
-
12
p(Anime)   : 0.00010
p(Cartoon) : 1.00000
p(Human)   : 0.00009
-
13
p(Anime)   : 0.00084
p(Cartoon) : 1.00000
p(Human)   : 0.00034
-
14
p(Anime)   : 0.00012
p(Cartoon) : 1.00000
p(Human)   : 0.00007
-
15
p