# 1. Configuration

In [None]:
import os
import gc
import cv2
import copy
import random
import numpy as np
import pandas as pd
from pathlib import Path
import typing as tp
from io import BytesIO
from PIL import Image
import h5py
import timm
from time import time
from tqdm.notebook import tqdm
import shutil

import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedGroupKFold

import torch
import torchvision
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
from torch.utils.data import Dataset
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)


In [None]:
class CFG:
    
    exp = 23
    
    is_kaggle = False
    is_infer = False
    if is_kaggle:
        OUTPUT_DIR = Path('/kaggle/input/isic2024-baseline')
        TRAIN_DIR = Path('/kaggle/input/isic-2024-challenge/train-image/image')
        TRAIN_HDF5 = Path('/kaggle/input/isic-2024-challenge/train-image.hdf5')
        TEST_HDF5 = Path('/kaggle/input/isic-2024-challenge/test-image.hdf5')
        TRAIN_META = Path('/kaggle/input/isic-2024-challenge/train-metadata.csv')
        TEST_META = Path('/kaggle/input/isic-2024-challenge/test-metadata.csv')
        SAMPLE_SUB = Path('/kaggle/input/isic-2024-challenge/sample_submission.csv')
        PRETRAINED_MODEL = ''
        
    else:
        OUTPUT_DIR = Path(f'/root/Development/Kaggle/ISIC2024/main/models/experiments/exp{exp}/outputs')
        OUTPUT_LOG = Path(f'/root/Development/Kaggle/ISIC2024/main/models/experiments/exp{exp}/log.txt')
        TRAIN_DIR = Path('/root/Development/Kaggle/ISIC2024/data/raw/train-image/image')
        TRAIN_HDF5 = Path('/root/Development/Kaggle/ISIC2024/data/raw/train-image.hdf5')
        TEST_HDF5 = Path('/root/Development/Kaggle/ISIC2024/data/raw/test-image.hdf5')
        TRAIN_META = Path('/root/Development/Kaggle/ISIC2024/data/raw/train-metadata.csv')
        TEST_META = Path('/root/Development/Kaggle/ISIC2024/data/raw/test-metadata.csv')
        SAMPLE_SUB = Path('/root/Development/Kaggle/ISIC2024/data/raw/sample_submission.csv')
        TRAIN_META_ADD = Path('/root/Development/Kaggle/ISIC2024/data/external/All/metadata.csv')
        TRAIN_HDF5_ADD = Path('/root/Development/Kaggle/ISIC2024/data/external/All/image.hdf5')
        TRAIN_HDF5_COMBINED = Path('/root/Development/Kaggle/ISIC2024/data/processed/train-image-combined.hdf5')
        PRETRAINED_MODEL = ''
        # PRETRAINED_MODEL = Path('/root/Development/Kaggle/ISIC2024/main/models/experiments/exp10/outputs/averaged_model.pth')
        
    
        
    max_epoch = 9
    n_folds = 5
    n_classes = 2
    random_seed = 42
    deterministic = True
    enable_amp = True
    view = True
    change_dataset = True
    standardization = True
    oversampling = False
    oversampling_percent = 3
        
    model_name = 'eva02_small_patch14_224.mim_in22k'
    output_dim_models = {
        "resnet18.a1_in1k": 512,
        "efficientnet_b0.ra_in1k": 320,
        "tf_efficientnet_b0.ns_jft_in1k": 1280,
        # 'tf_efficientnet_b5.ns_jft_in1k': 1280,
        'tf_efficientnet_b3.ns_jft_in1k': 1536,
        'eva02_small_patch14_224.mim_in22k': 384,
    }
    pretrained = True
    img_size = 224
    # img_size = 384
    interpolation = cv2.INTER_LINEAR

    es_patience = 5 
    batch_size = 128
    
    lr = 1.0e-03
    # lr = 1.0e-04
    weight_decay = 1.0e-02 # default
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

In [None]:

train_meta = pd.read_csv(CFG.TRAIN_META)
test_meta = pd.read_csv(CFG.TEST_META)
print(len(train_meta))
print(len(test_meta))

# head
display(train_meta[train_meta['target'] > 0.5].head(10))
display(train_meta.head())
display(test_meta.head())


