In [None]:
kaggle = True
mode = 'training'
#mode = 'inference'
debug = True

In [None]:
import sys

if kaggle:
    sys.path.append('/kaggle/input/pytorch-image-models/pytorch-image-models-master')
    sys.path.append('/kaggle/input/timm-pretrained-efficientnet/efficientnet')
    sys.path.append('/kaggle/input/pycm-master/pycm-master')
    sys.path.append('/kaggle/input/utilities')
else:
    sys.path.append('./utils')

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import random
import time
import gc
import json
import numbers
import copy
import matplotlib.pyplot as plt
import seaborn as sns

from functools import partial
from collections import OrderedDict
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from pycm import ConfusionMatrix

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold
from losses import LabelSmoothing, MyCrossEntropyLoss

In [None]:
if kaggle:
    data_dir = '/kaggle/input/cassava-leaf-disease-classification'
    cache_dir = '/kaggle/working'
    input_dir = '/kaggle/input'
else:
    data_dir = '../../../data/cassava'
    cache_dir = './'
    input_dir = './cache'

train_images_dir = Path(data_dir) / 'train_images'
test_images_dir = Path(data_dir) / 'test_images'
save_weights_dir = Path(cache_dir) / 'weights'
load_weights_dir = Path(input_dir) / 'weights'

os.makedirs(str(save_weights_dir), exist_ok=True)

In [None]:
class GlobalConfig:
    device_ids = [0, 1]
    device = torch.device(f'cuda:{device_ids[0]}' if torch.cuda.is_available() else "cpu")
    image_model_name = 'tf_efficientnet_b4_ns'
    load_weights = False
    use_multi_gpus = False
    num_workers = 8
    batch_size = 16
    num_folds = 2 if debug else 5
    num_epochs = 3 if debug else 20
    seed = 42
    shuffle = False
    onehot_cols = ['label_0', 'label_1', 'label_2', 'label_3', 'label_4']
    num_targets = len(onehot_cols)
    use_tta = True
    weights_mode = 'loss'

class DataConfig:
    batch_size = GlobalConfig.batch_size
    num_workers = GlobalConfig.num_workers
    onehot_cols = GlobalConfig.onehot_cols
    scale_size = (512, 512)
    input_size = (512, 512)
    # data transforms
    trans_params = {
        'interpolation': 'BILINEAR',
        'random_resized_crop_scale': (0.08, 1.0),
        'random_resized_crop_ratio': (0.75, 1.3333333333333333),
        'rgb_mean': (0.485, 0.456, 0.406),
        'rgb_std': (0.229, 0.224, 0.225),
        'hue_shift_limit': (-10, 10),
        'sat_shift_limit': (-15, 15),
        'val_shift_limit': (-10, 10),
        'brightness_limit': (-0.05, 0.05),
        'contrast_limit': (-0.05, 0.05)
    }

class ModelConfig:
    image_model_name = GlobalConfig.image_model_name
    num_targets = GlobalConfig.num_targets
    if kaggle:
        json_path = f"/kaggle/input/timm-pretrained-efficientnet/index.json"
        weights_dict = {}
        with open(json_path, mode="r") as f:
            weights_dict = json.load(f)
        weight_name = weights_dict['efficientnet'][f'{image_model_name}']
        path = f"/kaggle/input/timm-pretrained-efficientnet/efficientnet/{weight_name}"

class FitterConfig:
    device = GlobalConfig.device
    num_epochs = GlobalConfig.num_epochs
    batch_size = GlobalConfig.batch_size
    onehot_cols = GlobalConfig.onehot_cols
    use_tta = GlobalConfig.use_tta
    weights_mode = GlobalConfig.weights_mode
    num_targets = len(onehot_cols)
    num_augs_tta = 3
    early_stopping_start = 0 if debug else 8
    early_stopping_rounds = 4
    iters_to_accumulate = 2
    change_device_config = True
    # criterion
    train_criterion = LabelSmoothing(smoothing=0.05)
#    train_criterion = MyCrossEntropyLoss()
    valid_criterion = MyCrossEntropyLoss()
    # optimizer
    optimizer = partial(torch.optim.Adam, lr=1e-4, weight_decay=1e-6)
    # scheduler
    scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode='max', patience=2, factor=0.2)
