In [None]:
# 比赛 https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification
# 参考 Training: https://www.kaggle.com/ttahara/ranzcr-multi-head-model-training?scriptVersionId=55258318
# Inference https://www.kaggle.com/rsinda/38th-place-solution-0-972-single-model-5-fold

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'  # 'last_expr'

import os
import sys
import gc
import math
import pickle
import random
import time
import psutil
import pytz
from datetime import datetime
from collections import defaultdict
from contextlib import contextmanager

import warnings
warnings.filterwarnings('ignore')  # warnings.filterwarnings(action='once')

from tqdm.auto import tqdm
from tqdm import tqdm_notebook

import numpy as np
import pandas as pd
_ = np.seterr(divide='ignore', invalid='ignore')

pd.set_option('display.max_columns', None)
# pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', None)
# pd.set_option('display.max_rows', 100)

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.style as style
style.use('fivethirtyeight')
import seaborn as sns

# 直接在cell中显示图片，支持jpg、png、jpeg等格式，
# Image('./2.JPG') 或者指定显示尺寸 Image("./2.png",width=900,height=400)
from IPython.display import Image  

import lightgbm as lgb
from sklearn.metrics import roc_auc_score

import tensorflow as tf
from tensorflow import keras

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

def show_process_mem_usage(info_str=''):    ## 显示当前进程占用内存大小
    process = psutil.Process(os.getpid())
    memory_usage = process.memory_info().rss
    percent = psutil.virtual_memory().percent
    
    tz = pytz.timezone('Asia/Shanghai')
    now = datetime.now(tz)
    dt_str = now.strftime("%Y-%m-%d %H:%M:%S")
    
    if memory_usage >= 2.**30:
        print(f'{info_str} current process memory usage: {memory_usage/2.**30:.3f} GB, percentage: {percent:.2f}% 【{dt_str}】')
    elif memory_usage >= 2.**20:
        print(f'{info_str} current process memory usage: {memory_usage/2.**20:.3f} MB, percentage: {percent:.2f}% 【{dt_str}】')
    elif memory_usage >= 2.**10:
        print(f'{info_str} current process memory usage: {memory_usage/2.**10:.3f} KB, percentage: {percent:.2f}% 【{dt_str}】')
    else:
        print(f'{info_str} current process memory usage: {memory_usage} B, percentage: {percent:.2f}% 【{dt_str}】')

def logging(*info, file_name='./running_log.txt'):
    log_info = ' '.join([str(s) for s in info])
    with open(file_name, 'a') as f:
        f.write(log_info + '\n')

@contextmanager
def trace(trace_msg):    ## 追踪内存变化和运行时间
    t0 = time.time()
    p = psutil.Process(os.getpid())
    m0 = p.memory_info()[0] / 2. ** 30
    yield
    m1 = p.memory_info()[0] / 2. ** 30
    delta = m1 - m0
    sign = '+' if delta >= 0 else '-'
    delta = math.fabs(delta)
    trace_msg = str(trace_msg)
    
    tz = pytz.timezone('Asia/Shanghai')
    now = datetime.now(tz)
    dt_str = now.strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{m1:.3f}GB({sign}{delta:.3f}GB):{time.time() - t0:.3f}sec] {trace_msg} 【{dt_str}】", file=sys.stdout)
    
def get_time_random_seed():
    t = int(time.time() * 1000.0)
    return  (((t & 0xff000000) >> 24) +
             ((t & 0x00ff0000) >>  8) +
             ((t & 0x0000ff00) <<  8) +
             ((t & 0x000000ff) << 24))
	
def seed_all(random_seed=42):
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True

def keepbusy(num=10000):
    start_t = time.time()
    for i in range(num):
        ftpt(f'i: {i}, taken time: {time.time() - start_t:.7f}')
        time.sleep(60)

def ftpt(msg = 'having run this cell'):  # foot_print   
    tz = pytz.timezone('Asia/Shanghai')
    now = datetime.now(tz)
    dt_string = now.strftime("%Y-%m-%d %H:%M:%S")
    print(f'{dt_string}: {msg}')
    
