In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import json
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")
    
ort.set_default_logger_severity(3)

In [3]:
labels: dict = json.load(open("/kaggle/input/mmnist-dataframe/labels.json", "r"))
    
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")


def get_image(path: str) -> np.ndarray:
     return cv2.imread(path, cv2.IMREAD_GRAYSCALE)

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.conv_stem = torch.nn.Conv2d(
            in_channels=1,
            out_channels=self.model.conv_stem.out_channels,
            kernel_size=self.model.conv_stem.kernel_size,
            stride=self.model.conv_stem.stride,
            padding=self.model.conv_stem.padding
        )
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=6)

    def forward(self, x):
        return self.model(x)

    
    
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=3, 
        size: int=256, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
for v in ["a", "l"]:
    for i in range(1, 6):
        cfg = CFG(
            in_channels=1, 
            size=64, 
            opset_version=13, 
            path=f"/kaggle/input/mmnist-en4-a64-e10-gray/saves/b{v}e_state_fold_{i}.pt"
        )

        model = Model()
        model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
        model.eval()

        clear_output()

        # param_size: float = 0
        # for param in model.parameters():
        #     param_size += param.nelement() * param.element_size()

        # buffer_size: float = 0
        # for buffer in model.buffers():
        #     buffer_size += buffer.nelement() * buffer.element_size()

        # size_all_mb: float = (param_size + buffer_size) / 1024**2

        # breaker()
        # print(f"Model size: {size_all_mb:.3f} MB")
        # breaker()

        torch.onnx.export(
            model=model, 
            args=cfg.dummy, 
            f=f"onnx/b{v}e_model_f{i}.onnx", 
            input_names=["input"], 
            output_names=["output"], 
            opset_version=cfg.opset_version,
            export_params=True,
            training=torch.onnx.TrainingMode.EVAL,
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
            dynamic_axes={
              "input"  : {0 : "batch_size"},
              "output" : {0 : "batch_size"},
            }
        )

In [6]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.ort_session = None
        self.size: int = 64
        
        self.mean: float = 0.35835
        self.std: float  = 0.20144
        
        self.path: str = path
    
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:
        h, w = image.shape

        image = image / 255
        if image.shape != (64, 64):
            image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA)
        image = (image - self.mean) / self.std
        image = np.expand_dims(np.expand_dims(image, axis=0), axis=0)
        inputs = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        result = self.ort_session.run(None, inputs)
        odds = np.exp(np.max(result))
        return labels[str(np.argmax(result))].title(), odds / (1 + odds)

In [7]:
df = pd.read_csv("/kaggle/input/mmnist-dataframe/dataframe.csv")


breaker()
for model_name in sorted(os.listdir("onnx")):
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    
    print(f"{model_name.upper()}\n")
    
    for i in range(9):
        index = r.randint(0, len(df)-1)

        filepath = df.iloc[index, 0]
        y_true   = df.iloc[index, 1]

        image = get_image(filepath)

        y_pred = onnx_model.infer(image, labels)
        
        print(f"{labels[str(y_true)].title()}, {y_pred}")

    breaker()


**************************************************

BAE_MODEL_F1.ONNX

Abdomenct, ('Abdomenct', 0.9999999999990807)
Cxr, ('Cxr', 0.9999995469187701)
Headct, ('Headct', 0.9999999999995612)
Chestct, ('Chestct', 0.9999999999999301)
Breastmri, ('Breastmri', 1.0)
Chestct, ('Chestct', nan)
Headct, ('Headct', 0.9999998598491304)
Chestct, ('Chestct', nan)
Abdomenct, ('Abdomenct', 0.9999999887763059)

**************************************************

BAE_MODEL_F2.ONNX

Headct, ('Headct', 0.9999841437409523)
Cxr, ('Cxr', 0.9999999999683192)
Chestct, ('Chestct', 0.9999999998017854)
Abdomenct, ('Abdomenct', 0.9999998538547348)
Hand, ('Hand', 0.9999989838097392)
Chestct, ('Chestct', 1.0)
Abdomenct, ('Abdomenct', 0.9999999998329404)
Abdomenct, ('Abdomenct', 0.9999998970192956)
Abdomenct, ('Abdomenct', 0.9999999999318855)

**************************************************

BAE_MODEL_F3.ONNX

Breastmri, ('Breastmri', 0.9999999964273248)
Cxr, ('Cxr', 0.9999999903002726)
Abdomenct, ('Abdomenct', 0.9