In [1]:
%%bash
pip install timm onnx onnxruntime -q

bash: /opt/conda/lib/libtinfo.so.6: no version information available (required by bash)


In [2]:
import os
import sys
import cv2
import json
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")
    
ort.set_default_logger_severity(3)

In [3]:
labels: dict = json.load(open("/kaggle/input/ct-dataframe/labels.json", "r"))
    
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

    
def sigmoid(x) -> np.ndarray:
    return 1 / (1 + np.exp(-x))


def get_image(path: str) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return image

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=1)

    def forward(self, x):
        return self.model(x)
    
    
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=3, 
        size: int=256, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
for v in ["a", "l"]:
    for i in range(1, 6):
        cfg = CFG(
            in_channels=3, 
            size=256, 
            opset_version=15, 
            path=f"/kaggle/input/ct-en4-a256-e10/saves/b{v}e_state_fold_{i}.pt"
        )
        
        model = Model()
        model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
        model.eval()

        clear_output()
        
        torch.onnx.export(
            model=model, 
            args=cfg.dummy, 
            f=f"onnx/b{v}e_model_f{i}.onnx", 
            input_names=["input"], 
            output_names=["output"], 
            opset_version=cfg.opset_version,
            export_params=True,
            training=torch.onnx.TrainingMode.EVAL,
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
            dynamic_axes={
              "input"  : {0 : "batch_size"},
              "output" : {0 : "batch_size"},
            }
        )

In [6]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.size: int = 256
        
        self.mean: list = [0.45546, 0.44390, 0.43109]
        self.std: list  = [0.25527, 0.25103, 0.25482]
        
        self.path: str = path
    
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> str:
        image = image / 255
        image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
        for i in range(image.shape[0]): image[i, :, :] = (image[i, :, :] - self.mean[i]) / self.std[i]
        image = np.expand_dims(image, axis=0)
        inputs = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        prob = sigmoid(self.ort_session.run(None, inputs)[0][0])
        if prob <= 0.5:
            return "Cars", 1-prob
        else:
            return "Tanks", prob

In [7]:
df = pd.read_csv("/kaggle/input/ct-dataframe/test.csv")


breaker()
for model_name in sorted(os.listdir("onnx")):
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    print(f"{model_name.upper()}\n")
    
    for i in range(9):
        index = r.randint(0, len(df)-1)

        filepath = df.iloc[index, 0]
        y_true   = df.iloc[index, 1]

        image = get_image(filepath)

        y_pred, _ = onnx_model.infer(image, labels)
        
        print(f"{labels[str(y_true)].title()}, {y_pred}")

    breaker()


**************************************************

BAE_MODEL_F1.ONNX

Tanks, Tanks
Cars, Cars
Tanks, Tanks
Cars, Cars
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Cars, Cars
Tanks, Tanks

**************************************************

BAE_MODEL_F2.ONNX

Tanks, Tanks
Cars, Cars
Cars, Cars
Tanks, Tanks
Cars, Cars
Cars, Cars
Cars, Cars
Cars, Cars
Cars, Cars

**************************************************

BAE_MODEL_F3.ONNX

Cars, Cars
Cars, Cars
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Cars, Cars
Tanks, Tanks
Cars, Cars

**************************************************

BAE_MODEL_F4.ONNX

Cars, Cars
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Cars, Cars
Cars, Cars
Cars, Cars

**************************************************

BAE_MODEL_F5.ONNX

Cars, Cars
Cars, Cars
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks
Cars, Cars
Tanks, Tanks
Tanks, Tanks
Tanks, Tanks

**************************************************

BLE_MODEL_F1.ONNX

Tanks, Tanks
Cars, Tanks