if not CFG.is_infer:
    train_meta_add = pd.read_csv(CFG.TRAIN_META_ADD)
    print(len(train_meta_add))
    display(train_meta_add.head())

# 2. Preprocessing

### process additional data

In [None]:
if not CFG.is_infer:

    # only benign or maligant
    train_meta_add = train_meta_add[train_meta_add['benign_malignant'].isin(['benign', 'malignant'])]
    # make the target column
    train_meta_add['target'] = train_meta_add['benign_malignant'].apply(lambda x: 1 if x == 'malignant' else 0)
    # assign difference patientid
    train_meta_add['patient_id'] = [
        f"example_{i+1}" if pd.isna(id) else id
        for i, id in enumerate(train_meta_add['patient_id'])
    ]

    train_meta_add['additional'] = 1

    display(train_meta_add.head())


### process original data

In [None]:
if not CFG.is_infer:

    # delete clothes and others
    # List of IDs to drop
    ids_to_drop = ['ISIC_0573025', 'ISIC_1443812', 'ISIC_5374420', 'ISIC_2611119', 'ISIC_2691718', 'ISIC_9689783', 'ISIC_9520696', 'ISIC_8651165', 'ISIC_9385142', 'ISIC_9680590']
    train_meta = train_meta.drop(train_meta[train_meta['isic_id'].isin(ids_to_drop)].index)

    # only has lesion_id(strgong label)
    train_meta = train_meta[train_meta['lesion_id'].notnull()].reset_index(drop=True)
    # add "additional" column
    train_meta['additional'] = 0

    print(len(train_meta))
    print(len(train_meta[train_meta['target']==1]))
    print(len(train_meta[train_meta['target']==1]) / len(train_meta) * 100, '%')
    display(train_meta.head())

### concat both data

In [None]:
if not CFG.is_infer:

    concatenated = pd.concat([train_meta, train_meta_add], axis=0)
    # train_meta = concatenated.reset_index()
    # display(train_meta.head())
    # display(train_meta.tail())

# 3. Split to fold

In [None]:
if not CFG.is_infer:


    def split_fold(df:pd.DataFrame):
        # if 'fold' in df.columns:
        #     return df
        
        df['fold'] = -1
        # object
        skf = StratifiedGroupKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.random_seed)

        for i, (train_index, test_index) in enumerate(skf.split(df, df['target'], df['patient_id'])):
            df.loc[test_index, 'fold'] = i
        
        return df
            
    train_meta = split_fold(train_meta)
    display(train_meta.head().T)

    # check
    if CFG.view:
        print(train_meta.groupby('fold')['target'].value_counts().head(300))

### OverSampling

In [None]:
if not CFG.is_infer:
    if CFG.oversampling:

        percent = CFG.oversampling_percent

        train_meta_mali = train_meta[train_meta['target']==1]
        for _ in range(100):
            train_meta = pd.concat([train_meta, train_meta_mali], axis=0)
            if len(train_meta[train_meta['target']==1]) / len(train_meta) * 100 >= percent:
                print(f'over {percent}% malignant')
                break

        print(len(train_meta))
        print(len(train_meta[train_meta['target']==1]))
        print(len(train_meta[train_meta['target']==1]) / len(train_meta) * 100, '%')

        del percent

        if CFG.view:
            print(train_meta.groupby('fold')['target'].value_counts().head(300))
            display(train_meta.head())

# 4. Dataset

In [None]:
# setting seed in each env
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore

# function to set tensor to device
def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [None]:

