In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/tmp/project', exist_ok=True)
os.chdir('/tmp/project')
print( os.getcwd() )
if not os.path.exists('/tmp/project/train.csv'):
    !cp /content/drive/MyDrive/Colab_Notebooks/dacon/2024_저해상도조류이미지/open.zip /tmp/project
    !unzip -o -q open.zip
    !rm open.zip
    # 추가 모듈 설치
    !sudo apt-get install -y libmagickwand-dev
    !pip install wandb timm wand

Mounted at /content/drive
/tmp/project
cp: cannot stat '/content/drive/MyDrive/Colab_Notebooks/dacon/2024_저해상도조류이미지/open.zip': No such file or directory
unzip:  cannot find or open open.zip, open.zip.zip or open.zip.ZIP.
rm: cannot remove 'open.zip': No such file or directory
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fonts-droid-fallback fonts-noto-mono fonts-urw-base35 ghostscript
  gir1.2-freedesktop gir1.2-gdkpixbuf-2.0 gir1.2-rsvg-2.0 gsfonts
  imagemagick-6-common libblkid-dev libblkid1 libcairo-script-interpreter2
  libcairo2-dev libdjvulibre-dev libdjvulibre-text libdjvulibre21 libffi-dev
  libfftw3-double3 libgdk-pixbuf-2.0-dev libgdk-pixbuf2.0-bin libglib2.0-dev
  libglib2.0-dev-bin libgs9 libgs9-common libice-dev libidn12 libijs-0.35
  libjbig2dec0 libjxr-tools libjxr0 liblcms2-dev liblqr-1-0 liblqr-1-0-dev
  liblzo2-2 libmagickcore-6-arch-config libmagickcore-6-heade

In [2]:
import os
import shutil
import gc
import math
import pickle
import re
import sys
import logging
import IPython
from glob import glob
from datetime import datetime

import numpy as np
import pandas as pd
import cv2
from PIL import Image
from wand import image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
from torch.utils.data import Dataset, DataLoader
from torch.optim import swa_utils

import torchvision
from torchvision.transforms import v2

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from transformers import AutoModel

from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.utils.class_weight import compute_class_weight

import wandb
import optuna
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Configure training settings
CFG = {
    'SEED': 42,
    'N_SPLIT': 5,
    'LABEL_SMOOTHING': 0.05,
    'OPTIMIZER': 'AdamW',
    'INTERPOLATION': 'robidouxsharp',
    'PRECISION': '16',
    'MODEL_NAME': "timm/deit3_large_patch16_224.fb_in22k_ft_in1k",
    'IMG_SIZE': 224,
    'BATCH_SIZE': 48,
    'LR': [0.25e-5 * np.sqrt(48), 1e-7],
    'IMG_TRAIN_SIZE': 224,
}

# Logger setup
logger = logging.getLogger()
logging.basicConfig(handlers=[
    logging.StreamHandler(stream=sys.stdout),
    logging.handlers.RotatingFileHandler(filename='run.log', mode='a', maxBytes=512000, backupCount=4)
])
logging_formatter = logging.Formatter('%(asctime)s [%(levelname)-4.4s] %(message)s', datefmt='%m/%d %H:%M:%S')
_ = [h.setFormatter(logging_formatter) for h in logger.handlers]
logger.setLevel(logging.INFO)

def showtraceback(self, *args, **kwargs):
    logger.exception('-------Exception----------')

IPython.core.interactiveshell.InteractiveShell.showtraceback = showtraceback
logger.info('program started')

# Seed function
def seed_everything(seed):
    logger.info(f'seed_everything : {seed}')
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED'])