# 质数列表  [7, 53, 97, 317, 577, 997, 7753, 9973, 53113, 99991, 153133, 377171, 515371, 737353, 999983, 5157133, 7757537, 9999991, 99999989, 999999937]
RANDOM_SEED = 53113
seed_all(RANDOM_SEED)
    
ftpt()

# Prepare

## import

In [None]:
import copy
import shutil
import typing as tp
from pathlib import Path

import yaml
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix

from tqdm import tqdm
from joblib import Parallel, delayed

import cv2
import albumentations
from albumentations.core.transforms_interface import ImageOnlyTransform, DualTransform
from albumentations.pytorch import ToTensorV2

from torch.utils import data
from torchvision import models as torchvision_models

sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
import timm

sys.path.append('../input/pytorch-pfn-extras/pytorch-pfn-extras-0.3.2')
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions as ppe_extensions

ftpt()

In [None]:
ROOT = Path.cwd().parent
INPUT = ROOT / 'input'
OUTPUT = ROOT / 'output'
DATA = INPUT / "ranzcr-clip-catheter-line-classification"
TRAIN = DATA / 'train'
TEST = DATA/ 'test'

TRAIN_NPY = INPUT / 'ranzcr-clip-train-numpy'
TMP = ROOT / 'tmp'
TMP.mkdir(exist_ok=True)

RANDOM_SEED = 1086
N_CLASSES = 11
FOLDS = [1,]
N_FOLD = 5

CLASSES = [
    'ETT - Abnormal',
    'ETT - Borderline',
    'ETT - Normal',
    'NGT - Abnormal',
    'NGT - Borderline',
    'NGT - Incompletely Imaged',
    'NGT - Normal',
    'CVC - Abnormal',
    'CVC - Borderline',
    'CVC - Normal',
    'Swan Ganz Catheter Present'
]

ftpt()

## read data

In [None]:
for p in DATA.iterdir():
    print(p.name)
    
train = pd.read_csv(DATA / 'train.csv')
smpl_sub = pd.read_csv(DATA / 'sample_submission.csv')
ftpt()

## split fold

In [None]:
def multi_label_stratified_group_k_fold(label_arr: np.array, gid_arr: np.array, n_fold: int, seed: int=42):
    np.random.seed(seed)
    random.seed(seed)
    start_time = time.time()
    n_train, n_class = label_arr.shape
    gid_unique = sorted(set(gid_arr))
    n_group = len(gid_unique)
    
    gid2aid = dict(zip(gid_unique, range(n_group)))
    aid_arr = np.vectorize(lambda x:gid2aid[x])(gid_arr)
    
    cnts_by_class = label_arr.sum(axis=0)
    
    col, row = np.array(sorted(enumerate(aid_arr), key=lambda x: x[1])).T
    cnts_by_group = coo_matrix(
        (np.ones(len(label_arr)), (row, col))
    ).dot(coo_matrix(label_arr)).toarray().astype(int)
    del col
    del row
    cnts_by_fold = np.zeros((n_fold, n_class), int)
    
    groups_by_fold = [[] for fid in range(n_fold)]
    group_and_cnts = list(enumerate(cnts_by_group))
    np.random.shuffle(group_and_cnts)
    print('finished preparation', time.time()-start_time)
    for aid, cnt_by_g in sorted(group_and_cnts, key=lambda x: -np.std(x[1])):
        best_fold = None
        min_eval = None
        for fid in range(n_fold):
            cnts_by_fold[fid] += cnt_by_g
            fold_eval = (cnts_by_fold / cnts_by_class).std(axis=0).mean()
            cnts_by_fold[fid] -= cnt_by_g
            
            if min_eval is None or fold_eval < min_eval:
                min_eval = fold_eval
                best_fold = fid
                
        cnts_by_fold[best_fold] += cnt_by_g
        groups_by_fold[best_fold].append(aid)
    print('finished assignment: ', time.time() - start_time)
    
    gc.collect()
    idx_arr = np.arange(n_train)
    for fid in range(n_fold):
        val_groups = groups_by_fold[fid]
        
        val_indexs_bool = np.isin(aid_arr, val_groups)
        train_indexs = idx_arr[~val_indexs_bool]
        val_indexs = idx_arr[val_indexs_bool]
        
        print(f'[fold {fid}]', end=' ')
        print(f'n_group: (train, val) = ({n_group-len(val_groups)}, {len(val_groups)})', end=' ')
        print(f'n_sample: (train, val) = ({len(train_indexs)}, {len(val_indexs)})')
        
        yield train_indexs, val_indexs
        
