In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import onnx
import timm
import torch
import random as r
import numpy as np
import onnxruntime as ort

from typing import Union
from IPython.display import clear_output

import warnings
warnings.filterwarnings("ignore")

In [3]:
def breaker() -> None: print("\n" + 50*"*" + "\n")

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = torch.nn.Linear(in_features=self.model.classifier.in_features, out_features=1)

    def forward(self, x):
        return self.model(x)

    
    
class CFG(object):  
    def __init__(
        self, 
        in_channels: int=3, 
        size: int=256, 
        opset_version: int=9, 
        path: Union[str, None]=None
    ):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
cfg = CFG(
    in_channels=3, 
    size=224, 
    opset_version=13, 
    path=f"/kaggle/input/scv-en4-a224-e10/saves/ble_state_fold_4.pt"
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

clear_output()

torch.onnx.export(
    model=model, 
    args=cfg.dummy, 
    f=f"model-224.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

In [6]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=13, 
    path=f"/kaggle/input/scv-en4-a384-e10/saves/ble_state_fold_4.pt"
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

clear_output()

torch.onnx.export(
    model=model, 
    args=cfg.dummy, 
    f=f"model-384.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)