# Library

In [None]:
import os
import copy
import time
import random
import logging
import logging.handlers

import cv2
import glob
import ttach as tta
import timm
import tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit

import albumentations as A
from albumentations.pytorch import ToTensorV2

from ml_decoder import *
from cutmix import cutmix
from loss import *

# Configs

In [None]:
CFG = {"device": "cuda" if torch.cuda.is_available() else "cpu",
       # "model_name": "caformer_b36.sail_in22k_ft_in1k_384",
       "model_name": "eva02_large_patch14_448.mim_in22k_ft_in22k_in1k",
       # "model_name": "tf_efficientnetv2_xl.in21k_ft_in1k",
       "seed": 42,
       "num_epochs": 50,
       "skf_n_splits": 5,
       "lr": 1e-5,
       "early_stop_count": 5,
       "batch_size": 8,
       "num_workers": 8,
       "imgsz": (448, 448)}

# Log

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

log_fileHandler = logging.handlers.RotatingFileHandler(
    filename=f"./{CFG['model_name']}.log",
    maxBytes=1024000,
    backupCount=3,
    mode='a')

log_fileHandler.setFormatter(formatter)
logger.addHandler(log_fileHandler)

# Fix Seed

In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed_everything(CFG['seed'])

# Augmentations

In [None]:
img_size = CFG['imgsz']
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)