#    scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)
    step_scheduler = False  # do scheduler.step after optimizer.step
    valid_scheduler = True  # do scheduler.step after validation stage loss
    scheduler_update_by_loss = True

In [None]:
train_df = pd.read_csv(str(Path(data_dir) / 'train.csv'))
test_df = pd.read_csv(str(Path(data_dir) / 'sample_submission.csv'))

train_df = train_df.sample(100).reset_index(drop=True) if debug else train_df

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def feature_engineering(df, config=GlobalConfig):
    df_ = df.copy()
    df_ = pd.get_dummies(df_, columns=['label'])
    df_['label'] = df['label']
    return df_

In [None]:
# https://github.com/knjcode/pytorch-finetuner
def custom_four_crop(img, size):
    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w)))
    
    center = transforms.functional.center_crop(img, (crop_h, crop_w))
    full = transforms.functional.resize(img, (crop_h, crop_w))

    img_ = transforms.functional.hflip(img)

    center_ = transforms.functional.center_crop(img_, (crop_h, crop_w))
    full_ = transforms.functional.resize(img_, (crop_h, crop_w))

    return (center, full, center_, full_)

class CustomFourCrop(object):
    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
        return custom_four_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

class RandomAugs(object):
    def __init__(self, config, num_augs=3):
        self.input_size = config.input_size
        self.trans_params = config.trans_params
        self.num_augs = num_augs

    def __call__(self, image):
        return self._random_augs(image)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.input_size)

    def _random_augs(self, image):
        images = []
        for _ in range(self.num_augs):
            images.append(self._transforms()(image))
        return tuple(images)

    def _transforms(self):
        interpolation = getattr(Image, self.trans_params['interpolation'], 2)
        data_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=self.input_size,
                                         scale=self.trans_params['random_resized_crop_scale'],
                                         ratio=self.trans_params['random_resized_crop_ratio'],
                                         interpolation=interpolation),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=45., translate=(0.0625, 0.0625), scale=(0.9, 1.1), shear=10.),
            transforms.ColorJitter(brightness=self.trans_params['brightness_limit'][1],
                                   contrast=self.trans_params['contrast_limit'][1],
                                   saturation=self.trans_params['sat_shift_limit'][1] / 255.,
                                   hue=self.trans_params['hue_shift_limit'][1] / 255.),
        ])
        return data_transform

In [None]:
def train_transforms(config):
    trans_params = config.trans_params
    data_transform = A.Compose([
        A.RandomResizedCrop(width=config.input_size[0], height=config.input_size[1],
                            scale=trans_params['random_resized_crop_scale'],
                            ratio=trans_params['random_resized_crop_ratio']),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.HueSaturationValue(hue_shift_limit=trans_params['hue_shift_limit'],
                             sat_shift_limit=trans_params['sat_shift_limit'],
                             val_shift_limit=trans_params['val_shift_limit'], p=0.5),
        A.RandomBrightnessContrast(brightness_limit=trans_params['brightness_limit'],
                                   contrast_limit=trans_params['contrast_limit'], p=0.5),
        A.Normalize(mean=trans_params['rgb_mean'], std=trans_params['rgb_std'], max_pixel_value=255.0),
        A.CoarseDropout(p=0.5),
        A.Cutout(p=0.5),
        ToTensorV2(),
    ])
    return data_transform

def valid_transforms(config):
    trans_params = config.trans_params
    data_transform = A.Compose([
        A.Resize(width=config.scale_size[0], height=config.scale_size[1]),
        A.CenterCrop(width=config.input_size[0], height=config.input_size[1]),
        A.Normalize(mean=trans_params['rgb_mean'], std=trans_params['rgb_std'], max_pixel_value=255.0),
        ToTensorV2(),
    ])
    return data_transform

def test_transforms(config):
    trans_params = config.trans_params
    data_transform = A.Compose([
        A.Resize(width=config.scale_size[0], height=config.scale_size[1]),
        A.CenterCrop(width=config.input_size[0], height=config.input_size[1]),
        A.Normalize(mean=trans_params['rgb_mean'], std=trans_params['rgb_std'], max_pixel_value=255.0),
        ToTensorV2(),
    ])
    return data_transform