# Function to selectively replace files
def selective_replace(src_folder, dest_folder, replace_ext=None, skip_ext='.jpg'):
    for root, _, files in os.walk(src_folder):
        relative_path = os.path.relpath(root, src_folder)
        dest_root = os.path.join(dest_folder, relative_path)

        if not os.path.exists(dest_root):
            os.makedirs(dest_root)

        for file in files:
            src_file = os.path.join(root, file)
            dest_file = os.path.join(dest_root, file)

            if skip_ext and file.endswith(skip_ext):
                print(f"Skipping: {file}")
                continue

            if os.path.exists(dest_file):
                if skip_ext and file.endswith(skip_ext):
                    print(f"Skipping: {file}")
                    continue
                else:
                    shutil.copy2(src_file, dest_file)
                    print(f"Replaced: {file}")

# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, load_img_size, shuffle=False, transforms=None, interpolation='robidouxsharp'):
        self.df = pd.DataFrame({'img_path_list': img_path_list})
        self.interpolation = interpolation
        self.load_img_size = load_img_size
        logger.info(f'load_img_size={load_img_size}')
        if label_list is not None:
            self.df['label_list'] = label_list
        if shuffle:
            self.df = self.df.sample(frac=1.0).reset_index(drop=True)
        self.transforms = transforms

    def get_interpolated_image(self, img, new_image_size):
        if self.interpolation == 'pil_lanczos':
            if isinstance(img, np.ndarray):
                img = Image.fromarray(img)
            return img.resize((new_image_size, new_image_size), Image.LANCZOS)
        elif self.interpolation == 'cv2_lanczos4':
            if not isinstance(img, np.ndarray):
                img = np.array(img)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            img = cv2.resize(img, (new_image_size, new_image_size), interpolation=cv2.INTER_LANCZOS4)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            return Image.fromarray(img)
        else:
            if not isinstance(img, np.ndarray):
                img = np.array(img)
            with image.Image.from_array(img) as src:
                src.resize(new_image_size, new_image_size, filter=self.interpolation)
                return Image.fromarray(np.array(src))

    def get_image_from_index(self, index, img_size):
        img_path = self.df.img_path_list[index]
        fname = img_path.replace('./','').split('.')[0] + '.png'
        full_fname = f'img_cached/{img_size}_{self.interpolation}/{fname}'
        if os.path.exists(full_fname):
            img = Image.open(full_fname)
        else:
            fname_path = '/'.join(full_fname.split('/')[:-1])
            os.makedirs(fname_path, exist_ok=True)
            img = self.get_interpolated_image(Image.open(img_path), img_size)
            img.save(full_fname)
        return img

    def __getitem__(self, index):
        image = self.get_image_from_index(index, self.load_img_size)
        if self.transforms is not None:
            image = self.transforms(image=np.array(image))['image']
        if 'label_list' in self.df.columns:
            label = self.df.label_list[index]
            return {'pixel_values': image, 'label': label}
        else:
            return {'pixel_values': image}

    def __len__(self):
        return len(self.df)

# CosineAnnealingWarmupRestarts Scheduler
class CosineAnnealingWarmupRestarts(_LRScheduler):
    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 first_cycle_steps: int,
                 cycle_mult: float = 1.,
                 max_lr: float = 1e-5,
                 min_lr: float = 1e-10,
                 warmup_steps: int = 0,
                 gamma: float = 1.,
                 last_epoch: int = -1):
        assert warmup_steps < first_cycle_steps
        self.first_cycle_steps = first_cycle_steps
        self.cycle_mult = cycle_mult
        self.base_max_lr = max_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.gamma = gamma
        self.cur_cycle_steps = first_cycle_steps
        self.cycle = 0
        self.step_in_cycle = last_epoch
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        self.init_lr()

    def init_lr(self):
        self.base_lrs = [self.min_lr for _ in self.optimizer.param_groups]
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr

    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.max_lr - base_lr) *
                    (1 + math.cos(math.pi * (self.step_in_cycle - self.warmup_steps) /
                                 (self.cur_cycle_steps - self.warmup_steps))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle += 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle -= self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
                    self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n
            else:
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch
        self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

# Custom Model with GroupNorm and Dropout
class CustomModel(nn.Module):
    def __init__(self, model, dropout_rate=0.5):
        super(CustomModel, self).__init__()
        self.model = model
        hidden_dim = model.num_features if hasattr(model, 'num_features') else model.config.hidden_size
        self.dropout = nn.Dropout(dropout_rate)
        self.clf = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=hidden_dim),
            nn.ReLU(),
            self.dropout,
            nn.Linear(hidden_dim, 25)
        )

    def forward(self, x):
        x = self.model(x)
        if not isinstance(x, torch.Tensor):
            x = x.pooler_output
        x = self.clf(x)
        return x