train_transform = A.Compose([
    A.Resize(img_size[0], img_size[1], p=1.0),
    A.Rotate(limit=30),
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(),
    A.Cutout(num_holes=30, max_h_size=20, max_w_size=20),
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

val_transform = A.Compose([
    A.Resize(img_size[0], img_size[1], p=1.0),
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

# Define Categories(Class)

In [None]:
train_df = pd.read_csv('./data/train.csv')
train_df.head()

In [None]:
CFG['num_classes'] = len(train_df['label'].unique())

In [None]:
categories = {}
for i, c in enumerate(train_df['label'].unique()):
    categories[c] = i

categories

# Datasets

In [None]:
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __getitem__(self, idx):
        # origin image
        img_path = self.df.iloc[idx]['img_path']
        image = cv2.imread(os.path.join('./data', img_path))
        
        label = self.df.iloc[idx]['label']
        
        if self.transform:
            image = self.transform(image=np.array(image))['image']
        
        return image, categories[label]
       
    def __len__(self):
        return len(self.df)

# Custom model with timm

In [None]:
# Use ML-Decoder
class CustomModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=True).to(CFG['device'])
        self.model.conv_head = nn.Identity()
        
        self.ml_decoder_head = MLDecoder(num_classes=CFG['num_classes'],
                                         decoder_embedding=768,
                                         initial_num_features=self.model.embed_dim).to(CFG['device'])
        
    def forward(self, x):
        x = self.model.forward_features(x)
        x = self.ml_decoder_head(x)
        
        return x

# Train(Stratified K-Fold)

In [None]:
whole_start = time.time()

sss = StratifiedShuffleSplit(n_splits=CFG['skf_n_splits'], 
                             random_state=CFG['seed'],
                             test_size=0.15)

# skf = StratifiedKFold(n_splits=CFG['skf_n_splits'],
#                       random_state=CFG['seed'],
#                       shuffle=True)

for fold_idx, (train_idx, val_idx) in enumerate(sss.split(train_df, train_df['label'])):  
    if fold_idx < 3:
        continue
        
    print(f'----- Fold {fold_idx} -----')
    logger.info(f'----- Fold {fold_idx} -----')
    
    # Create Dataset(fold)
    X_train_fold = train_df.loc[train_idx, :]
    X_val_fold = train_df.loc[val_idx, :]
    
    train_dataset = CustomDataset(X_train_fold, transform=train_transform)
    val_dataset = CustomDataset(X_val_fold, transform=val_transform)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=CFG['batch_size'],
                                  num_workers=CFG['num_workers'],
                                  shuffle=True)

    val_dataloader = DataLoader(val_dataset,
                                batch_size=CFG['batch_size'],
                                num_workers=CFG['num_workers'],
                                shuffle=False)
    
    model = CustomModel(CFG['model_name'])
    # model = timm.create_model(CFG['model_name'], pretrained=True, num_classes=CFG['num_classes']).to(CFG['device'])
    model = nn.DataParallel(model)
    
    criterion1 = DiceLoss(mode='multiclass')
    criterion2 = LabelSmoothingLoss(CFG['num_classes'], smoothing=0.3)
    
    optimizer = optim.AdamW(model.parameters(), lr=CFG['lr'])

    best_score = 0.0
    early_stop_check = 0
    best_epoch = None
    best_model_weights = copy.deepcopy(model.state_dict())

    start = time.time()
    for epoch in range(CFG['num_epochs']):
        # Train
        model.train()
        train_loss = []

        for inputs, labels in tqdm.tqdm(train_dataloader, leave=True):
            inputs = inputs.to(CFG['device'])
            labels = labels.to(CFG['device'])

            optimizer.zero_grad()

            if np.random.rand() >= 0.5:
                inputs, labels_a, labels_b, lam = cutmix(inputs, labels)
                outputs = model(inputs)
                # loss = criterion(outputs, labels_a) * lam + criterion(outputs, labels_b) * (1 - lam)
                loss = criterion2(outputs, labels_a) * lam + criterion2(outputs, labels_b) * (1 - lam)

            else:
                outputs = model(inputs)
                # loss = criterion(outputs, labels)
                loss = 0.25 * criterion1(outputs, labels) + 0.75 * criterion2(outputs, labels)

            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())

        # Evaluation
        model.eval()
        with torch.no_grad():
            val_preds = []
            targets = []

            for inputs, labels in tqdm.tqdm(val_dataloader, leave=True):
                inputs = inputs.to(CFG['device'])
                labels = labels.to(CFG['device'])

                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=-1)

                val_preds += preds.detach().cpu().numpy().tolist()
                targets += labels.detach().cpu().numpy().tolist()

            val_score = f1_score(targets, val_preds, average='macro')

        print(f"Epoch [{epoch+1}/{CFG['num_epochs']}] Loss: {np.mean(train_loss):.6f} Val F1 Score: {val_score:.6f}")

        logger.info(f"Epoch [{epoch+1}/{CFG['num_epochs']}] Loss: {np.mean(train_loss):.6f} Val F1 Score: {val_score:.6f}\n")

        # save best model
        if val_score > best_score:
            best_score = val_score

            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')

            best_model_weights = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f'./checkpoints/{CFG["model_name"]}_fold{fold_idx}_best.pth')
            print(f"Save Best Weights...")

            early_stop_check = 0
            best_epoch = epoch + 1 

        else:
            early_stop_check += 1

        if early_stop_check == CFG['early_stop_count']:
            print("Early Stopping...")
            break
            
    end = time.time()
    elapsed_time = end - start

    hour, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)

    print(f"  - {fold_idx} fold's best epoch is '{best_epoch}'")
    print(f"  - {fold_idx} fold's best score is {best_score:.6f}")
    print(f"  - {fold_idx} fold's Training Time is {hour}h {minutes}m {seconds:.4f}s")
    print()

    logger.info(f"  - {fold_idx} fold's best epoch is '{best_epoch}'")
    logger.info(f"  - {fold_idx} fold's best score is {best_score:.6f}")
    logger.info(f"  - {fold_idx} fold's Training Time is {hour}h {minutes}m {seconds:.4f}s")
    
    # Memory release
    del model
    torch.cuda.empty_cache()
    
    break

# end training
whole_end = time.time()
whole_elapsed_time = whole_end - whole_start

whole_hour, whole_remainder = divmod(whole_elapsed_time, 3600)
whole_minutes, whole_seconds = divmod(whole_remainder, 60)

logger.info(f"  - Whole Training Time is {whole_hour}h {whole_minutes}m {whole_seconds:.4f}s")

# model.load_state_dict(best_model_weights)