def tta_transforms(config, num_augs=3, mode='normal'):
    trans_params = config.trans_params
    interpolation = getattr(Image, trans_params['interpolation'], 2)
    if mode == 'heavy':
        main_transforms = RandomAugs(config=config, num_augs=num_augs)
    else:
        main_transforms = CustomFourCrop(size=config.input_size)
    data_transform = transforms.Compose([
        main_transforms,
        transforms.Lambda(lambda augs: torch.stack([
            transforms.ToTensor()(aug) for aug in augs])),
        transforms.Lambda(lambda augs: torch.stack([
            transforms.Normalize(trans_params['rgb_mean'], trans_params['rgb_std'])(aug) for aug in augs])),
    ])
    return data_transform

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, config, df, phase, transforms, indices=None):
        self.image_ids = df.image_id.values
        self.data_transform = transforms
        self.phase = phase
        if phase in ['train', 'valid']:
            self.targets = df[config.onehot_cols].values.astype(np.float32)
            images_dir = train_images_dir
        else:
            images_dir = test_images_dir
        self.images_dir = images_dir
        if indices is None:
            indices = torch.from_numpy(df.index.values.astype(np.int))
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        index = self.indices[index]
        image_id = self.image_ids[index]
        image_path = str(Path(self.images_dir) / image_id)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.phase in ['train', 'valid']:
            image = self.data_transform(image=image)['image']
            target = self.targets[index]
            return image_id, image, target
        elif self.phase == 'tta':
            image = transforms.ToPILImage()(image)
            tta_images = self.data_transform(image)
            return image_id, tta_images
        else:
            image = self.data_transform(image=image)['image']
            return image_id, image

In [None]:
def check_dataset(config=GlobalConfig, index=[0, 1, 2]):
    train_df_ = train_df.copy()
    train_df_ = feature_engineering(train_df_, config=config)
    train_dataset = CassavaDataset(config=DataConfig, df=train_df_, phase='train',
                                   transforms=train_transforms(DataConfig))
    for i in index:
        image_id, image, target = train_dataset.__getitem__(index=i)
        image = np.clip(image.numpy().transpose((1, 2, 0)), 0, 1)
        print(f"image_id: {image_id}, target: {target}")
        plt.imshow(image)
        plt.show()

check_dataset()

In [None]:
class DataLoaderMaker():

    def __init__(self, train_df, test_df, config=DataConfig):
        self.config = config
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.train_df = train_df
        self.test_df = test_df

    def _train_dataset(self, indices):
        return CassavaDataset(df=self.train_df, phase='train', transforms=train_transforms(self.config),
                              indices=indices, config=self.config)

    def train_dataloader(self, train_index):
        return DataLoader(self._train_dataset(train_index), batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=True, shuffle=True)

    def _valid_dataset(self, indices):
        return CassavaDataset(df=self.train_df, phase='valid', transforms=valid_transforms(self.config),
                              indices=indices, config=self.config)

    def valid_dataloader(self, valid_index):
        return DataLoader(self._valid_dataset(valid_index), batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=False, shuffle=False)

    def _test_dataset(self):
        return CassavaDataset(df=self.test_df, phase='test', transforms=test_transforms(self.config),
                              config=self.config)

    def test_dataloader(self):
        return DataLoader(self._test_dataset(), batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=False, shuffle=False)

    def _tta_dataset(self, num_augs):
        return CassavaDataset(df=self.test_df, phase='tta', transforms=tta_transforms(self.config, num_augs),
                              config=self.config)

    def tta_dataloader(self, num_augs=3):
        return DataLoader(self._tta_dataset(num_augs), batch_size=self.batch_size,
                          num_workers=self.num_workers, drop_last=False, shuffle=False)

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self._reset()

    def _reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count

class RocAucMeter(object):
    def __init__(self, n_class=2):
        self.n_class = n_class
        self.reset()

    def reset(self):
        # to avoid sklearn method crushing, when batch has only one class
        self.y_true = np.arange(self.n_class)
        self.y_pred = np.full((self.n_class, self.n_class), 1. / self.n_class)
        self.score = 0

    def update(self, y_true, y_pred):
        y_true = np.hstack(torch.argmax(y_true, dim=-1).cpu().numpy())
        y_pred = nn.functional.softmax(y_pred, dim=-1).data.cpu().numpy()
        self.y_true = np.concatenate([self.y_true, y_true], axis=0)
        self.y_pred = np.concatenate([self.y_pred, y_pred], axis=0)
        row_sums = np.sum(self.y_pred, 1)
        row_sums = np.repeat(row_sums, self.n_class).reshape(-1, self.n_class)
        self.y_pred = np.divide(self.y_pred , row_sums)
        self.score = metrics.roc_auc_score(self.y_true, self.y_pred, multi_class="ovo")

    @property
    def avg(self):
        return self.score