def get_transforms(include_light=False):
    augmentations_train = A.Compose([
    A.Transpose(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, p=0.3),
    A.OneOf([
            A.MotionBlur(blur_limit=5, p=0.5),
            A.MedianBlur(blur_limit=5, p=0.5),
            A.GaussianBlur(blur_limit=5, p=0.5),
        ], p=0.2),
    A.GaussNoise(var_limit=(5.0, 30.0), p=0.1),
    # A.OpticalDistortion(distort_limit=0.5, p=0.3),
    # A.GridDistortion(num_steps=5, distort_limit=0.5, p=0.3),
    # A.ElasticTransform(alpha=1, p=0.3),
    # A.CLAHE(clip_limit=2.0, p=0.3),
    # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=5, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.3),
    A.CoarseDropout(max_holes=20, min_holes=10, p=0.3),
    A.Resize(CFG.img_size, CFG.img_size),
    # A.Normalize(
    #                 mean=[0.485, 0.456, 0.406], 
    #                 std=[0.229, 0.224, 0.225], 
    #                 max_pixel_value=255.0, 
    #                 p=1.0
    #             ),
    ToTensorV2(p=1)
    ])
    augmentations_train_light = A.Compose([
        A.Transpose(p=0.5),
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, p=0.1),
        A.OneOf([
            A.MotionBlur(blur_limit=5, p=0.5),
            A.MedianBlur(blur_limit=5, p=0.5),
            A.GaussianBlur(blur_limit=5, p=0.5),
        ], p=0.1),
        # A.GaussNoise(var_limit=(5.0, 30.0), p=0.5),
        # A.OpticalDistortion(distort_limit=0.5, p=0.5),
        # A.GridDistortion(num_steps=5, distort_limit=0.5, p=0.5),
        # A.ElasticTransform(alpha=1, p=0.5),
        # A.CLAHE(clip_limit=2.0, p=0.5),
        # A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=5, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.2),
        A.CoarseDropout(max_holes=30, min_holes=20, p=0.2),
        A.Resize(CFG.img_size, CFG.img_size),
        # A.Normalize(
        #                 mean=[0.485, 0.456, 0.406], 
        #                 std=[0.229, 0.224, 0.225], 
        #                 max_pixel_value=255.0, 
        #                 p=1.0
        #             ),
        ToTensorV2(p=1)
    ])
    
    augmentations_test = A.Compose([
        A.Resize(CFG.img_size, CFG.img_size),
        # A.Normalize(
        #                 mean=[0.485, 0.456, 0.406], 
        #                 std=[0.229, 0.224, 0.225], 
        #                 max_pixel_value=255.0, 
        #                 p=1.0
        #             ),
        ToTensorV2(p=1)
    ])
    if include_light:
        return augmentations_train, augmentations_test, augmentations_train_light
    return augmentations_train, augmentations_test

# def get_transforms():
#     data_transforms = {
#         "train": A.Compose([
#             A.Resize(CFG.img_size, CFG.img_size),
#             A.RandomRotate90(p=0.5),
#             A.Flip(p=0.5),
#             A.Downscale(p=0.25),
#             A.ShiftScaleRotate(shift_limit=0.1, 
#                             scale_limit=0.15, 
#                             rotate_limit=60, 
#                             p=0.5),
#             A.HueSaturationValue(
#                     hue_shift_limit=0.2, 
#                     sat_shift_limit=0.2, 
#                     val_shift_limit=0.2, 
#                     p=0.5
#                 ),
#             A.RandomBrightnessContrast(
#                     brightness_limit=(-0.1,0.1), 
#                     contrast_limit=(-0.1, 0.1), 
#                     p=0.5
#                 ),
#             A.Normalize(
#                     mean=[0.485, 0.456, 0.406], 
#                     std=[0.229, 0.224, 0.225], 
#                     max_pixel_value=255.0, 
#                     p=1.0
#                 ),
#             ToTensorV2()], p=1.),
        
#         "valid": A.Compose([
#             A.Resize(CFG.img_size, CFG.img_size),
#             A.Normalize(
#                     mean=[0.485, 0.456, 0.406], 
#                     std=[0.229, 0.224, 0.225], 
#                     max_pixel_value=255.0, 
#                     p=1.0
#                 ),
#             ToTensorV2()], p=1.)
#     }
#     return data_transforms["train"], data_transforms["valid"]

# def get_transforms():
#     train_transforms = A.Compose([
#         A.Transpose(p=0.5),
#         A.VerticalFlip(p=0.5),
#         A.HorizontalFlip(p=0.5),
#         A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75),
#         A.OneOf([
#             A.MotionBlur(blur_limit=5),
#             A.MedianBlur(blur_limit=5),
#             A.GaussianBlur(blur_limit=5),
#             A.GaussNoise(var_limit=(5.0, 30.0)),
#         ], p=0.7),

