In [1]:
%%bash
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import sys
import cv2
import json
import onnx
import torch
import random as r
import numpy as np
import pandas as pd
import onnxruntime as ort
import matplotlib.pyplot as plt

from torch import nn
from typing import Union
from torchvision import models
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

if not os.path.exists("onnx"): os.makedirs("onnx")
    
ort.set_default_logger_severity(3)

In [3]:
labels: dict = {}
for i in range(26): labels[i] = chr(i + 65)
    
    
def breaker() -> None:
    print("\n" + 50*"*" + "\n")

    
def get_image(path: str, size: int=28) -> np.ndarray:
    return cv2.resize(src=cv2.imread(path, cv2.IMREAD_GRAYSCALE), dsize=(size, size), interpolation=cv2.INTER_CUBIC).reshape(1, size, size)

In [4]:
class Model(nn.Module):
    def __init__(self, filter_sizes: list, HL: list, DP: Union[float, None]=None):
        
        super(Model, self).__init__()
        
        self.features = nn.Sequential()
        self.features.add_module("CN1", nn.Conv2d(in_channels=1, out_channels=filter_sizes[0], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN1", nn.BatchNorm2d(num_features=filter_sizes[0], eps=1e-5))
        self.features.add_module("AN1", nn.ReLU())
        self.features.add_module("MP1", nn.MaxPool2d(kernel_size=(2, 2)))
        self.features.add_module("CN2", nn.Conv2d(in_channels=filter_sizes[0], out_channels=filter_sizes[1], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN2", nn.BatchNorm2d(num_features=filter_sizes[1], eps=1e-5))
        self.features.add_module("AN2", nn.ReLU())
        self.features.add_module("MP2", nn.MaxPool2d(kernel_size=(2, 2)))
        self.features.add_module("CN3", nn.Conv2d(in_channels=filter_sizes[1], out_channels=filter_sizes[2], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.features.add_module("BN3", nn.BatchNorm2d(num_features=filter_sizes[2], eps=1e-5))
        self.features.add_module("AN3", nn.ReLU())
        self.features.add_module("MP3", nn.MaxPool2d(kernel_size=(2, 2))) 
        
        self.classifier = nn.Sequential()
        if len(HL) == 0:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=26))
        elif len(HL) == 1:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=HL[0]))
            if isinstance(DP, float):
                self.classifier.add_module("DP1", nn.Dropout(p=DP))
            self.classifier.add_module("AN1", nn.ReLU())
            self.classifier.add_module("FC2", nn.Linear(in_features=HL[0], out_features=26))
        elif len(HL) == 2:
            self.classifier.add_module("FC1", nn.Linear(in_features=filter_sizes[2]*3*3, out_features=HL[0]))
            if isinstance(DP, float):
                self.classifier.add_module("DP1", nn.Dropout(p=DP))
            self.classifier.add_module("AN1", nn.ReLU())
            self.classifier.add_module("FC2", nn.Linear(in_features=HL[0], out_features=HL[1]))
            if isinstance(DP, float):
                self.classifier.add_module("DP2", nn.Dropout(p=DP))
            self.classifier.add_module("AN2", nn.ReLU())
            self.classifier.add_module("FC3", nn.Linear(in_features=HL[1], out_features=26))
        self.classifier.add_module("Final Activation", nn.LogSoftmax(dim=1))
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        return self.classifier(x)

In [5]:
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=1, 
        size: int=28, 
        opset_version: int=9, 
        path: str=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [6]:
for v in ["a", "l"]:
    for i in range(1, 6):
        cfg = CFG(
            opset_version=13, 
            path=f"/kaggle/input/azhwd-f-64-128-256-h1024-e10/saves/b{v}e_state_fold_{i}.pt"
        )

        model = Model(filter_sizes=[64, 128, 256], HL=[1024])
        model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
        model.eval()

        clear_output()
        
        # param_size: float = 0
        # for param in model.parameters():
        #     param_size += param.nelement() * param.element_size()

        # buffer_size: float = 0
        # for buffer in model.buffers():
        #     buffer_size += buffer.nelement() * buffer.element_size()

        # size_all_mb: float = (param_size + buffer_size) / 1024**2

        # breaker()
        # print(f"Model size: {size_all_mb:.3f} MB")
        # breaker()
        
        torch.onnx.export(
            model=model, 
            args=cfg.dummy, 
            f=f"onnx/b{v}e_model_f{i}.onnx", 
            input_names=["input"], 
            output_names=["output"], 
            opset_version=cfg.opset_version,
            export_params=True,
            training=torch.onnx.TrainingMode.EVAL,
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
            dynamic_axes={
              "input"  : {0 : "batch_size"},
              "output" : {0 : "batch_size"},
            }
        )

In [7]:
class OnnxModel(object):
    def __init__(self, path: str) -> None:
        self.ort_session = None
        self.size: int = 28
        
        self.mean: list = 0.15344
        self.std: list  = 0.19509
        
        self.path: str = path
    
        model = onnx.load(self.path)
        onnx.checker.check_model(model)
        self.ort_session = ort.InferenceSession(self.path)
    
    def infer(self, image: np.ndarray, labels: dict) -> np.ndarray:        
        image = image / 255
        # image = cv2.resize(src=image, dsize=(self.size, self.size), interpolation=cv2.INTER_AREA)
        image = (image - self.mean) / self.std
        image = np.expand_dims(image, axis=0)
        inputs = {self.ort_session.get_inputs()[0].name : image.astype("float32")}
        return labels[np.argmax(self.ort_session.run(None, inputs))].title()

In [8]:
df = pd.read_csv("/kaggle/input/az-handwritten-alphabets-in-csv-format/A_Z Handwritten Data.csv")

images  = df.iloc[:, 1:].copy().values.reshape(df.shape[0], 1, 28, 28).astype("uint8")
y_trues = df.iloc[:, 0].copy().values

breaker()
for model_name in sorted(os.listdir("onnx")):
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    
    print(f"{model_name.upper()}\n")
    
    for i in range(9):
        index = r.randint(0, len(df)-1)

        y_pred = onnx_model.infer(images[index], labels)
        
        print(f"{labels[y_trues[index]].title()}, {y_pred}")

    breaker()


**************************************************

BAE_MODEL_F1.ONNX

R, R
N, N
O, O
T, T
L, L
X, X
U, U
O, O
A, A

**************************************************

BAE_MODEL_F2.ONNX

S, S
P, P
C, C
W, W
L, L
O, O
B, B
T, T
Y, Y

**************************************************

BAE_MODEL_F3.ONNX

V, V
A, A
P, P
Y, Y
A, A
K, K
U, U
U, U
M, M

**************************************************

BAE_MODEL_F4.ONNX

N, N
V, V
A, A
N, N
O, O
R, R
R, R
M, M
C, C

**************************************************

BAE_MODEL_F5.ONNX

U, U
C, C
U, U
C, C
S, S
S, S
O, O
T, T
Z, Z

**************************************************

BLE_MODEL_F1.ONNX

T, T
O, O
S, S
P, P
L, L
N, N
O, O
N, N
O, O

**************************************************

BLE_MODEL_F2.ONNX

S, S
W, W
M, M
A, A
N, N
P, P
D, D
R, R
S, S

**************************************************

BLE_MODEL_F3.ONNX

E, E
L, L
A, A
O, O
O, O
N, N
T, T
S, S
Z, Z

**************************************************

BLE_MODEL_F

In [9]:
filepaths = sorted([os.path.join("/kaggle/input/azhwdtest", name) for name in os.listdir("/kaggle/input/azhwdtest") if "All" not in name])

breaker()
for model_name in sorted(os.listdir("onnx")):
    
    onnx_model = OnnxModel(f"onnx/{model_name}")
    
    print(f"{model_name.upper()}\n")
    
    count: int = 0
    for filepath in filepaths:
        index = r.randint(0, len(df)-1)

        image = get_image(filepath)
        y_pred = onnx_model.infer(image, labels)
        
        print(f"{filepath.split('/')[-1][0]}, {y_pred}")
        
        if filepath.split('/')[-1][0] == y_pred:
            count += 1
   
    print(f"\nAccuracy : {count / len(filepaths):.5f}")
    breaker()


**************************************************

BAE_MODEL_F1.ONNX

A, A
B, B
C, D
D, D
E, E
F, E
G, G
H, H
I, I
J, J
K, K
L, L
M, M
N, N
O, O
P, P
Q, Q
R, R
S, S
T, T
U, U
V, M
W, W
X, X
Y, Y
Z, Z

Accuracy : 0.88462

**************************************************

BAE_MODEL_F2.ONNX

A, A
B, B
C, D
D, D
E, E
F, F
G, Q
H, H
I, I
J, J
K, K
L, L
M, M
N, N
O, O
P, P
Q, Q
R, R
S, S
T, T
U, U
V, M
W, M
X, X
Y, Y
Z, Z

Accuracy : 0.84615

**************************************************

BAE_MODEL_F3.ONNX

A, A
B, B
C, D
D, D
E, E
F, A
G, Q
H, H
I, I
J, J
K, K
L, L
M, A
N, N
O, O
P, P
Q, Q
R, R
S, S
T, T
U, U
V, H
W, M
X, X
Y, Y
Z, Z

Accuracy : 0.76923

**************************************************

BAE_MODEL_F4.ONNX

A, A
B, B
C, D
D, D
E, E
F, R
G, Q
H, H
I, I
J, J
K, K
L, L
M, A
N, N
O, O
P, R
Q, Q
R, R
S, S
T, T
U, U
V, M
W, A
X, X
Y, Y
Z, Z

Accuracy : 0.73077

**************************************************

BAE_MODEL_F5.ONNX

A, A
B, B
C, O
D, D
E, E
F, A
G, Q
H, H