class AccMeter(object):
    def __init__(self, n_class):
        self.n_class = n_class
        self.reset()

    def reset(self):
        # to avoid sklearn method crushing, when batch has only one class
        self.y_true = np.arange(self.n_class)
        self.y_pred = np.arange(self.n_class)

    def update(self, y_true, y_pred):
        y_true = np.hstack(torch.argmax(y_true, dim=-1).cpu().numpy())
        y_pred = np.hstack(torch.argmax(y_pred, dim=-1).data.cpu().numpy())
        self.y_true = np.hstack((self.y_true, y_true))
        self.y_pred = np.hstack((self.y_pred, y_pred))
        self.correct_count = metrics.accuracy_score(self.y_true, self.y_pred, normalize=False)
        self.score = (self.correct_count - self.n_class) / (self.y_true.shape[0] - self.n_class)

    @property
    def avg(self):
        return self.score

class FoldLogger(object):
    def __init__(self, save_fig=False):
        self.save_fig = save_fig
        self.train_loss = []
        self.train_acc = []
        self.valid_loss = []
        self.valid_acc = []

    def update_train(self, train_loss=None, train_acc=None):
        self.train_loss.append(train_loss)
        self.train_acc.append(train_acc)

    def update_valid(self, valid_loss=None, valid_acc=None):
        self.valid_loss.append(valid_loss)
        self.valid_acc.append(valid_acc)

    def get_graph(self, title=None, fig_title=None):
        self.train_loss = [e for e in self.train_loss if e is not None]
        self.train_acc = [e for e in self.train_acc if e is not None]
        self.valid_loss = [e for e in self.valid_loss if e is not None]
        self.valid_acc = [e for e in self.valid_acc if e is not None]

        fig, (axis_l, axis_r) = plt.subplots(ncols=2, figsize=(12, 3))

        if self.train_acc:
            axis_l.plot(np.arange(0, len(self.train_acc)), self.train_acc,
                        linestyle="solid", label="train acc", color='b')
        if self.valid_acc:
            axis_l.plot(np.arange(0, len(self.valid_acc)), self.valid_acc,
                        linestyle="solid", label="valid acc", color='r')
        axis_l.legend(bbox_to_anchor=(1, 0), loc='lower right', borderaxespad=1)
        axis_l.set_xlabel("epochs")
        axis_l.set_ylabel("acc")
        axis_l.grid(True)

        if self.train_loss:
            axis_r.plot(np.arange(0, len(self.train_loss)), self.train_loss, linestyle="solid", label="train loss", color='b')
        if self.valid_loss:
            axis_r.plot(np.arange(0, len(self.valid_loss)), self.valid_loss, linestyle="solid", label="valid loss", color='r')
        axis_r.legend(bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=1)
        axis_r.set_xlabel("epochs")
        axis_r.set_ylabel("loss")
        axis_r.grid(True)

        if title is not None:
            fig.suptitle(title)

        if fig_title is not None and self.save_fig:
            fig.savefig(fig_title)

        plt.show()

In [None]:
class CassavaNet(nn.Module):
    def __init__(self, config=ModelConfig):
        super().__init__()
        self.image_model = self._get_image_model(config)
        self.image_model.classifier = nn.Linear(in_features=self.image_model.classifier.in_features,
                                                out_features=config.num_targets)

    def _get_image_model(self, config):
        if kaggle:
            image_model = timm.create_model(config.image_model_name, pretrained=False)
            image_model.load_state_dict(torch.load(config.path))
        else:
            image_model = timm.create_model(config.image_model_name, pretrained=True)
        return image_model

    def forward(self, image):
        x = self.image_model(image)
        return x

In [None]:
def fix_model_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith('module.'):
            name = name[7:]  # remove 'module.' of dataparallel
        new_state_dict[name] = v
    return new_state_dict

