In [1]:
%%bash
pip install timm -q



In [2]:
import os
import re
import cv2
import timm
import torch
import random as r
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from time import time
from typing import Union
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torchvision import models, transforms

from IPython.display import clear_output

In [3]:
SEED: int = 42
SIZE: int = 384


def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


def show_image(
    image: np.ndarray, 
    cmap: str="gnuplot2", 
    title: Union[str, None]=None
) -> None:
    plt.figure()
    plt.imshow(image, cmap=cmap)
    plt.axis("off")
    plt.show()

    
def get_model_size(model) -> float:
    param_size: float = 0.0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size: float = 0.0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

In [4]:
class CFG(object):
    def __init__(
        self, 
        seed: int = 42,
        size: int = 224,
        num_samples: int=10000,
        n_splits: int = 5,
        batch_size: int = 16,
        epochs: int = 25,
        early_stopping: int = 5,
        lr: float = 1e-4,
        wd: float = 0.0,
        max_lr: float = 1e-3,
        pct_start: float = 0.2,
        steps_per_epoch: int = 100,
        div_factor: int = 1e3, 
        final_div_factor: float = 1e3,
    ):
        self.seed = seed
        self.size = size
        self.num_samples = num_samples
        self.n_splits = n_splits
        self.batch_size = batch_size
        self.epochs = epochs
        self.early_stopping = early_stopping
        self.lr = lr
        self.wd = wd
        self.max_lr = max_lr
        self.pct_start = pct_start
        self.steps_per_epoch = steps_per_epoch
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_transform_1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.28106, 0.31696, 0.30282], [0.26783, 0.27980, 0.27595]),
            transforms.RandomAffine(degrees=(-45, 45), translate=(0.15, 0.15), scale=(0.5, 1.5)),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomVerticalFlip(p=0.25),
        ])
        self.valid_transform_1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.28106, 0.31696, 0.30282], [0.26783, 0.27980, 0.27595]),
        ])
        
        self.train_transform_2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.33171, 0.40140, 0.42093], [0.23583, 0.24294, 0.24042]),
            transforms.RandomAffine(degrees=(-45, 45), translate=(0.15, 0.15), scale=(0.5, 1.5)),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomVerticalFlip(p=0.25),
        ])
        self.valid_transform_2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.33171, 0.40140, 0.42093], [0.23583, 0.24294, 0.24042]),
        ])
        
        self.train_transform_3 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.35717, 0.39790, 0.37881], [0.23648, 0.24145, 0.23653]),
            transforms.RandomAffine(degrees=(-45, 45), translate=(0.15, 0.15), scale=(0.5, 1.5)),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomVerticalFlip(p=0.25),
        ])
        self.valid_transform_3 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.35717, 0.39790, 0.37881], [0.23648, 0.24145, 0.23653]),
        ])
        
        self.train_transform_4 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.41605, 0.47578, 0.45165], [0.20690, 0.20804, 0.19967]),
            transforms.RandomAffine(degrees=(-45, 45), translate=(0.15, 0.15), scale=(0.5, 1.5)),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomVerticalFlip(p=0.25),
        ])
        self.valid_transform_4 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.41605, 0.47578, 0.45165], [0.20690, 0.20804, 0.19967]),
        ])
                                
        self.save_path = "saves"
        if not os.path.exists(self.save_path): os.makedirs(self.save_path)
    
cfg = CFG(
    seed=SEED, 
    size=SIZE
)

In [5]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
            
        self.encoder = timm.create_model("efficientnet_b4", pretrained=False)
        self.encoder = nn.Sequential(*[*self.encoder.children()][:-1])
        
        self.decoder = nn.Sequential()
        self.decoder.add_module("DC1", nn.ConvTranspose2d(in_channels=1792, out_channels=512, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN1", nn.ReLU())
        self.decoder.add_module("UP1", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC2", nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN2", nn.ReLU())
        self.decoder.add_module("UP2", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC3", nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN3", nn.ReLU())
        self.decoder.add_module("UP3", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC4", nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN4", nn.ReLU())
        self.decoder.add_module("UP4", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC5", nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN5", nn.ReLU())
        self.decoder.add_module("UP5", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC6", nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN6", nn.ReLU())
        self.decoder.add_module("UP6", nn.Upsample(scale_factor=2))
        self.decoder.add_module("DC7", nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
        self.decoder.add_module("AN7", nn.ReLU())
        self.decoder.add_module("UP7", nn.Upsample(scale_factor=2))
            
    def freeze(self):
        for params in self.parameters(): params.requires_grad = False

    def forward(self, x):
        encoded = self.encoder(x)
        temp = encoded.unsqueeze(dim=2).unsqueeze(dim=3)
        decoded = self.decoder(temp)

        return encoded, decoded

In [6]:
class EncoderModel(nn.Module):
    def __init__(self):
        super(EncoderModel, self).__init__()
            
        self.encoder = timm.create_model("efficientnet_b4", pretrained=False)
        self.encoder = nn.Sequential(*[*self.encoder.children()][:-1])

    def forward(self, x):
        return self.encoder(x)

### **I1T1**

In [7]:
model = Model()
model.load_state_dict(torch.load("../input/fds-en4-ae384-i1t1-s42/saves/state.pt", map_location=cfg.device)["model_state_dict"])
model.eval()

breaker()
print(f"Full Model Size    : {get_model_size(model):.2f} MB")

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

breaker()
print(f"Encoder Model Size : {get_model_size(encoder_model):.2f} MB")

breaker()


**************************************************

Full Model Size    : 160.84 MB

**************************************************

Encoder Model Size : 67.42 MB

**************************************************



### **I2T1**

In [8]:
model = Model()
model.load_state_dict(torch.load("../input/fds-en4-ae384-i2t1-s42/saves/state.pt", map_location=cfg.device)["model_state_dict"])
model.eval()

breaker()
print(f"Full Model Size    : {get_model_size(model):.2f} MB")

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

breaker()
print(f"Encoder Model Size : {get_model_size(encoder_model):.2f} MB")

breaker()


**************************************************

Full Model Size    : 160.84 MB

**************************************************

Encoder Model Size : 67.42 MB

**************************************************



### **I3T1**

In [9]:
model = Model()
model.load_state_dict(torch.load("../input/fds-en4-ae384-i3t1-s42/saves/state.pt", map_location=cfg.device)["model_state_dict"])
model.eval()

breaker()
print(f"Full Model Size    : {get_model_size(model):.2f} MB")

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

breaker()
print(f"Encoder Model Size : {get_model_size(encoder_model):.2f} MB")

breaker()


**************************************************

Full Model Size    : 160.84 MB

**************************************************

Encoder Model Size : 67.42 MB

**************************************************



### **I4T1**

In [10]:
model = Model()
model.load_state_dict(torch.load("../input/fds-en4-ae384-i4t1-s42/saves/state.pt", map_location=cfg.device)["model_state_dict"])
model.eval()

breaker()
print(f"Full Model Size    : {get_model_size(model):.2f} MB")

encoder_model_state_dict = dict()

for names, params in model.state_dict().items():
    if re.match(r"encoder", names, re.IGNORECASE):
        encoder_model_state_dict[names] = params

encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder_model_state_dict)
encoder_model.eval()

breaker()
print(f"Encoder Model Size : {get_model_size(encoder_model):.2f} MB")

breaker()


**************************************************

Full Model Size    : 160.84 MB

**************************************************

Encoder Model Size : 67.42 MB

**************************************************