#         A.OneOf([
#             A.OpticalDistortion(distort_limit=1.0),
#             A.GridDistortion(num_steps=5, distort_limit=1.),
#             A.ElasticTransform(alpha=3),
#         ], p=0.7),

#         A.CLAHE(clip_limit=4.0, p=0.7),
#         A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
#         A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
#         A.Resize(CFG.img_size, CFG.img_size),
#         A.CoarseDropout(max_height=int(CFG.img_size * 0.375), max_width=int(CFG.img_size * 0.375), max_holes=1, min_holes=1, p=0.7),    
#         A.Normalize()
#     ])

#     val_transforms = A.Compose([
#         A.Resize(CFG.img_size, CFG.img_size),
#         A.Normalize()
#     ])

#     return train_transforms, val_transforms


In [None]:

class ISICDataset(Dataset):
    def __init__(self, 
                df: pd.DataFrame,
                fp_hdf: str|Path,
                transform: A.Compose=None,
                ):
        self.df = df
        if 'target' in self.df.columns:
            self.is_training = True
            self.targets = df['target'].values
        else:
            self.is_training = False
        self.fp_hdf = h5py.File(fp_hdf, mode="r")
        self.isic_ids = df['isic_id'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.isic_ids)
    
    def __getitem__(self, index):
        isic_id = self.isic_ids[index]
        image = np.array(Image.open(BytesIO(self.fp_hdf[isic_id][()])))
        if self.is_training:
            target = self.targets[index]
        else:
            target = []
        
        if self.transform:
            return (self._apply_transform(image), target)
        else:
            return (image, target)
    
    def _apply_transform(self, img:np.ndarray):
        """apply transform to image"""
        transformed = self.transform(image=img)
        img = transformed["image"]# .float()# .half()
        return img

# 5. Model

In [None]:

def F_rgb2hsv(rgb: torch.Tensor) -> torch.Tensor:
    cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
    cmin = torch.min(rgb, dim=1, keepdim=True)[0]
    delta = cmax - cmin
    hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
    cmax_idx[delta == 0] = 3
    hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
    hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
    hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
    hsv_h[cmax_idx == 3] = 0.
    hsv_h /= 6.
    hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
    hsv_v = cmax
    return torch.cat([hsv_h, hsv_s, hsv_v], dim=1)


class timmModel(nn.Module):

    def __init__(
            self,
            model_name: str,
            pretrained: bool,
            in_channels: int,
            num_classes: int,
            is_training: bool=False
        ):
        super().__init__()
        if 'eva' in CFG.model_name:
            self.model = timm.create_model(
                model_name=model_name, 
                pretrained=pretrained, 
                in_chans=in_channels,
                num_classes=num_classes,
                # global_pool=''
            )
        else:
            self.model = timm.create_model(
                model_name=model_name, 
                pretrained=pretrained, 
                in_chans=in_channels,
                num_classes=num_classes,
                global_pool=''
            )
        
        self.output_type = ['infer', 'loss']
        self.is_training = is_training

        # model output dim
        dim = CFG.output_dim_models[CFG.model_name]
        
        self.dropout = nn.ModuleList([
            nn.Dropout(0.5) for i in range(5)
        ])
        self.target=nn.Linear(dim, 1)
        # self.target_aux=nn.Linear(dim, 1)
        

    def forward(self, x, target=None):
        batch_size = len(x)
        
        x = torch.cat([x, F_rgb2hsv(x)],1) # (bs, dim, h, w)
        x = self.model(x)    
        # print('model output: ', x.shape)
        
        if not 'eva' in CFG.model_name:
            pool = F.adaptive_avg_pool2d(x, 1)
            # print('after pooling: ', pool.shape)
            reshaped = pool.reshape(batch_size, -1)  
            # print('after reshape: ', reshaped.shape)
        
        else:
            reshaped = x
        
        if self.is_training:
            logit = 0
            # print(len(self.dropout))
            for i in range(len(self.dropout)):
                dropped = self.dropout[i](reshaped)
                # print(dropped.shape)
                logit += self.target(dropped)
            logit = logit / len(self.dropout)
        else:
            logit = self.target(reshaped)
            
        output = {}
        if 'loss' in self.output_type:
            if target.dim() == 1:
                target = target.view(-1, 1)
            output['bce_loss'] = F.binary_cross_entropy_with_logits(logit.float(), target.float())

        if 'infer' in self.output_type:
            output['target'] = torch.sigmoid(logit.float())

        return output