In [None]:
class Fitter(nn.Module):

    def __init__(self, model, dataloader_maker, train_index=None, valid_index=None, config=FitterConfig):
        super(Fitter, self).__init__()
        self.device = config.device
        self.train_criterion = config.train_criterion.to(config.device)
        self.valid_criterion = config.valid_criterion.to(config.device)
        self.train_index = torch.from_numpy(train_index)
        self.valid_index = torch.from_numpy(valid_index)
        self.best_loss = np.Inf
        self.best_acc = -np.Inf
        self.train_dataloader = dataloader_maker.train_dataloader(self.train_index)
        self.valid_dataloader = dataloader_maker.valid_dataloader(self.valid_index)
        self.test_dataloader = dataloader_maker.test_dataloader()
        self.tta_dataloader = dataloader_maker.tta_dataloader(config.num_augs_tta)
        if not kaggle and config.use_multi_gpus and torch.cuda.device_count() > 1:
            model = nn.DataParallel(model, device_ids=config.device_ids)
        self.model = model.to(config.device)
        self.optimizer = config.optimizer(params=self.model.parameters())
        self.scheduler = config.scheduler(optimizer=self.optimizer)
        self.scaler = GradScaler()
        self.num_epochs = config.num_epochs
        self.batch_size = config.batch_size
        self.early_stopping_start = config.early_stopping_start
        self.early_stopping_rounds = config.early_stopping_rounds
        self.num_targets = config.num_targets
        self.step_scheduler = config.step_scheduler
        self.valid_scheduler = config.valid_scheduler
        self.scheduler_update_by_loss = config.scheduler_update_by_loss
        self.iters_to_accumulate = config.iters_to_accumulate
        self.onehot_cols = config.onehot_cols
        self.change_device_config = config.change_device_config
        self.use_tta = config.use_tta
        self.weights_mode = config.weights_mode

    def fit(self):
        self.fold_logger = FoldLogger()
        wait = 0
        for epoch in range(self.num_epochs):
            t = time.time()
            loss, acc, roc_auc = self._train_one_epoch()
            print(f"[RESULT]: Train. Epoch: {epoch}, acc: {acc.avg:.9f}, roc_auc: {roc_auc.avg:.9f}"
                  + f", loss: {loss.avg:.9f}, time: {(time.time() - t):.5f}")
            self.fold_logger.update_train(train_loss=loss.avg, train_acc=acc.avg)

            t = time.time()
            loss, acc, roc_auc, valid_preds_df = self._valid_one_epoch()
            print(f"[RESULT]: Valid. Epoch: {epoch}, acc: {acc.avg:.9f}, roc_auc: {roc_auc.avg:.9f}"
                  + f", loss: {loss.avg:.9f}, time: {(time.time() - t):.5f}")
            self.fold_logger.update_valid(valid_loss=loss.avg, valid_acc=acc.avg)

            if self.valid_scheduler:
                if self.scheduler_update_by_loss:
                    self.scheduler.step(loss.avg)
                else:
                    self.scheduler.step()

            current = acc.avg
            if current > self.best_acc:
                print(f'Validation acc increased ({self.best_acc:.9f} --> {current:.9f}). Saving model ...')
                wait = 0
                self.best_acc = current
                loss_at_best_acc = loss.avg
                roc_auc_at_best_acc = roc_auc.avg
                self.weights_at_best_acc = copy.deepcopy(self.model.state_dict())
                if self.weights_mode == 'acc':
                    self.best_weights = self.weights_at_best_acc

            current = loss.avg
            if current < self.best_loss:
                print(f'Validation loss decreased ({self.best_loss:.9f} --> {current:.9f}). Saving model ...')
                wait = 0
                self.best_loss = current
                acc_at_best_loss = acc.avg
                roc_auc_at_best_loss = roc_auc.avg
                self.weights_at_best_loss = copy.deepcopy(self.model.state_dict())
                if self.weights_mode == 'loss':
                    self.best_weights = self.weights_at_best_loss
            elif epoch > self.early_stopping_start:
                wait += 1
                print(f'EarlyStopping counter: {wait} out of {self.early_stopping_rounds}')
                if (self.early_stopping_rounds > 0) and (wait >= self.early_stopping_rounds):
                    print('Epoch %05d: early stopping' % (epoch))
                    break

        if self.weights_mode == 'acc':
            return loss_at_best_acc, self.best_acc, roc_auc_at_best_acc, valid_preds_df
        else:
            return self.best_loss, acc_at_best_loss, roc_auc_at_best_loss, valid_preds_df

    # https://pytorch.org/docs/stable/notes/amp_examples.html
    # https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
    def _train_one_epoch(self):
        self.model.train()
        epoch_loss = AverageMeter()
        epoch_acc = AccMeter(n_class=self.num_targets)
        epoch_roc_auc = RocAucMeter(n_class=self.num_targets)
        for step, (image_ids, images, targets) in enumerate(tqdm(self.train_dataloader)):
            images = images.to(self.device, dtype=torch.float)
            targets = targets.to(self.device, dtype=torch.float)
            batch_size = len(image_ids)
            with autocast():
                preds = self.model(images)
                loss = self.train_criterion(preds, targets)
            self.scaler.scale(loss / self.iters_to_accumulate).backward()
            if (step + 1) % self.iters_to_accumulate == 0 or step + 1 == len(self.train_dataloader):
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
            if self.step_scheduler:
                self.scheduler.step()
            epoch_loss.update(loss.item(), batch_size)
            epoch_acc.update(targets, preds)
            epoch_roc_auc.update(targets, preds)
        return epoch_loss, epoch_acc, epoch_roc_auc

    def _valid_one_epoch(self):
        self.model.eval()
        epoch_loss = AverageMeter()
        epoch_acc = AccMeter(n_class=self.num_targets)
        epoch_roc_auc = RocAucMeter(n_class=self.num_targets)
        image_ids_list = []
        valid_preds = np.empty((0, self.num_targets), dtype=np.float)
        for image_ids, images, targets in tqdm(self.valid_dataloader):
            images = images.to(self.device, dtype=torch.float)
            targets = targets.to(self.device, dtype=torch.float)
            batch_size = len(image_ids)
            with torch.no_grad():
                preds = self.model(images)
                loss = self.valid_criterion(preds, targets)
            epoch_loss.update(loss.item(), batch_size)
            epoch_acc.update(targets, preds)
            epoch_roc_auc.update(targets, preds)
            image_ids_list += list(image_ids)
            valid_preds = np.append(valid_preds, preds.data.cpu().numpy(), axis=0)
        valid_preds_df = pd.concat([pd.DataFrame(image_ids_list, columns=['image_id']),
                                    pd.DataFrame(valid_preds, columns=self.onehot_cols)], axis=1)
        return epoch_loss, epoch_acc, epoch_roc_auc, valid_preds_df

    def validation(self):
        loss, acc, roc_auc, valid_preds_df = self._valid_one_epoch()
        return loss.avg, acc.avg, roc_auc.avg, valid_preds_df

    def _prediction(self):
        self.set_model_weights(self.best_weights)
        self.model.eval()
        image_ids_list = []
        test_preds = np.empty((0, self.num_targets), dtype=np.float)
        for image_ids, images in tqdm(self.test_dataloader):
            images = images.to(self.device, dtype=torch.float)
            batch_size = len(image_ids)
            with torch.no_grad():
                preds = self.model(images)
            image_ids_list += list(image_ids)
            test_preds = np.append(test_preds, preds.data.cpu().numpy(), axis=0)
        test_preds_df = pd.concat([pd.DataFrame(image_ids_list, columns=['image_id']),
                                   pd.DataFrame(test_preds, columns=self.onehot_cols)], axis=1)
        return test_preds_df

    def _prediction_with_tta(self):
        self.set_model_weights(self.best_weights)
        self.model.eval()
        image_ids_list = []
        test_preds = np.empty((0, self.num_targets), dtype=np.float)
        for image_ids, tta_images in tqdm(self.tta_dataloader):
            tta_images = tta_images.to(self.device, dtype=torch.float).transpose(0, 1)
            n_augs, batch_size, ch, height, width = tta_images.size()
            with torch.no_grad():
                preds_ = torch.zeros(batch_size, self.num_targets).to(self.device)
                for i in range(n_augs):
                    preds = self.model(tta_images[i])
                    preds_ += preds.to(self.device)
                preds_ /= n_augs
            image_ids_list += list(image_ids)
            test_preds = np.append(test_preds, preds_.data.cpu().numpy(), axis=0)
        test_preds_df = pd.concat([pd.DataFrame(image_ids_list, columns=['image_id']),
                                   pd.DataFrame(test_preds, columns=self.onehot_cols)], axis=1)
        return test_preds_df

    def prediction(self):
        if self.use_tta:
            test_preds_df = self._prediction_with_tta()
        else:
            test_preds_df = self._prediction()
        return test_preds_df

    def get_best_model_weights(self):
        return self.best_weights

    def set_model_weights(self, weights):
        self.model.load_state_dict(weights)

    def save_model_weights(self, path):
        self.model.eval()
        torch.save(self.model.state_dict(), path)

    def save_best_model_weights(self, path):
        torch.save(self.best_weights, path)

    def load_model_weights(self, path):
        if self.change_device_config:
            self.best_weights = fix_model_state_dict(torch.load(path))
        else:
            self.best_weights = torch.load(path)
        self.model.load_state_dict(self.best_weights)
        print(f"weights: {path} loaded.")

    def get_fold_log(self, fold=None):
        if fold is not None:
            title = f'convergence check of fold {fold}'
            fig_title = f'fold{fold}.png'
            self.fold_logger.get_graph(title=title, fig_title=fig_title)
        else:
            title = f'convergence check'
            self.fold_logger.get_graph(title=title)

    def delete(self):
        del self.model, self.optimizer, self.scheduler, self.scaler