ftpt()

In [None]:
label_arr = train[CLASSES].values
group_id = train.PatientID.values

train_val_indexs = list(multi_label_stratified_group_k_fold(label_arr, group_id, N_FOLD, RANDOM_SEED))
ftpt()

In [None]:
train['fold'] = -1
for fold_id, (trn_idx, val_idx) in enumerate(train_val_indexs):
    train.loc[val_idx, 'fold'] = fold_id
    
train.groupby('fold')[CLASSES].sum()
ftpt()

# Training

In [None]:
def get_activation(activ_name: str='relu'):
    act_dict = {
        'relu': nn.ReLU(inplace=True),
        'tanh': nn.Tanh(),
        'sigmoid': nn.Sigmoid(),
        'identity': nn.Identity(),
    }
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError
        
class Conv2dBNActiv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int,
                 kernel_size: int, stride: int=1, padding: int=0,
                 bias: bool=False, use_bn: bool=True, activ: str='relu'):
        super().__init__()
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
            
        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)
    
class SSEBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.channel_squeeze = nn.Conv2d(in_channels=in_channels, 
                                         out_channels=1, kernel_size=1, stride=1,
                                         padding=0, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        h = self.sigmoid(self.channel_squeeze(x))
        return x*h
    
class SpatialAttentionBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels_list: tp.List[int]):
        super().__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1]==1
        
        for i in range(self.n_layers -1):
            in_chs, out_chs = channels_list[i: i+2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ='relu')
            setattr(self, f'conv{i+1}', layer)
            
        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ='sigmoid')
        setattr(self, f'conv{self.n_layers}', layer)
        
    def forward(self, x):
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f'conv{i+1}')(h)
        
        h = h*x
        return h
    
ftpt()

In [None]:
class SingleHeadModel(nn.Module):
    def __init__(self, base_name: str='resnext50_32x4d', out_dim: int=11, pretrained=False):
        self.base_name = base_name
        super().__init__()
        
        base_model = timm.create_model(base_name, pretrained=pretrained)
        in_features = base_model.num_features
        
        base_model.reset_classifier(0)
        
        self.backbone = base_model
        self.head_fc = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features, out_dim)
        )
        
    def forward(self, x):
        h = self.backbone(x)
        h = self.head_fc(h)
        return h
    
class MultiHeadModel(nn.Module):
    def __init__(self, base_name: str='resnext50_32x4d',
                 out_dims_head: tp.List[int]=[3, 4, 3, 1], pretrained=False):
        self.base_name = base_name
        self.n_heads = len(out_dims_head)
        super().__init__()
        
        base_model = timm.create_model(self.base_name, num_classes=sum(out_dims_head), pretrained=False)
        in_features = base_model.num_features
            
        ## remove global pooling and head classifier
        base_model.reset_classifier(0, '')
        
        self.backbone = base_model
        
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f'head_{i}'
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim)
            )
            setattr(self, layer_name, layer)
            
    def forward(self, x):
        h = self.backbone(x)
        hs = [
            getattr(self, f'head_{i}')(h) for i in range(self.n_heads)
        ]
        y = torch.cat(hs, axis=1)
        return y
    
class MultiHeadResNet200D(nn.Module):
    def __init__(self, out_dims_head: tp.List[int]=[3, 4, 3, 1], pretrained=False):
        self.base_name = 'resnet200d_320'
        self.n_heads = len(out_dims_head)
        super().__init__()
        
        base_model = timm.create_model(self.base_name, num_classes=sum(out_dims_head), pretrained=False)
        in_features = base_model.num_features
        
        if pretrained:
            pretrained_model_path = '../input/startingpointschestx/resnet200d_320_chestx.pth'
            state_dict = dict()
            for k, v in torch.load(pretrained_model_path, map_location='cpu')['model'].items():
                if k[:6] == 'model.':
                    k = k.replace('model.', '')
                state_dict[k] = v
            base_model.load_state_dict(state_dict)
            
        ## remove global pooling and head classifier
        base_model.reset_classifier(0, '')
        
        self.backbone = base_model
        
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f'head_{i}'
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim)
            )
            setattr(self, layer_name, layer)
            
    def forward(self, x):
        h = self.backbone(x)
        hs = [
            getattr(self, f'head_{i}')(h) for i in range(self.n_heads)
        ]
        y = torch.cat(hs, axis=1)
        return y