# Test-Time Augmentation
tta_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

def prediction_with_tta(model, test_loader, device, n_augmentations=5):
    model = model.to(device)
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="TTA Prediction"):
            batch_preds = []
            for _ in range(n_augmentations):
                augmented = tta_transforms(image=batch['pixel_values'].cpu().numpy())
                augmented_image = augmented['image'].to(device)
                output = model(augmented_image)
                batch_preds.append(F.softmax(output, dim=1).cpu().numpy())
            batch_preds = np.mean(batch_preds, axis=0)
            preds.append(batch_preds)
    return np.concatenate(preds, axis=0)

# Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

# Create Model Function
def create_model(model_name, freeze_layers=True):
    logger.info(f'create_model: {model_name}')
    if '/' not in model_name:
        model_name = 'timm/' + model_name
    if model_name.startswith('timm/'):
        base_model = timm.create_model(model_name, pretrained=True)
    else:
        base_model = AutoModel.from_pretrained(model_name)
    model = CustomModel(base_model)
    if freeze_layers:
        for param in model.model.parameters():
            param.requires_grad = False
        for param in model.clf.parameters():
            param.requires_grad = True
    model.eval()
    model(torch.rand((1,3,CFG['IMG_SIZE'],CFG['IMG_SIZE'])).type(torch.float32))
    return model

# Load Data
from google.colab import drive
drive.mount('/content/drive')

%cd /content

!unzip -qq "/content/drive/MyDrive/Colab Notebooks/open.zip"

current_path = os.getcwd()
print(current_path)