In [None]:
def get_loss(preds, targets, config=FitterConfig):
    preds = torch.from_numpy(preds).to(config.device).float()
    targets = torch.from_numpy(targets).to(config.device).float()
    with torch.no_grad():
        loss = config.valid_criterion(preds, targets)
    return loss

In [None]:
def get_label(df, config=GlobalConfig):
    return pd.concat([df, pd.DataFrame(np.argmax(df[config.onehot_cols].values, axis=-1), columns=['label'])], axis=1)

In [None]:
def get_probability(df, config=GlobalConfig):
    preds = torch.from_numpy(df[config.onehot_cols].values).to(config.device)
    with torch.no_grad():
        probs = nn.functional.softmax(preds, dim=-1).data.cpu().numpy()
    df_ = pd.concat([df.drop(columns=config.onehot_cols), pd.DataFrame(probs, columns=config.onehot_cols)], axis=1)
    return df_

In [None]:
def get_confusion_matrix(preds_list, targets_list):
    cm = ConfusionMatrix(targets_list, preds_list)
    cm.relabel(mapping={0: 'cbb', 1: 'cbsd', 2: 'cgm', 3: 'cmd', 4: 'healthy'})
    cm.save_obj(os.path.join(cache_dir, 'cm'))

In [None]:
# https://tech-blog.optim.co.jp/entry/2020/12/08/100000
def plot_cm(cm, normalize=False, title='Confusion matrix', annot=True, fmt='d', cmap='YlGnBu'):
    data = cm.matrix
    if normalize:
        title += '(Normalized)'
        data = cm.normalized_matrix
        fmt = '.3f'
    df = pd.DataFrame(data).T.fillna(0)
    ax = sns.heatmap(df, annot=annot, cmap=cmap, fmt=fmt)
    ax.set_title(title)
    ax.set(xlabel='Predict', ylabel='Actual')

