In [1]:
%%bash
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import json
import onnx
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from torch import nn
from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")
    
ort.set_default_logger_severity(3)

In [3]:
labels: dict = {}
for i in range(26): labels[i] = chr(i + 65)
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

In [4]:
class Model(nn.Module):
    def __init__(self, filter_sizes: list, HL: list, DP: Union[float, None]=None):
        
        super(Model, self).__init__()
        
        self.features = nn.Sequential()
        self.features.add_module("CN1", nn.Conv2d(in_channels=1, out_channels=filter_sizes[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN1", nn.BatchNorm2d(num_features=filter_sizes[0], eps=1e-5))
        self.features.add_module("AN1", nn.ReLU())
        self.features.add_module("MP1", nn.MaxPool2d(kernel_size=(2, 2)))
        self.features.add_module("CN2", nn.Conv2d(in_channels=filter_sizes[0], out_channels=filter_sizes[1], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN2", nn.BatchNorm2d(num_features=filter_sizes[1], eps=1e-5))
        self.features.add_module("AN2", nn.ReLU())
        self.features.add_module("MP2", nn.MaxPool2d(kernel_size=(2, 2)))
        self.features.add_module("CN3", nn.Conv2d(in_channels=filter_sizes[1], out_channels=filter_sizes[2], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN3", nn.BatchNorm2d(num_features=filter_sizes[2], eps=1e-5))
        self.features.add_module("AN3", nn.ReLU())
        self.features.add_module("MP3", nn.MaxPool2d(kernel_size=(2, 2))) 
        
        self.classifier = nn.Sequential()
        if len(HL) == 0:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=26))
        elif len(HL) == 1:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=HL[0]))
            if isinstance(DP, float):
                self.classifier.add_module("DP1", nn.Dropout(p=DP))
            self.classifier.add_module("AN1", nn.ReLU())
            self.classifier.add_module("FC2", nn.Linear(in_features=HL[0], out_features=26))
        elif len(HL) == 2:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=HL[0]))
            if isinstance(DP, float):
                self.classifier.add_module("DP1", nn.Dropout(p=DP))
            self.classifier.add_module("AN1", nn.ReLU())
            self.classifier.add_module("FC2", nn.Linear(in_features=HL[0], out_features=HL[1]))
            if isinstance(DP, float):
                self.classifier.add_module("DP2", nn.Dropout(p=DP))
            self.classifier.add_module("AN2", nn.ReLU())
            self.classifier.add_module("FC3", nn.Linear(in_features=HL[1], out_features=26))
        self.classifier.add_module("Final Activation", nn.LogSoftmax(dim=1))
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        return self.classifier(x)

In [5]:
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=1, 
        size: int=28, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [6]:
for v in ["a", "l"]:
    for i in range(1, 6):
        cfg = CFG(
            opset_version=13, 
            path=f"/kaggle/input/azhwd-f-64-128-256-h1024-e10/saves/b{v}e_state_fold_{i}.pt"
        )

        model = Model(filter_sizes=[64, 128, 256], HL=[1024])
        model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
        model.eval()

        clear_output()
        
        # param_size: float = 0
        # for param in model.parameters():
        #     param_size += param.nelement() * param.element_size()

        # buffer_size: float = 0
        # for buffer in model.buffers():
        #     buffer_size += buffer.nelement() * buffer.element_size()

        # size_all_mb: float = (param_size + buffer_size) / 1024**2

        # breaker()
        # print(f"Model size: {size_all_mb:.3f} MB")
        # breaker()
        
        torch.onnx.export(
            model=model, 
            args=cfg.dummy, 
            f=f"onnx/b{v}e_model_f{i}.onnx", 
            input_names=["input"], 
            output_names=["output"], 
            opset_version=cfg.opset_version,
            export_params=True,
            training=torch.onnx.TrainingMode.EVAL,
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
            dynamic_axes={
              "input"  : {0 : "batch_size"},
              "output" : {0 : "batch_size"},
            }
        )

In [7]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.ort_session = None
        self.size: int = 28
        
        self.mean: list = 0.15344
        self.std: list  = 0.19509
        
        self.path: str = path
    
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:
        h, w, _ = image.shape
        
        image = image / 255
        # image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA)
        image = (image - self.mean) / self.std
        image = np.expand_dims(image, axis=0)
        inputs = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        return labels[np.argmax(self.ort_session.run(None, inputs))].title()

In [8]:
df = pd.read_csv("/kaggle/input/az-handwritten-alphabets-in-csv-format/A_Z Handwritten Data.csv")

images  = df.iloc[:, 1:].copy().values.reshape(df.shape[0], 1, 28, 28).astype("uint8")
y_trues = df.iloc[:, 0].copy().values

breaker()
for model_name in sorted(os.listdir("onnx")):
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    
    print(f"{model_name.upper()}\n")
    
    for i in range(9):
        index = r.randint(0, len(df)-1)

        y_pred = onnx_model.infer(images[index], labels)
        
        print(f"{labels[y_trues[index]].title()}, {y_pred}")

    breaker()


**************************************************

BAE_MODEL_F1.ONNX

O, O
P, P
U, U
W, W
S, S
O, O
R, R
V, V
S, S

**************************************************

BAE_MODEL_F2.ONNX

C, C
C, C
U, U
J, J
U, U
O, O
P, P
Z, Z
W, W

**************************************************

BAE_MODEL_F3.ONNX

M, M
N, N
W, W
S, S
P, P
N, N
U, N
N, N
Q, Q

**************************************************

BAE_MODEL_F4.ONNX

L, L
S, S
P, P
R, R
T, T
O, O
Y, Y
C, C
U, U

**************************************************

BAE_MODEL_F5.ONNX

C, C
C, C
T, T
O, O
C, C
B, B
S, S
S, S
Z, Z

**************************************************

BLE_MODEL_F1.ONNX

O, O
E, E
O, D
S, S
S, S
O, O
C, C
L, L
O, O

**************************************************

BLE_MODEL_F2.ONNX

E, E
R, R
W, W
U, U
E, E
S, S
R, R
A, A
K, K

**************************************************

BLE_MODEL_F3.ONNX

J, J
S, S
W, W
A, A
S, S
I, I
P, P
S, S
C, C

**************************************************

BLE_MODEL_F