In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import onnx
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort

from torchvision import models, transforms
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

In [3]:
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4")
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=11)

    def forward(self, x):
        return self.model(x)


class CFG(object):  
    def __init__(self, 
             in_channels: int=3, 
             size: int=256, 
             opset_version: int=9, 
             path: str=None):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=13, 
    path="../input/wic-en4-a384/saves/ble_state_fold_1.pt"
)


model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

clear_output()

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()

buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2

breaker()
print(f"Model size: {size_all_mb:.3f} MB")
breaker()


**************************************************

Model size: 67.496 MB

**************************************************



In [6]:
torch.onnx.export(
    model=model, 
    args=cfg.dummy, 
    f="wic-en4-f1.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

In [7]:
class OnnxModel(object):
    def __init__(self) -> None:
        self.ort_session = None
        self.size: int = 384
            
        # self.mean: list = [0.5, 0.5, 0.5]
        # self.std: list  = [0.5, 0.5, 0.5]
        
        self.mean: list = [0.51684, 0.52503, 0.50567]
        self.std: list  = [0.19350, 0.18743, 0.19404]
        
        self.path: str = "wic-en4-f1.onnx"
        self.labels = {
            "0" : "dew",
            "1" : "fogsmog",
            "2" : "frost",
            "3" : "glaze",
            "4" : "hail",
            "5" : "lightning",
            "6" : "rain",
            "7" : "rainbow",
            "8" : "rime",
            "9" : "sandstorm",
            "10" : "snow",
        }
        ort.set_default_logger_severity(3)
    
    def setup(self) -> None:
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray) -> np.ndarray:
        h, w, _ = image.shape

        image = image / 255
        image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
        for i in range(image.shape[0]): image[i, :, :] = (image[i, :, :] - self.mean[i]) / self.std[i]
        image = np.expand_dims(image, axis=0)
        input = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        return self.labels[str(np.argmax(self.ort_session.run(None, input)))].title()

    
onnx_model = OnnxModel()
onnx_model.setup()

In [8]:
df = pd.read_csv("../input/wic-dataframe/data.csv")

breaker()
for _ in range(10):
    index = r.randint(0, df.shape[0] - 1)

    folder_name = df.folder_names[index]
    filename = df.filenames[index]

    image = get_image(f"../input/weather-dataset/dataset/{folder_name}/{filename}", cfg.size)

    label = onnx_model.infer(image)

    print(f"Actual     : {folder_name.title()}")
    print(f"Predcition : {label}")
    breaker()


**************************************************

Actual     : Dew
Predcition : Dew

**************************************************

Actual     : Sandstorm
Predcition : Sandstorm

**************************************************

Actual     : Sandstorm
Predcition : Sandstorm

**************************************************

Actual     : Hail
Predcition : Hail

**************************************************

Actual     : Rime
Predcition : Rime

**************************************************

Actual     : Glaze
Predcition : Glaze

**************************************************

Actual     : Rime
Predcition : Rime

**************************************************

Actual     : Frost
Predcition : Frost

**************************************************

Actual     : Rime
Predcition : Rime

**************************************************

Actual     : Frost
Predcition : Frost

**************************************************