In [None]:
def evaluate_model_with_cv(train_df, test_df, config=GlobalConfig):

    seed_everything(config.seed)

    dataloader_maker = DataLoaderMaker(train_df, test_df, config=DataConfig)

    skf = StratifiedKFold(n_splits=config.num_folds, random_state=config.seed, shuffle=config.shuffle)

    oof_valid_preds_df_ = pd.DataFrame(index=[], columns=['image_id'] + config.onehot_cols)
    fold_best_losses = []
    for fold, (train_index, valid_index) in enumerate(skf.split(train_df.values, train_df.label.values)):

        t = time.time()
        print(f'----- fold {fold}/{config.num_folds - 1} Started. -----')

        model = CassavaNet()

        fitter = Fitter(model, dataloader_maker, train_index, valid_index)

        if mode == 'training':
            if config.load_weights:
                file_path = Path(load_weights_dir) / f'{config.image_model_name}_{str(fold)}.pth'
                fitter.load_model_weights(str(file_path))
            file_path = Path(save_weights_dir) / f'{config.image_model_name}_{str(fold)}.pth'
            fold_best_loss, fold_acc, fold_roc_auc, fold_valid_preds_df = fitter.fit()
            fitter.save_best_model_weights(str(file_path))
            fitter.get_fold_log(fold)
        else:
            file_path = Path(load_weights_dir) / f'{config.image_model_name}_{str(fold)}.pth'
            fitter.load_model_weights(str(file_path))
            fold_best_loss, fold_acc, fold_roc_auc, fold_valid_preds_df = fitter.validation()

        print(f"[fold {fold} best]: Valid. acc: {fold_acc:.9f}, roc_auc: {fold_roc_auc:.9f}"
              + f", loss: {fold_best_loss:.9f}, time: {(time.time() - t):.5f}")

        fold_best_losses.append(fold_best_loss)

        oof_valid_preds_df_ = pd.concat([oof_valid_preds_df_, fold_valid_preds_df],
                                        sort=False, ignore_index=True)

        fold_preds_df = fitter.prediction()

        if fold == 0:
            oof_preds_df = fold_preds_df.copy()
        else:
            oof_preds_df[config.onehot_cols] += fold_preds_df[config.onehot_cols]

        fitter.delete()

        del model, fitter
        del fold_preds_df, fold_valid_preds_df
        gc.collect()

    oof_valid_preds_df = pd.merge(train_df.image_id, oof_valid_preds_df_, on='image_id', how='outer')
    oof_preds_df[config.onehot_cols] = oof_preds_df[config.onehot_cols] / config.num_folds

    oof_labels = train_df.label.values
    oof_valid_preds = np.argmax(oof_valid_preds_df[config.onehot_cols].values, axis=-1)
    oof_valid_accuracy = metrics.accuracy_score(oof_labels, oof_valid_preds, normalize=True)

    get_confusion_matrix(preds_list=oof_valid_preds.tolist(), targets_list=oof_labels.tolist())

    print(f"[OOF]: acc: {oof_valid_accuracy:.9f}, loss (CV): {np.mean(fold_best_losses):.9f}"
          + f", std of fold best: {np.std(fold_best_losses):.9f}")

    return oof_preds_df, oof_valid_preds_df

