### **Setup**

In [1]:
%%bash
pip install timm -q



### **Library Imports**

In [2]:
import os
import re
import cv2
import timm
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy import stats
from torch import nn, optim
from torchvision import models, transforms

### **Helpers**

In [3]:
LABELS: dict = {
    0 : "dew",
    1 : "fogsmog",
    2 : "frost",
    3 : "glaze",
    4 : "hail",
    5 : "lightning",
    6 : "rain",
    7 : "rainbow",
    8 : "rime",
    9 : "sandstorm",
    10 : "snow",
}
IMAGE_BASE_PATH: str = "../input/wictestimages"
MODEL_BASE_PATH: str = "../input/wicen4a384models/saves"


def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")

    
def get_image(path: str, size: int=224) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)

### **Configuration**

In [4]:
class CFG(object):
    def __init__(
         self, 
         seed: int = 42,
         size: int = 224,
         n_splits: int = 5,
         batch_size: int = 16,
         epochs: int = 25,
         early_stopping: int = 5,
         lr: float = 1e-4,
         wd: float = 0.0,
         max_lr: float = 1e-3,
         pct_start: float = 0.2,
         steps_per_epoch: int = 100,
         div_factor: int = 1e3, 
         final_div_factor: float = 1e3,
     ):
        self.seed = seed
        self.size = size
        self.n_splits = n_splits
        self.batch_size = batch_size
        self.epochs = epochs
        self.early_stopping = early_stopping
        self.lr = lr
        self.wd = wd
        self.max_lr = max_lr
        self.pct_start = pct_start
        self.steps_per_epoch = steps_per_epoch
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.51684, 0.52503, 0.50567], [0.19350, 0.18743, 0.19404]),
            transforms.RandomAffine(degrees=(-45, 45), translate=(0.15, 0.15), scale=(0.5, 1.5)),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomVerticalFlip(p=0.25),
        ])
        self.valid_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.51684, 0.52503, 0.50567], [0.19350, 0.18743, 0.19404]),
        ])
                                
        self.save_path = "saves"
        if not os.path.exists(self.save_path): os.makedirs(self.save_path)
    
cfg = CFG(
    seed=42, 
    size=384
)

### **Model**

In [5]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
            
        self.model = timm.create_model(model_name="efficientnet_b4", pretrained=False)
        self.model.classifier = nn.Linear(in_features=self.model.classifier.in_features, out_features=11)

    def forward(self, x):
        return nn.LogSoftmax(dim=1)(self.model(x))

### **Average Final Predictions**

In [6]:
image_filenames: list = sorted(os.listdir(IMAGE_BASE_PATH))
model_filenames: list = os.listdir(MODEL_BASE_PATH)
model_filenames = model_filenames[:1]
label_indexes: list = []

breaker()
for image_filename in image_filenames:
    image = get_image(os.path.join(IMAGE_BASE_PATH, image_filename), cfg.size)
    
    label_index: float = 0.0
    for model_filename in model_filenames:
        model = Model().to(cfg.device)
        model.load_state_dict(torch.load(f=os.path.join(MODEL_BASE_PATH, model_filename), map_location=cfg.device)["model_state_dict"])
        model.eval()
        with torch.no_grad(): output = torch.argmax(model(cfg.valid_transform(image).to(cfg.device).unsqueeze(dim=0)), dim=1)
        label_index += output.item()
    
    print(f"Predicted: {LABELS[int(label_index / len(model_filenames))].title()}")
    print(f"Actual   : {image_filename[:-4]}")
    breaker()


**************************************************

Predicted: Fogsmog
Actual   : Fog

**************************************************

Predicted: Frost
Actual   : Frost

**************************************************

Predicted: Rime
Actual   : Glaze

**************************************************

Predicted: Hail
Actual   : Hail

**************************************************

Predicted: Lightning
Actual   : Lightning

**************************************************

Predicted: Rain
Actual   : Rain

**************************************************

Predicted: Rainbow
Actual   : Rainbow

**************************************************

Predicted: Fogsmog
Actual   : Smog

**************************************************

Predicted: Snow
Actual   : Snow

**************************************************



### **Mode Final Predictions**

In [7]:
image_filenames: list = sorted(os.listdir(IMAGE_BASE_PATH))
model_filenames: list = os.listdir(MODEL_BASE_PATH)
model_filenames = model_filenames[:1]
label_indexes: list = []