In [None]:
if not CFG.is_infer:


    if CFG.view:
        def show_batch(ds, row=3, col=3, color='rgb'):
            fig = plt.figure(figsize=(10, 10))
            img_index = np.random.randint(0, len(ds)-1, row*col)
            
            for i in range(len(img_index)):
                img, label = ds[img_index[i]]
                img_rgb = img[:3, :, :]
                # img_hsv = img[3:6, :, :]
                
                if color=='rgb':
                    img = img_rgb
                # elif color=='hsv':
                #     img = img_hsv
                
                if isinstance(img, torch.Tensor):
                    img = img.detach().numpy()
                    # Check if the image is in (C, H, W) and transpose it to (H, W, C)
                    if img.shape[0] == 3:  # Assuming the image is (C, H, W)
                        img = np.transpose(img, (1, 2, 0))
                
                ax = fig.add_subplot(row, col, i + 1, xticks=[], yticks=[])
                ax.imshow(img)  # Remove cmap parameter for RGB images
                ax.set_title(f'ID: {img_index[i]}; Target: {label}')
            
            plt.tight_layout()
            plt.show()

        _train_transform, _ = get_transforms()
        _dataset = ISICDataset(df=train_meta, fp_hdf=CFG.TRAIN_HDF5, transform=_train_transform)
        show_batch(_dataset)
        # show_batch(_dataset, color='hsv')


# 6. Training