In [None]:
def run_everything():

    t = time.time()

    print("device: {}".format(GlobalConfig.device))

    train_df_ = train_df.copy()
    test_df_ = test_df.copy()
    
    train_df_ = feature_engineering(train_df_)

    oof_preds_df, oof_valid_preds_df = evaluate_model_with_cv(train_df_, test_df_)

    oof_valid_preds_df = get_label(oof_valid_preds_df)
    oof_preds_df = get_label(oof_preds_df)
    submission_df = oof_preds_df.drop(columns=GlobalConfig.onehot_cols).copy()

    oof_valid_preds_df.to_csv(str(Path(cache_dir) / 'valid_preds.csv'), index=False)
    oof_preds_df.to_csv(str(Path(cache_dir) / 'preds.csv'), index=False)
    submission_df.to_csv(str(Path(cache_dir) / 'submission.csv'), index=False)

    oof_probs_df = get_probability(oof_preds_df)
    oof_probs_df.to_csv(str(Path(cache_dir) / 'probs.csv'), index=False)

    print(f'All Completed. total time: {(time.time() - t):.5f}')

In [None]:
if __name__ == '__main__':
    run_everything()

In [None]:
valid_preds_df = pd.read_csv(str(Path(cache_dir) / 'valid_preds.csv'))
valid_preds_df

In [None]:
cm = ConfusionMatrix(file=open(os.path.join(cache_dir, 'cm.obj'), 'r'))

In [None]:
plot_cm(cm)

In [None]:
plot_cm(cm, normalize=True)

In [None]:
preds_df = pd.read_csv(str(Path(cache_dir) / 'preds.csv'))
preds_df

In [None]:
probs_df = pd.read_csv(str(Path(cache_dir) / 'probs.csv'))
probs_df

In [None]:
submission_df = pd.read_csv(str(Path(cache_dir) / 'submission.csv'))
submission_df