breaker()
for image_filename in image_filenames:
    image = get_image(os.path.join(IMAGE_BASE_PATH, image_filename), cfg.size)
    
    label_indexes: list = []
    for model_filename in model_filenames:
        model = Model().to(cfg.device)
        model.load_state_dict(torch.load(f=os.path.join(MODEL_BASE_PATH, model_filename), map_location=cfg.device)["model_state_dict"])
        model.eval()
        with torch.no_grad(): output = torch.argmax(model(cfg.valid_transform(image).to(cfg.device).unsqueeze(dim=0)), dim=1)
        label_indexes.append(output.item())
    
    label_index = stats.mode(np.array(label_indexes), axis=0)[0]
    print(f"Predicted: {LABELS[int(label_index)].title()}")
    print(f"Actual   : {image_filename[:-4]}")
    breaker()


**************************************************

Predicted: Fogsmog
Actual   : Fog

**************************************************

Predicted: Frost
Actual   : Frost

**************************************************

Predicted: Rime
Actual   : Glaze

**************************************************

Predicted: Hail
Actual   : Hail

**************************************************

Predicted: Lightning
Actual   : Lightning

**************************************************

Predicted: Rain
Actual   : Rain

**************************************************

Predicted: Rainbow
Actual   : Rainbow

**************************************************

Predicted: Fogsmog
Actual   : Smog

**************************************************

Predicted: Snow
Actual   : Snow

**************************************************



### **Summing State Dicts**

In [8]:
for model_filename in model_filenames:
    model = Model().to(cfg.device)
    model.load_state_dict(torch.load(f=os.path.join(MODEL_BASE_PATH, model_filename), map_location=cfg.device)["model_state_dict"])
    model.eval()
    
    sum_state_dict: dict = dict()
    for names, params in model.state_dict().items():
        if names in sum_state_dict.keys():
            sum_state_dict[names] += params
        else:
            sum_state_dict[names] = params

In [9]:
image_filenames: list = sorted(os.listdir(IMAGE_BASE_PATH))
model_filenames: list = os.listdir(MODEL_BASE_PATH)
model_filenames = model_filenames[:1]
label_indexes: list = []
    
model = Model().to(cfg.device)
model.load_state_dict(sum_state_dict)

breaker()
for image_filename in image_filenames:
    image = get_image(os.path.join(IMAGE_BASE_PATH, image_filename), cfg.size)
    with torch.no_grad(): label_index = torch.argmax(model(cfg.valid_transform(image).to(cfg.device).unsqueeze(dim=0)), dim=1)
    print(f"Predicted: {LABELS[int(label_index.item())].title()}")
    print(f"Actual   : {image_filename[:-4]}")
    breaker()


**************************************************

Predicted: Rime
Actual   : Fog

**************************************************

Predicted: Snow
Actual   : Frost

**************************************************

Predicted: Snow
Actual   : Glaze

**************************************************

Predicted: Snow
Actual   : Hail

**************************************************

Predicted: Snow
Actual   : Lightning

**************************************************

Predicted: Snow
Actual   : Rain

**************************************************

Predicted: Snow
Actual   : Rainbow

**************************************************

Predicted: Snow
Actual   : Smog

**************************************************

Predicted: Snow
Actual   : Snow

**************************************************



### **Averaging State Dicts**

In [10]:
for model_filename in model_filenames:
    model = Model().to(cfg.device)
    model.load_state_dict(torch.load(f=os.path.join(MODEL_BASE_PATH, model_filename), map_location=cfg.device)["model_state_dict"])
    model.eval()
    
    avg_state_dict: dict = dict()
    for names, params in model.state_dict().items():
        if names in avg_state_dict.keys():
            avg_state_dict[names] += params
        else:
            avg_state_dict[names] = params
        
for k, v in avg_state_dict.items(): avg_state_dict[k] = v / len(model_filenames)

In [11]:
image_filenames: list = sorted(os.listdir(IMAGE_BASE_PATH))
model_filenames: list = os.listdir(MODEL_BASE_PATH)
model_filenames = model_filenames[:1]
label_indexes: list = []
    
model = Model().to(cfg.device)
model.load_state_dict(avg_state_dict)

breaker()
for image_filename in image_filenames:
    image = get_image(os.path.join(IMAGE_BASE_PATH, image_filename), cfg.size)
    with torch.no_grad(): label_index = torch.argmax(model(cfg.valid_transform(image).to(cfg.device).unsqueeze(dim=0)), dim=1)
    print(f"Predicted: {LABELS[int(label_index.item())].title()}")
    print(f"Actual   : {image_filename[:-4]}")
    breaker()


**************************************************

Predicted: Rime
Actual   : Fog

**************************************************

Predicted: Snow
Actual   : Frost

**************************************************

Predicted: Snow
Actual   : Glaze

**************************************************

Predicted: Snow
Actual   : Hail

**************************************************

Predicted: Snow
Actual   : Lightning

**************************************************

Predicted: Snow
Actual   : Rain

**************************************************

Predicted: Snow
Actual   : Rainbow

**************************************************

Predicted: Snow
Actual   : Smog

**************************************************

Predicted: Snow
Actual   : Snow

**************************************************