In [None]:
if not CFG.is_infer:


    # reference
    def train_one_fold(val_fold: int, 
                    train: pd.DataFrame,
                    output_path: str|Path
                    ):
        """Main"""
        # If True, forces cuDNN to benchmark multiple convolution algorithms and choose the fastest one
        torch.backends.cudnn.benchmark = True
        set_random_seed(CFG.random_seed, deterministic=CFG.deterministic)
        # set device with pytorch env
        device = torch.device(CFG.device)
        # print(device)
        
        train_transform, val_transform = get_transforms()
        
    #     train_dataset = Bird2024Dataset(**train_path_label, transform=train_transform)
        train_dataset = ISICDataset(train[train['fold']!=val_fold], CFG.TRAIN_HDF5, train_transform)
        train_dataset_noaugment = ISICDataset(train[train['fold']!=val_fold], CFG.TRAIN_HDF5, val_transform)
        val_dataset = ISICDataset(train[train['fold']==val_fold], CFG.TRAIN_HDF5, val_transform)
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=True, drop_last=True)
        train_loader_noaugment = torch.utils.data.DataLoader(
            train_dataset_noaugment, batch_size=CFG.batch_size, num_workers=4, shuffle=True, drop_last=True)
        
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
        
        model = timmModel(
            model_name=CFG.model_name, 
            pretrained=CFG.pretrained, 
            num_classes=0, # no classification head
            in_channels=6, # RBG+HSV
            is_training=True
        )
        # transfer learning
        if CFG.PRETRAINED_MODEL:
            model.load_state_dict(torch.load(CFG.PRETRAINED_MODEL, map_location=device))
            
        model = model.to(device)

        
        optimizer = optim.AdamW(params=model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
        scheduler = lr_scheduler.OneCycleLR(
            optimizer=optimizer, epochs=CFG.max_epoch,
            pct_start=0.0, steps_per_epoch=len(train_loader),
            max_lr=CFG.lr, div_factor=25, final_div_factor=1e4
        )
        
    #     loss_func = KLDivLossWithLogits()
        # loss_func = nn.CrossEntropyLoss()
        # loss_func = nn.BCEWithLogitsLoss()
    #     loss_func = FocalLossBCE()
        # loss_func.to(device)
    #     loss_func_val = KLDivLossWithLogitsForVal()
    #     loss_func_val = nn.CrossEntropyLoss()
        # loss_func_val = nn.BCEWithLogitsLoss()
    #     loss_func_val = FocalLossBCE()
        
        use_amp = CFG.enable_amp
        scaler = amp.GradScaler(enabled=use_amp)
        
        best_val_loss = 1.0e+09
        best_epoch = 0
        train_loss = 0
        val_loss = 0
        
        for epoch in range(1, CFG.max_epoch + 1):
            if CFG.change_dataset & epoch >= ((CFG.max_epoch+1) * 0.6):
                train_loader = train_loader_noaugment
            epoch_start = time()
            model.train()
            for batch in tqdm(train_loader):
                
                x, t = batch
                if CFG.standardization:
                    x = (x - x.min()) / (x.max() - x.min() +1e-6) * 255
                else:
                    x = x.float()/255
                # t = t.float()
    #             print(x)
    #             print(t)
                x = to_device(x, device)
                t = to_device(t, device)
                    
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    output = model(x, t)
                    loss = output['bce_loss']
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                train_loss += loss.item()
                scheduler.step()
                
            train_loss /= len(train_loader)
                
            model.eval()
            for batch in tqdm(val_loader):
                x, t = batch
                if CFG.standardization:
                    x = (x - x.min()) / (x.max() - x.min() +1e-6) * 255
                else:
                    x = x.float()/255
                x = to_device(x, device)
                t = to_device(t, device)
                with torch.no_grad(), amp.autocast(use_amp):
                    output = model(x, t)
                loss = output['bce_loss']
                val_loss += loss.item()
            val_loss /= len(val_loader)
            
            if val_loss < best_val_loss:
                best_epoch = epoch
                best_val_loss = val_loss
                # print("save model")
                torch.save(model.state_dict(), str(output_path / f'snapshot_epoch_{epoch}.pth'))
            
            elapsed_time = time() - epoch_start
            print(
                f"[epoch {epoch}] train loss: {train_loss: .6f}, val loss: {val_loss: .6f}, elapsed_time: {elapsed_time: .3f}")
            
            if epoch - best_epoch > CFG.es_patience:
                print("Early Stopping!")
                break
                
            train_loss = 0
            val_loss = 0
                
        return val_fold, best_epoch, best_val_loss

### Train  

In [None]:
if not CFG.is_infer:

    score_list = []
    print('='*20)
    print('experiment: ', CFG.exp)
    print('='*20)
    for fold_id in range(CFG.n_folds):
        output_path = CFG.OUTPUT_DIR / f"fold{fold_id}"
        output_path.mkdir(exist_ok=True)
        print(f"[fold{fold_id}]")
        score_list.append(train_one_fold(fold_id, train_meta, output_path))

# 7. Validation

In [None]:
if not CFG.is_infer:

    print(score_list)
    score_list_log = score_list

In [None]:
if not CFG.is_infer:


    # select the best model and delete others
    best_log_list = []
    for (fold_id, best_epoch, _) in score_list:
        
        # select the best model
        exp_dir_path = CFG.OUTPUT_DIR / f"fold{fold_id}"
        best_model_path = exp_dir_path / f"snapshot_epoch_{best_epoch}.pth"
        # copy to new place
        copy_to = CFG.OUTPUT_DIR / f"./best_model_fold{fold_id}.pth"
        shutil.copy(best_model_path, copy_to)
        
        for p in exp_dir_path.glob("*.pth"):
            # delete
            p.unlink()

In [None]:
# Function for inference
def run_inference_loop(model, loader, device):
    model.to(device)
    model.eval()
    model.output_type = ['infer']
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch[0], device)
            if CFG.standardization:
                x = (x - x.min()) / (x.max() - x.min() +1e-6) * 255
            else:
                x = x.float()/255
            output = model(x)
            y = output['target']
            pred_list.append(y.detach().cpu().numpy())
    
    # concatenate to vertical (to df like from long scroll like)
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

### Prediction

