In [1]:
%%bash
pip install timm -q
pip install onnx -q
pip install onnxruntime -q



In [2]:
import os
import re
import timm
import onnx
import torch
import numpy as np
import onnxruntime as ort

from torch import nn

In [3]:
def breaker(num: int=50, char: str="*") -> None: print("\n" + num*char + "\n")

In [4]:
class CFG(object):  
    def __init__(self, 
             in_channels: int=3, 
             size: int=384, 
             opset_version: int=9, 
             path: str=None):
        self.in_channels = in_channels
        self.size = size
        self.dummy = torch.randn(1, self.in_channels, self.size, self.size)
        self.opset_version = opset_version
        self.path = path

In [5]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
            
        self.encoder = timm.create_model("efficientnet_b4", pretrained=False)
        self.encoder = nn.Sequential(*[*self.encoder.children()][:-1])
        
        self.decoder = nn.Sequential()
        self.decoder.add_module("DC1", nn.ConvTranspose2d(in_channels=1792, out_channels=512, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN1", nn.ReLU())
        self.decoder.add_module("UP1", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC2", nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN2", nn.ReLU())
        self.decoder.add_module("UP2", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC3", nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN3", nn.ReLU())
        self.decoder.add_module("UP3", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC4", nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN4", nn.ReLU())
        self.decoder.add_module("UP4", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC5", nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN5", nn.ReLU())
        self.decoder.add_module("UP5", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC6", nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN6", nn.ReLU())
        self.decoder.add_module("UP6", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC7", nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN7", nn.ReLU())
        self.decoder.add_module("UP7", nn.Upsample(scale_factor=2))

    def forward(self, x):
        encoded = self.encoder(x)
        temp = encoded.unsqueeze(dim=2).unsqueeze(dim=3)
        decoded = self.decoder(temp)

        return encoded, decoded

In [6]:
class EncoderModel(nn.Module):
    def __init__(self):
        super(EncoderModel, self).__init__()
            
        self.encoder = timm.create_model("efficientnet_b4", pretrained=False)
        self.encoder = nn.Sequential(*[*self.encoder.children()][:-1])

    def forward(self, x):
        return self.encoder(x)

### **I1T1**

In [7]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=15, 
    path="../input/fds-en4-ae384-i1t1-s42/saves/state.pt",
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

torch.onnx.export(
    model=encoder_model, 
    args=cfg.dummy, 
    f="I1T1-EN384AE.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

onnx_model = onnx.load("I1T1-EN384AE.onnx")
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession("I1T1-EN384AE.onnx")

### **I2T1**

In [8]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=15, 
    path="../input/fds-en4-ae384-i2t1-s42/saves/state.pt",
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

torch.onnx.export(
    model=encoder_model, 
    args=cfg.dummy, 
    f="I2T1-EN384AE.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

onnx_model = onnx.load("I2T1-EN384AE.onnx")
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession("I2T1-EN384AE.onnx")

### **I3T1**

In [9]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=15, 
    path="../input/fds-en4-ae384-i3t1-s42/saves/state.pt",
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

torch.onnx.export(
    model=encoder_model, 
    args=cfg.dummy, 
    f="I3T1-EN384AE.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

onnx_model = onnx.load("I3T1-EN384AE.onnx")
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession("I3T1-EN384AE.onnx")

### **I4T1**

In [10]:
cfg = CFG(
    in_channels=3, 
    size=384, 
    opset_version=15, 
    path="../input/fds-en4-ae384-i4t1-s42/saves/state.pt",
)

model = Model()
model.load_state_dict(torch.load(cfg.path, map_location=torch.device("cpu"))["model_state_dict"])
model.eval()

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

torch.onnx.export(
    model=encoder_model, 
    args=cfg.dummy, 
    f="I4T1-EN384AE.onnx", 
    input_names=["input"], 
    output_names=["output"], 
    opset_version=cfg.opset_version,
    export_params=True,
    training=torch.onnx.TrainingMode.EVAL,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
    dynamic_axes={
      "input"  : {0 : "batch_size"},
      "output" : {0 : "batch_size"},
    }
)

onnx_model = onnx.load("I4T1-EN384AE.onnx")
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession("I4T1-EN384AE.onnx")