m = MultiHeadResNet200D([3, 4, 3, 1], True)
m = m.eval()

x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
    y = m(x)
print('[forward test]')
print(f'input:\t{x.shape}\noutput:\t:{y.shape}')

del m, x, y
gc.collect()

ftpt()

## dataset

In [None]:
class LabeledImageDatasetNumpy(data.Dataset):
    def __init__(self, 
                 file_list: tp.List[
                     tp.Tuple[np.ndarray, tp.Union[int, float, np.ndarray]]],
                 transform_list: tp.List[tp.Dict],
                 copy_in_channels=True, in_channels=3,):
        self.file_list = file_list
        self.transform = ImageTransformForCls(transform_list)
        self.copy_in_channels = copy_in_channels
        self.in_channels = in_channels
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img, label = self.file_list[index]
        if img.shape[-1] == 2:
            img = img[..., None]
        
        if self.copy_in_channels:
            img = np.repeat(img, self.in_channels, axis=2)
            
        img, label = self.transform((img, label))
        return img, label
    
def get_file_list_with_array(stgs, train_all):
    use_fold = stgs['globals']['val_fold']
    
    train_idx = train_all[train_all['fold'] != use_fold].index.values
    if stgs['globals']['debug']:
        train_idx = train_idx[:len(train_idx) // 20]
    val_idx = train_all[train_all['fold'] == use_fold].index.values
    
    train_data_path = TRAIN_NPY / '{}.npy'.format(stgs['globals']['dataset_name'])
    print(train_data_path)
    
    train_data_arr = np.load(train_data_path, mmap_mode='r')
    label_arr = train_all[CLASSES].values.astype('f')
    print(train_data_arr.shape, label_arr.shape)
    
    train_file_list = [(train_data_arr[idx][..., None], label_arr[idx]) for idx in train_idx]
    val_file_list = [(train_data_arr[idx][..., None], label_arr[idx]) for idx in val_idx]
    
    return train_file_list, val_file_list

def get_dataloaders_cls(
    stgs: tp.Dict, 
    train_file_list: tp.List[tp.List],
    val_file_list: tp.List[tp.List],
    dataset_class: data.Dataset
):
    train_loader = val_loader = None
    if train_file_list is not None:
        train_dataset = dataset_class(train_file_list, **stgs['dataset']['train'])
        train_loader = data.DataLoader(train_dataset, **stgs['loader']['train'])
        
    if val_file_list is not None:
        val_dataset = dataset_class(val_file_list, **stgs['dataset']['val'])
        val_loader = data.DataLoader(val_dataset, **stgs['loader']['val'])
        
    return train_loader, val_loader

ftpt()

## image transform

In [None]:
class ImageTransformBase:
    def __init__(self, data_augmentations: tp.List[tp.Tuple[str, tp.Dict]]):
        augmentations_list = [
            self._get_augmentation(aug_name)(**params)
            for aug_name, params in data_augmentations]
        self.data_aug = albumentations.Compose(augmentations_list)
        
    def __call__(self, pair: tp.Tuple[np.ndarray]) -> tp.Tuple[np.ndarray]:
        raise NotImplementedError
        
    def _get_augmentation(self, aug_name: str) -> tp.Tuple[ImageOnlyTransform, DualTransform]:
        if hasattr(albumentations, aug_name):
            return getattr(albumentations, aug_name)
        else:
            return eval(aug_name)
        
class ImageTransformForCls(ImageTransformBase):
    def __init__(self, data_augmentations: tp.List[tp.Tuple[str, tp.Dict]]):
        super(ImageTransformForCls, self).__init__(data_augmentations)
        
    def __call__(self, in_arrs: tp.Tuple[np.ndarray]) -> tp.Tuple[np.ndarray]:
        img, label = in_arrs
        augmented = self.data_aug(image=img)
        img = augmented['image']
        return img, label
    
ftpt()

## metric

In [None]:
class EvalFuncManager(nn.Module):
    def __init__(self, iters_per_epoch: int, evalfunc_dict: tp.Dict[str, nn.Module], 
                 prefix: str='val') -> None:
        self.tmp_iter = 0
        self.iters_per_epoch = iters_per_epoch
        self.prefix = prefix
        self.metric_names = []
        super(EvalFuncManager, self).__init__()
        for k, v in evalfunc_dict.items():
            setattr(self, k, v)
            self.metric_names.append(k)
        self.reset()
        
    def reset(self) -> None:
        self.tmp_iter = 0
        for name in self.metric_names:
            getattr(self, name).reset()
            
    def __call__(self, y: torch.Tensor, t: torch.Tensor) -> None:
        for name in self.metric_names:
            getattr(self, name).update(y, t)
        self.tmp_iter += 1
        
        if self.tmp_iter == self.iters_per_epoch:
            ppe.reporting.report({
                '{}/{}'.format(self.prefix, name): getattr(self, name).compute()
                for name in self.metric_names
            })
            self.reset()
            
            
class MeanLoss(nn.Module):
    def __init__(self):
        super(MeanLoss, self).__init__()
        self.loss_sum = 0
        self.n_examples = 0
        
    def forward(self, y: torch.Tensor, t: torch.Tensor):
        return self.loss_func(y, t)
    
    def reset(self):
        self.loss_sum = 0
        self.n_examples = 0
        
    def update(self, y: torch.Tensor, t: torch.Tensor):
        self.loss_sum += self(y, t).item() * y.shape[0]
        self.n_examples += y.shape[0]
        
    def compute(self):
        return self.loss_sum / self.n_examples
    
class MyLogLoss(MeanLoss):
    def __init__(self, **params):
        super().__init__()
        self.loss_func = nn.BCEWithLogitsLoss(**params)
        
class MyROCAUC(nn.Module):
    def __init__(self, average='macro') -> None:
        self.average = average
        self._pred_list = []
        self._true_list = []
        super(MyROCAUC, self).__init__()
        
    def reset(self) -> None:
        self._pred_list = []
        self._true_list = []
        
    def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None:
        self._pred_list.append(y_pred.detach().cpu().numpy())
        self._true_list.append(y_true.detach().cpu().numpy())
        
    def compute(self) -> float:
        y_pred = np.concatenate(self._pred_list, axis=0)
        y_true = np.concatenate(self._true_list, axis=0)
        score = roc_auc_score(y_true, y_pred, average=self.average)
        return score
    
    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
        self.reset()
        self.update(y_pred, y_true)
        return self.compute()
    
ftpt()

## training utils

In [None]:
def set_random_seed(seed: int=42, deterministic: bool=False):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
    
def get_stepper(manager, stgs, scheduler):
    def dummy_step():
        pass
    
    def step():
        scheduler.step()
        
    def step_with_epoch_detail():
        scheduler.step(manager.epoch_detail)
        
    if stgs['scheduler']['name'] == None:
        return dummy_step, dummy_step
    elif stgs['scheduler']['name'] == 'CosineAnnealingWarmRestarts':
        return dummy_step, step_with_epoch_detail
    elif stgs['scheduler']['name'] == 'OneCycleLR':
        return dummy_step, step
    else:
        return step, dummy_step
    
def run_train_loop(manager, stgs, model, device, train_loader, optimizer, scheduler, loss_func):
    step_scheduler_by_epoch, step_scheduler_by_iter = get_stepper(manager, stgs, scheduler)
    
    if stgs['globals']['use_amp']:
        while not manager.stop_trigger:
            model.train()
            scaler = torch.cuda.amp.GradScaler()
            for x, t in train_loader:
                with manager.run_iteration():
                    x, t = x.to(device), t.to(device)
                    optimizer.zero_grad()
                    with torch.cuda.amp.autocast():
                        y = model(x)
                        loss = loss_func(y, t)
                    ppe.reporting.report({'train/loss': loss.item()})
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    step_scheduler_by_iter()
            step_scheduler_by_epoch()
    else:
        while not manager.stop_trigger:
            model.train()
            for x, t in train_loader:
                with manager.run_iteration():
                    x, t = x.to(device), t.to(device)
                    optimizer.zero_grad()
                    y = model(x)
                    loss = loss_func(y, t)
                    ppe.reporting.report({'train/loss': loss.item()})
                    loss.backward()
                    optimizer.step()
                    step_scheduler_by_iter()
            step_scheduler_by_epoch()
            
        
def run_eval(stgs, model, device, batch, eval_manager):
    model.eval()
    x, t = batch
    if stgs['globals']['use_amp']:
        with torch.cuda.amp.autocast():
            y = model(x.to(device))
            eval_manager(y, t.to(device))
    else:
        y = model(x.to(device))
        eval_manager(y, t.to(device))
        
ftpt()

In [None]:
def set_extensions(manager, args, model, device, val_loader, 
                   optimizer, eval_manager, print_progress: bool=False):
    eval_names = [f'val/{name}' for name in eval_manager.metric_names]
    
    log_extensions = [
        ppe_extensions.observe_lr(optimizer=optimizer),
        ppe_extensions.LogReport(),
        ppe_extensions.PlotReport(['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
        ppe_extensions.PlotReport(['lr'], 'epoch', filename='lr.png'),
        ppe_extensions.PrintReport([
            'epoch', 'iteration', 'lr', 'train/loss', *eval_names, 'elapsed_time'])
    ]
    if print_progress:
        log_extensions.append(ppe_extensions.ProgressBar(update_iterval=20))
        
    for ext in log_extensions:
        manager.extend(ext)
        
    manager.extend(
        ppe_extensions.Evaluator(
            val_loader, model,
            eval_func=lambda *batch: run_eval(args, model, device, batch, eval_manager)),
        trigger = (1, 'epoch'))
    
    manager.extend(
        ppe_extensions.snapshot(target=model, filename='snapshot_epoch_{.epoch}.pth'),
        trigger=ppe.training.triggers.MaxValueTrigger(key='val/metric', trigger=(1, 'epoch')))
    
    return manager
    
def train_one_fold(settings, train_all, output_path, print_progress=False):
    torch.backends.cudnn.benchmark = True
    set_random_seed(settings['globals']['seed'])
    
    train_file_list, val_file_list = get_file_list_with_array(settings, train_all)
    print(f'train: {len(train_file_list)}, val: {len(val_file_list)}')
    
    device = torch.device(settings['globals']['device'])
    train_loader, val_loader = get_dataloaders_cls(
        settings, train_file_list, val_file_list, LabeledImageDatasetNumpy)
    
    model = MultiHeadResNet200D(**settings['model']['params'])
    model.to(device)
    
    ## get optimizer
    optimizer = getattr(
        torch.optim, settings['optimizer']['name']
    )(model.parameters(), **settings['optimizer']['params'])
    
    ## get scheduler
    if settings['scheduler']['name'] == 'OneCycleLR':
        settings['scheduler']['params']['epochs'] = settings['globals']['max_epoch']
        settings['scheduler']['params']['step_per_epoch'] = len(train_loader)
    scheduler = getattr(
        torch.optim.lr_scheduler, settings['scheduler']['name']
    )(optimizer, **settings['scheduler']['params'])
    
    if hasattr(nn, settings['loss']['name']):
        loss_func = getattr(nn, settings['loss']['name'])(**settings['loss']['params'])
    else:
        loss_func = eval(settings['loss']['name'])(**settings['loss']['params'])
    loss_func.to(device)
    
    eval_manager = EvalFuncManager(
        len(val_loader), {
            metric['report_name']: eval(metric['name'])(**metric['params'])
            for metric in settings['eval']
        })
    eval_manager.to(device)
    
    trigger = ppe.training.triggers.EarlyStoppingTrigger(
        check_trigger = (1, 'epoch'),
        monitor = 'val/metric', mode='max',
        patience = settings['globals']['patience'], verbose=False,
        max_trigger = (settings['globals']['max_epoch'], 'epoch'),
    )
    manager = ppe.training.ExtensionsManager(
        model, optimizer, settings['globals']['max_epoch'],
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger, out_dir=output_path,
    )
    manager = set_extensions(
        manager, settings, model, device, val_loader, 
        optimizer, eval_manager, print_progress)
    
    run_train_loop(manager, settings, model, device, train_loader,
                   optimizer, scheduler, loss_func)
    
ftpt()

## Train

In [None]:
stgs_str = """
globals:
  seed: 1086
  device: cuda
  max_epoch: 16
  patience: 3
  dataset_name: train_512x512
  use_amp: True
  val_fold: 0
  debug: False

dataset:
  name: LabeledImageDatasetNumpy
  train:
    transform_list:
      - [HorizontalFlip, {p: 0.5}]
      - [ShiftScaleRotate, {
          p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
          rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}]
      - [RandomResizedCrop, {height: 512, width: 512, scale: [0.9, 1.0]}]
      - [Cutout, {max_h_size: 51, max_w_size: 51, num_holes: 5, p: 0.5}]
      - [Normalize, {
          always_apply: True, max_pixel_value: 255.0,
          mean: [0.4887381077884414], std: [0.23064819430546407]}]
      - [ToTensorV2, {always_apply: True}]
  val:
    transform_list:
      - [Normalize, {
          always_apply: True, max_pixel_value: 255.0,
          mean: [0.4887381077884414], std: [0.23064819430546407]}]
      - [ToTensorV2, {always_apply: True}]

loader:
  train: {batch_size: 16, shuffle: True, num_workers: 2, pin_memory: True, drop_last: True}
  val: {batch_size: 32, shuffle: False, num_workers: 2, pin_memory: True, drop_last: False}

model:
  name: MultiHeadResNet200D
  params:
    # base_name: resnet200D_320
    out_dims_head: [3, 4, 3, 1]
    pretrained: True

loss: {name: BCEWithLogitsLoss, params: {}}

eval:
  - {name: MyLogLoss, report_name: loss, params: {}}
  - {name: MyROCAUC, report_name: metric, params: {average: macro}}

optimizer:
    name: Adam
    params:
      lr: 2.5e-04

scheduler:
  name: CosineAnnealingWarmRestarts
  params:
    T_0: 16
    T_mult: 1
"""
stgs = yaml.safe_load(stgs_str)

if stgs["globals"]["debug"]:
    stgs["globals"]["max_epoch"] = 1
    
ftpt()

In [None]:
stgs_list = []
for fold_id in FOLDS:
    tmp_stgs = copy.deepcopy(stgs)
    tmp_stgs['globals']['val_fold'] = fold_id
    stgs_list.append(tmp_stgs)
    
ftpt()

In [None]:
torch.cuda.empty_cache()
gc.collect()

train_start_t = time.time()
for fold_id, tmp_stgs in zip(FOLDS, stgs_list):
    print('fold_id: ', fold_id)
    train_one_fold(tmp_stgs, train, TMP / f'fold{fold_id}', False)
    torch.cuda.empty_cache()
    gc.collect()
    
print(f'train finished, total cost time: {time.time()-train_start_t:.4f}')

ftpt()

## Inference OOF

In [None]:
ls ../tmp

In [None]:
best_log_list = []
for fold_id, tmp_stgs in zip(FOLDS, stgs_list):
    exp_dir_path = TMP / f'fold{fold_id}'
    log = pd.read_json(exp_dir_path / 'log')
    best_log = log.iloc[[log['val/metric'].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f'snapshot_epoch_{best_epoch}.pth'
    copy_to = f'./best_model_fold{fold_id}.pth'
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob('*.pth'):
        p.unlink()
        
    shutil.copytree(exp_dir_path, f'./fold{fold_id}')
    with open(f'./fold{fold_id}/settings.yml', 'w') as fw:
        yaml.dump(tmp_stgs, fw)
        
pd.concat(best_log_list, axis=0, ignore_index=True)

ftpt()