# Add selective_replace function call here to selectively replace files in the dataset
selective_replace('/content', '/content/drive/MyDrive/Colab Notebooks/open


SyntaxError: unterminated string literal (detected at line 343) (<ipython-input-2-2216c75b6137>, line 343)

In [None]:
train_df = pd.read_csv('/content/train.csv')
le = LabelEncoder()
train_df['class'] = le.fit_transform(train_df['label'])

# Save Preprocessed Data
with open('train_subset.pkl', 'wb') as f:
    pickle.dump(train_subset_df, f)

with open('val_fold.pkl', 'wb') as f:
    pickle.dump(val_fold_df, f)

# Create Sample Dataset
train_subset_df, _ = train_test_split(
    train_df,
    stratify=train_df['label'],
    test_size=0.9,
    random_state=CFG['SEED']
)

if not len(train_df) == len(os.listdir('/content/train')):
    raise ValueError()

skf = StratifiedKFold(n_splits=CFG['N_SPLIT'], random_state=CFG['SEED'], shuffle=True)

image_size = CFG['IMG_SIZE']

# Advanced Data Augmentation with MixUp and CutMix
def get_mixup_cutmix_augmentation(alpha=1.0, prob=0.5):
    return A.OneOf([
        A.CutMix(p=prob, alpha=alpha),
        A.MixUp(p=prob, alpha=alpha)
    ], p=prob)

train_transform = A.Compose([
    A.RandomResizedCrop(image_size, image_size),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.ShiftScaleRotate(p=0.5),
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0)),
        A.GaussianBlur(blur_limit=(3, 7)),
        A.MotionBlur(blur_limit=3)
    ], p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomGamma(p=0.5),
    A.HueSaturationValue(p=0.5),
    get_mixup_cutmix_augmentation(alpha=1.0, prob=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Create DataLoaders
train_dataset_small = CustomDataset(
    img_path_list=train_subset_df['img_path'].values,
    label_list=train_subset_df['label'].values,
    load_img_size=CFG['IMG_TRAIN_SIZE'],
    shuffle=True,
    transforms=train_transform
)

train_loader_small = DataLoader(
    train_dataset_small,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

val_fold_df = None  # Initialize to avoid reference before assignment

with open('val_fold.pkl', 'rb') as f:
    val_fold_df = pickle.load(f)

val_dataset = CustomDataset(
    img_path_list=val_fold_df['img_path'].values,
    label_list=val_fold_df['class'].values,
    load_img_size=CFG['IMG_SIZE'],
    shuffle=False,
    transforms=test_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CFG['BATCH_SIZE']*2,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

# Compute Class Weights
class_weight = torch.FloatTensor(
    compute_class_weight('balanced', classes=np.unique(train_df.label), y=train_df.label)
).to('cuda')

# Training Function
def train(model, optimizer, train_loader, val_loader, scheduler, device, validation_steps=10, logging_steps=10, use_amp=True, filename=''):
    logger.info(f'use_amp={use_amp}')

    model.to(device)
    best_score = 0
    best_loss = 1000
    best_model = None
    MAX_PATIENCE = 5
    best_patience = MAX_PATIENCE
    loss_fn = nn.CrossEntropyLoss(weight=class_weight, label_smoothing=CFG['LABEL_SMOOTHING'], reduction='mean').to(device)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    checkpoint_filenames = []
    early_stopping = EarlyStopping(patience=MAX_PATIENCE, verbose=True)

    max_steps = len(train_loader)
    if not isinstance(validation_steps, int):
        validation_steps = int(max_steps * validation_steps)
    max_steps = (max_steps // validation_steps) * validation_steps

    ema_model = swa_utils.AveragedModel(model, swa_utils.get_ema_multi_avg_fn(np.power(np.e, np.log(0.5)/(validation_steps*MAX_PATIENCE))))

    history = {
        'train_loss': [],
        'val_loss': [],
        'val_f1': []
    }

    for epoch in range(1, 4):
        model.train()
        train_loss = []
        pbar_postfix = {}

        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for i, batch in enumerate(pbar):
            if i >= max_steps:
                continue
            steps = i + 1

            optimizer.zero_grad()
            if use_amp:
                with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
                    output = model(batch['pixel_values'].to(device))
                    loss = loss_fn(output, batch['label'].to(device))
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
                scaler.step(optimizer)
                scaler.update()
            else:
                output = model(batch['pixel_values'].to(device))
                loss = loss_fn(output, batch['label'].to(device))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
                optimizer.step()

            if scheduler is not None:
                scheduler.step()

            train_loss.append(loss.item())
            loss = None
            output = None
            batch = None

            if ema_model is not None:
                ema_model.update_parameters(model)

            if steps % logging_steps == 0:
                current_lr = optimizer.param_groups[0]["lr"]
                pbar_postfix.update({
                    't_loss0': train_loss[-1],
                    'lr': current_lr
                })
                pbar.set_postfix(pbar_postfix)
                wandb.log({
                    "epoch": epoch,
                    "step": steps,
                    "train_loss": train_loss[-1],
                    "lr": current_lr
                }, step=(epoch-1)*max_steps + steps)

            if steps % validation_steps == 0:
                _val_loss, _val_score = validation(model, loss_fn, val_loader, device, use_amp)
                _train_loss = np.mean(train_loss)
                history['train_loss'].append(_train_loss)
                history['val_loss'].append(_val_loss)
                history['val_f1'].append(_val_score)

                logger.info(f'eps={epoch}, lr={optimizer.param_groups[0]["lr"]:.3g}, t_loss={_train_loss:.4f}, v_loss={_val_loss:.4f}, v_f1={_val_score:.4f}')
                wandb.log({
                    "epoch": epoch,
                    "step": steps,
                    "train_avg_loss": _train_loss,
                    "valid_loss": _val_loss,
                    "valid_f1": _val_score,
                    "lr": optimizer.param_groups[0]["lr"]
                }, step=(epoch-1)*max_steps + steps)

                early_stopping(_val_score)
                if early_stopping.early_stop:
                    logger.info("Early stopping triggered")
                    if ema_model is not None:
                        swa_utils.update_bn(train_loader, ema_model, device)
                        ema_val_loss, ema_val_score = validation(ema_model, loss_fn, val_loader, device, use_amp)
                        logger.info(f'EMA ::: ema_v_loss={ema_val_loss:.4f}, ema_v_f1={ema_val_score:.4f}')
                        wandb.log({'ema_v_loss': ema_val_loss, 'ema_v_f1': ema_val_score})

                        save_filename = filename.format(epoch=epoch, val_loss=ema_val_loss, val_score=ema_val_score) + '-ema.ckpt'
                        torch.save({"model": ema_model.state_dict()}, save_filename)
                        logger.info(f'{save_filename} : (ema) saved.')

                    if best_model and filename:
                        save_path = filename.format(epoch=epoch, val_loss=_val_loss, val_score=_val_score) + '.ckpt'
                        checkpoint_filenames.append(save_path)
                        os.makedirs(os.path.dirname(save_path), exist_ok=True)
                        torch.save({"model": best_model.state_dict()}, save_path)
                        logger.info(f'{save_path} : saved.')
                        if len(checkpoint_filenames) > 1:
                            os.remove(checkpoint_filenames[-2])
                    return best_model

                if _val_score > best_score:
                    best_score = _val_score
                    best_loss = _val_loss
                    best_model = model
                    best_patience = MAX_PATIENCE
                    if filename:
                        save_path = filename.format(epoch=epoch, val_loss=_val_loss, val_score=_val_score) + '.ckpt'
                        checkpoint_filenames.append(save_path)
                        os.makedirs(os.path.dirname(save_path), exist_ok=True)
                        torch.save({"model": model.state_dict()}, save_path)
                        logger.info(f'{save_path} : saved.')
                        if len(checkpoint_filenames) > 1:
                            os.remove(checkpoint_filenames[-2])
                elif _val_loss < best_loss:
                    best_loss = _val_loss
                    best_patience = MAX_PATIENCE
                else:
                    best_patience -= 1

    return best_model

# Validation Function
def validation(model, loss_fn, val_loader, device, use_amp):
    model = model.to(device)
    save_training = model.training
    model.eval()

    val_loss = []
    preds, true_labels = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            true_labels += batch['label'].detach().cpu().numpy().tolist()
            with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
                pred = model(batch['pixel_values'].to(device))
                loss = loss_fn(pred, batch['label'].to(device))
            preds += pred.detach().argmax(1).cpu().numpy().tolist()
            val_loss.append(loss.item())

        _val_loss = np.mean(val_loss)
        _val_score = f1_score(true_labels, preds, average='macro')

    if save_training:
        model.train()
    return _val_loss, _val_score

# Prediction Function
def prediction(model, test_loader, device):
    model = model.to(device)
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Prediction"):
            pixel_values = batch['pixel_values'].to(device)
            pred = model(pixel_values)
            preds += F.softmax(pred, dim=1).detach().cpu().numpy().tolist()
    return preds

# Initialize WandB
def init_wandb(fold_idx, model_name, dt_str):
    run = wandb.init(
        name=f'fold{fold_idx+1}_{model_name}_' + dt_str,
        config=CFG,
        reinit=True,
        mode='offline'
    )
    return run

# Plot Training History
def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)

    plt.figure(figsize=(12, 5))

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b', label='Training loss')
    plt.plot(epochs, history['val_loss'], 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot F1 Score
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['val_f1'], 'g', label='Validation F1')
    plt.title('Validation F1 Score')
    plt.legend()

    plt.show()

# Main Training Loop
dt_str = datetime.now().strftime('%m%d%H%M')

skf = StratifiedKFold(n_splits=CFG['N_SPLIT'], random_state=CFG['SEED'], shuffle=True)

for fold_idx, (train_index, val_index) in enumerate(skf.split(train_df, train_df['class'])):
    gc.collect()
    torch.cuda.empty_cache()

    logger.info(f'fold_idx={fold_idx} started')
    run = init_wandb(fold_idx, CFG["MODEL_NAME"].split("/")[1].split("-")[0], dt_str)

    train_fold_df = train_df.loc[train_index]
    val_fold_df = train_df.loc[val_index]

    train_dataset = CustomDataset(
        img_path_list=train_fold_df['img_path'].values,
        label_list=train_fold_df['class'].values,
        load_img_size=CFG['IMG_TRAIN_SIZE'],
        shuffle=True,
        transforms=train_transform
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG['BATCH_SIZE'],
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )
    val_dataset = CustomDataset(
        img_path_list=val_fold_df['img_path'].values,
        label_list=val_fold_df['class'].values,
        load_img_size=CFG['IMG_SIZE'],
        shuffle=False,
        transforms=test_transform
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG['BATCH_SIZE']*2,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

    model = create_model(CFG['MODEL_NAME'])

    optimizer = optim.AdamW(
        model.parameters(),
        lr=CFG['LR'][0],
        weight_decay=0.001,
    )
    scheduler = CosineAnnealingWarmupRestarts(
        optimizer,
        first_cycle_steps=int(len(train_loader)) // 4,
        cycle_mult=1.0,
        max_lr=CFG['LR'][0] * 2,
        min_lr=CFG['LR'][1],
        warmup_steps=0,
        gamma=0.93,
    )

    model = train(
        model, optimizer, train_loader, val_loader, scheduler,
        device, validation_steps=10, logging_steps=10,
        use_amp=(CFG['PRECISION'] == '16'),
        filename=f'./ckpt/{CFG["MODEL_NAME"].split("/")[1].split("-")[0]}-fold_idx={fold_idx}-epoch={{epoch:02d}}-val_loss={{val_loss:.4f}}-val_score={{val_score:.4f}}'
    )

    plot_training_history({
        'train_loss': [],
        'val_loss': [],
        'val_f1': []
    })

    model = None
    gc.collect()
    torch.cuda.empty_cache()
    logger.info(f'fold_idx={fold_idx} finished')
    run.finish()

    try:
        last_chpt_info = os.popen("ls -t ./ckpt/ | head -n1").read().strip()
        last_chpt_info = ','.join(last_chpt_info[:-5].split('-')[1:])
        os.system(f'python ~/send_telegram.py "{last_chpt_info}"')
    except:
        pass

print(os.listdir('./ckpt'))

# Save Final Checkpoint
torch.save({'model': model.state_dict()}, 'checkpoint_epoch_3.pth')

In [None]:
# Final Training on Sample Dataset
model = create_model(CFG['MODEL_NAME'])
optimizer = optim.AdamW(model.parameters(), lr=CFG['LR'][0])
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps=int(len(train_loader_small)) // 4,
    cycle_mult=1.0,
    max_lr=CFG['LR'][0] * 2,
    min_lr=CFG['LR'][1],
    warmup_steps=0,
    gamma=0.93
)

model = train(
    model, optimizer, train_loader_small, val_loader, scheduler,
    device, validation_steps=10, logging_steps=10,
    use_amp=(CFG['PRECISION'] == '16'),
    filename=f'./ckpt/{CFG["MODEL_NAME"].split("/")[1].split("-")[0]}-fold_idx={{fold_idx}}-epoch={{epoch:02d}}-val_loss={{val_loss:.4f}}-val_score={{val_score:.4f}}'
)

In [None]:
# Prediction and Ensemble
test_df = pd.read_csv('./test.csv')

ckpt_df = pd.DataFrame({'fname': glob('./ckpt/*.ckpt')})
ckpt_df['mtime'] = ckpt_df.fname.apply(lambda x: int(os.stat(x).st_mtime))
ckpt_df['model_name'] = ckpt_df.fname.apply(lambda x: re.search(r'./ckpt/(.*?)-fold', x).group(1) if re.search(r'./ckpt/(.*?)-fold', x) else None)
ckpt_df['img_size'] = ckpt_df.fname.apply(lambda x: int(re.search(r'patch[0-9]+_([0-9]+)', x + 'patch0_0').group(1)) if re.search(r'patch[0-9]+_([0-9]+)', x + 'patch0_0') else 0)
ckpt_df['is_ema'] = ckpt_df.fname.str.endswith('ema.ckpt').astype(int)
ckpt_df['fold_idx'] = ckpt_df.fname.apply(lambda x: int(re.search(r'fold_idx=([0-9]+)-', x).group(1)) if re.search(r'fold_idx=([0-9]+)-', x) else -1)
ckpt_df['val_loss'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_loss=(0\.[0-9]+)', x).group(1)) if re.search(r'val_loss=(0\.[0-9]+)', x) else None)
ckpt_df['val_score'] = ckpt_df.fname.apply(lambda x: float(re.search(r'val_score=(0\.[0-9]+)', x).group(1)) if re.search(r'val_score=(0\.[0-9]+)', x) else None)

ckpt_df = ckpt_df[(ckpt_df.img_size != 0) & (ckpt_df.is_ema == 0)]
ckpt_df = ckpt_df.sort_values('mtime', ascending=False).reset_index(drop=True)
ckpt_indexes = ckpt_df[ckpt_df.fold_idx == ckpt_df.fold_idx.max()].index[:4]

preds = []
preds_score = []

for ckpt_start_index in ckpt_indexes:
    logger.info(f'{ckpt_df.fname[ckpt_start_index]} loading')
    CFG['IMG_SIZE'] = ckpt_df.img_size[ckpt_start_index]
    assert CFG['IMG_SIZE'] in (196, 224)
    logger.info(CFG['IMG_SIZE'])

    test_dataset = CustomDataset(
        test_df['img_path'].values, None,
        interpolation=CFG['INTERPOLATION'], load_img_size=CFG['IMG_SIZE'],
        shuffle=False, transforms=test_transform
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=CFG['BATCH_SIZE']*2,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

    model_name = ckpt_df.model_name[ckpt_start_index]
    model = create_model(model_name)
    if ckpt_df.is_ema[ckpt_start_index]:
        model = swa_utils.AveragedModel(model)

    for i in range(ckpt_start_index, ckpt_start_index + skf.get_n_splits()):
        checkpoint_path = ckpt_df.fname[i]
        logger.info(f'{checkpoint_path} loading')
        model.load_state_dict(torch.load(checkpoint_path)['model'])

        preds_score.append(ckpt_df.val_score[i])
        preds.append(prediction_with_tta(model, test_loader, device))

preds = np.array(preds)
preds_score = np.array(preds_score)

# Weighted Averaging Ensemble
weights = preds_score / preds_score.sum()
ensemble_preds = np.tensordot(weights, preds, axes=([0], [0]))
preds_labels = le.inverse_transform(ensemble_preds.argmax(-1))
print(preds_labels)


In [None]:
# Save Submission
submit = pd.read_csv('./sample_submission.csv')
submit['label'] = preds_labels
dt_str = datetime.now().strftime('%Y%m%d_%H%M')
submit.to_csv(f'./basslibrary_submit_{dt_str}.csv', index=False)
logger.info(f'./basslibrary_submit_{dt_str}.csv saved')

submit.label.value_counts()

# Send Telegram Notification
try:
    os.system(f'python ~/send_telegram.py "basslibrary_submit_{dt_str}.csv saved"')
except:
    pass