In [None]:
if not CFG.is_infer:


    # predict for train data with metrix(CV. not test data)


    # duplicate check
    column_to_check = 'isic_id'
    # Drop duplicate rows based on the specified column
    train_meta = train_meta.drop_duplicates(subset=[column_to_check], keep='first').reset_index()

    # label_arr = train[CLASSES].values
    oof_pred_arr = np.zeros((len(train_meta), CFG.n_classes-1))
    score_list = []

    for fold_id in range(CFG.n_folds):
        print(f"\n[fold {fold_id}]")
        device = torch.device(CFG.device)

        # get_dataloader
        _, val_transform = get_transforms()
        val_dataset = ISICDataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)

        # # get model
        model_path = CFG.OUTPUT_DIR / f"best_model_fold{fold_id}.pth"
        model = timmModel(
            model_name=CFG.model_name, 
            pretrained=False, 
            in_channels=6,
            num_classes=0,
            is_training=False
        )
        model.load_state_dict(torch.load(model_path, map_location=device))

        # # inference
        val_pred = run_inference_loop(model, val_loader, device)
        val_idx = train_meta[train_meta["fold"] == fold_id].index.values
        oof_pred_arr[val_idx] = val_pred

        del val_idx
        del model, val_loader
        torch.cuda.empty_cache()
        gc.collect()

### Metrix

In [None]:
if not CFG.is_infer:

    def comp_score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str, min_tpr: float=0.80):
        v_gt = abs(np.asarray(solution.values)-1)
        v_pred = np.array([1.0 - x for x in submission.values])
        max_fpr = abs(1-min_tpr)
        partial_auc_scaled = roc_auc_score(v_gt, v_pred, max_fpr=max_fpr)
        # change scale from [0.5, 1.0] to [0.5 * max_fpr**2, max_fpr]
        # https://math.stackexchange.com/questions/914823/shift-numbers-into-a-different-range
        partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)
        return partial_auc


In [None]:
if not CFG.is_infer:

    # make true array
    label_arr = train_meta['target']
    # one-hot
    # ture_arr = np.zeros((label_arr.size, CFG.n_classes))
    # ture_arr[np.arange(label_arr.size), label_arr] = 1
    ture_arr = pd.DataFrame(label_arr)

    # oof
    oof = pd.DataFrame(oof_pred_arr)

    micro_roc_auc_ovr = comp_score(
        ture_arr,
        oof,
        ""
    )

    print(f"CV: Micro-averaged One-vs-Rest ROC AUC score:\n{micro_roc_auc_ovr:.10f}")

In [None]:
if not CFG.is_infer:

    with open(CFG.OUTPUT_LOG, 'a') as f:
        f.write('\n')
        f.write(f'exp: {CFG.exp}')
        f.write('\n')
        f.write(str(score_list_log))
        f.write('\n')
        f.write(f"CV: Micro-averaged One-vs-Rest ROC AUC score: {micro_roc_auc_ovr:.10f}")

In [None]:
if not CFG.is_infer:

    display(oof.head())
    display(ture_arr.head())
    display(oof.tail())
    display(ture_arr.tail())

# Inference

In [None]:
# predict for test data with metrix
if CFG.is_kaggle & CFG.is_infer:

    pred_arr = []

    for fold_id in range(CFG.n_folds):
        print(f"\n[fold {fold_id}]")
        device = torch.device(CFG.device)

        # get_dataloader
        _, val_transform = get_transforms()
        val_dataset = ISICDataset(df=test_meta, fp_hdf=CFG.TEST_HDF5, transform=val_transform)
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=CFG.batch_size, num_workers=2, shuffle=False, drop_last=False)

        # get model
        model_path = CFG.OUTPUT_DIR / f"best_model_fold{fold_id}.pth"
        model = timmModel(
            model_name=CFG.model_name, 
            pretrained=False, 
            in_channels=6,
            num_classes=0,
            is_training=False
        )
        model.load_state_dict(torch.load(model_path, map_location=device))

        # # inference
        val_pred = run_inference_loop(model, val_loader, device)
        pred_arr.append(val_pred)

        del model, val_loader, val_pred
        torch.cuda.empty_cache()
        gc.collect()
        
    # averaging
    print(len(pred_arr))
    if len(pred_arr) >= 2:    
        pred_arr = np.mean(pred_arr, axis=0)
    print(pred_arr.shape)

    pred_df = pd.DataFrame(pred_arr)
    display(pred_df)

    # submission
    df_sub = pd.read_csv(CFG.SAMPLE_SUB)
    df_sub["target"] = pred_df
    display(df_sub.head())
    df_sub.to_csv('submission.csv', index=False)