# About

I'd like to share an example for [small models](https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644).

### Experimental Settings

### model
* backbone: resnet18d (use the pretrained model provided by [timm](https://github.com/rwightman/pytorch-image-models))
* head classifier: one linear layer
* num of input channels: **1**

### data augmentation
* implemented by [albumentations](https://albumentations.ai/docs/) **except for Mixup**
* Train
  * Resize
  * HorizontalFlip
  * VerticalFlip
  * ShiftScaleRotate
  * RandomResizedCrop
  * Mixup(alpha=1.0)
* Val, Test
  * Resize

### learning settings
* CV Strategy: Stratified KFold (K=5)
* max epochs: 18
* data:
  * input image size: 1x320x320
  * batch size: 64
* loss: [BCEWithLogitsLoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss)
* optimizer: [AdamW](https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW)
  * weight decay: 1.0e-04
* learning rate scheduler: [OneCycleLR](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR) 
  * epochs: 18
  * max_lr: 1e-3
  * pct_start: 0.111
  * anneal_strategy: cos
  * div_factor: 1.0e+2
  * inal_div_factor: 1
  
### NOTE: I use only on-target ('A') observations

```python
img = np.load(path)[[0, 2, 4]]          # shape: (3, 273, 256)
img = np.vstack(img)                    # shape: (819, 256)
img = img.transpose(1, 0)               # shape: (256, 819)
```

### Submission -> 0.726

# Prapere

## Install

## Import

In [1]:
import os
import gc
import copy
import yaml
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.config import Config
from pytorch_pfn_extras.training import extensions as ppe_exts, triggers as ppe_triggers

In [2]:
# timm.list_models(pretrained=True)

In [3]:
ROOT = Path.cwd()
INPUT = ROOT / "../../"
OUTPUT = ROOT / "output"
DATA = INPUT / "seti-breakthrough-listen"
TRAIN = DATA / "train"
TEST = DATA / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086
CLASSES = ["target"]
N_CLASSES = len(CLASSES)
FOLDS = [0, 1, 2, 3, 4]
N_FOLDS = len(FOLDS)

## Read Data, Split folds

In [4]:
train = pd.read_csv(DATA / "train_labels.csv")
smpl_sub = pd.read_csv(DATA / "sample_submission.csv")

In [5]:
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (_, val_idx) in enumerate(skf.split(train["id"], train["target"])):
    train.loc[val_idx, "fold"] = fold_id

In [6]:
train.groupby("fold").agg(total=("id", len), pos=("target", sum))

Unnamed: 0_level_0,total,pos
fold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,12000,1200
1,12000,1200
2,12000,1200
3,12000,1200
4,12000,1200


## Definition of Model, Dataset, Metric

### Model

In [7]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name: str, dims_head: tp.List[int],
        pretrained=False, in_channels: int=3
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            base_model = timm.create_model(
                base_name, num_classes=0, pretrained=pretrained, in_chans=in_channels)
            in_features = base_model.num_features
            print("load imagenet pretrained:", pretrained)
        else:
            raise NotImplementedError

        self.backbone = base_model
        print(f"{base_name}: {in_features}")
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),
            ])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1])
        )
        self.head_cls = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head_cls(h)
        return h

### Dataset

In [8]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]


class SetiSimpleDataset(torch.utils.data.Dataset):
    """
    Dataset using 6 channels by stacking them along time-axis

    Attributes
    ----------
    paths : tp.Sequence[FilePath]
        Sequence of path to cadence snippet file
    labels : tp.Sequence[Label]
        Sequence of label for cadence snippet file
    transform: albumentations.Compose
        composed data augmentations for data
    """

    def __init__(
        self,
        paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label],
        transform: A.Compose,
    ):
        """Initialize"""
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Return num of cadence snippets"""
        return len(self.paths)

    def __getitem__(self, index: int):
        """Return transformed image and label for given index."""
        path, label = self.paths[index], self.labels[index]
        img = self._read_cadence_array(path)
        img = self.transform(image=img)["image"]
        return {"image": img, "target": label}

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)  # shape: (6, 273, 256)
        img = np.vstack(img)  # shape: (1638, 256)
        img = img.transpose(1, 0)  # shape: (256, 1638)
        img = img.astype("f")[..., np.newaxis]  # shape: (256, 1638, 1)
        return img

    def lazy_init(self, paths=None, labels=None, transform=None):
        """Reset Members"""
        if paths is not None:
            self.paths = paths
        if labels is not None:
            self.labels = labels
        if transform is not None:
            self.transform = transform


class SetiAObsDataset(SetiSimpleDataset):
    """Use only on-target observation"""

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
        img = np.vstack(img)  # shape: (819, 256)
        img = img.transpose(1, 0)  # shape: (256, 819)
        img = img.astype("f")[..., np.newaxis]  # shape: (819, 256, 1)
        return img

### Metric

In [9]:
Batch = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]]
ModelOut = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor], torch.Tensor]


class ROCAUC(nn.Module):
    """ROC AUC score"""

    def __init__(self, average="macro") -> None:
        """Initialize."""
        self.average = average
        super(ROCAUC, self).__init__()

    def forward(self, y, t) -> float:
        """Forward."""
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu().numpy()

        return roc_auc_score(t, y, average=self.average)


def micro_average(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """Return Metric Wrapper for Simple Mean Metric"""
    metric_sum = [0.]
    n_examples = [0]
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple): 
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        metric = metric_func(y, t).item()
        metric_sum[0] += metric * y.shape[0]
        n_examples[0] += y.shape[0]

        if is_last_batch:
            final_metric = metric_sum[0] / n_examples[0]
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            metric_sum[0] = 0.
            n_examples[0] = 0

    return wrapper


def calc_across_all_batchs(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """
    Return Metric Wrapper for Metrics caluculated on all data
    
    storing predictions and labels of evry batch, finally calculating metric on them.
    """
    pred_list = []
    label_list = []
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple):
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        pred_list.append(y.numpy())
        label_list.append(t.numpy())

        if is_last_batch:
            pred = np.concatenate(pred_list, axis=0)
            label = np.concatenate(label_list, axis=0)
            final_metric = metric_func(pred, label)
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            pred_list[:] = []
            label_list[:] = []

    return wrapper

# Train

## config_types for evaluating configuration

I use [pytorch-pfn-extras](https://github.com/pfnet/pytorch-pfn-extras) for training NNs. This library has useful config systems but requires some preparation.

For more details, see [docs](https://github.com/pfnet/pytorch-pfn-extras/blob/master/docs/config.md).

In [10]:
CONFIG_TYPES = {
    # # utils
    "__len__": lambda obj: len(obj),
    "method_call": lambda obj, method: getattr(obj, method)(),

    # # Dataset, DataLoader
    "SetiSimpleDataset": SetiSimpleDataset,
    "SetiAObsDataset": SetiAObsDataset,
    "DataLoader": torch.utils.data.DataLoader,

    # # Data Augmentation
    "Compose": A.Compose, "OneOf": A.OneOf,
    "Resize": A.Resize,
    "HorizontalFlip": A.HorizontalFlip, "VerticalFlip": A.VerticalFlip,
    "ShiftScaleRotate": A.ShiftScaleRotate,
    "RandomResizedCrop": A.RandomResizedCrop,
    "Cutout": A.Cutout,
    "ToTensorV2": ToTensorV2,

    # # Model
    "BasicImageModel": BasicImageModel,

    # # Optimizer
    "AdamW": optim.AdamW,

    # # Scheduler
    "OneCycleLR": lr_scheduler.OneCycleLR,

    # # Loss,Metric
    "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "ROCAUC": ROCAUC,

    # # Metric Wrapper
    "micro_average": micro_average,
    "calc_across_all_batchs": calc_across_all_batchs,

    # # PPE Extensions
    "ExtensionsManager": ppe.training.ExtensionsManager,

    "observe_lr": ppe_exts.observe_lr,
    "LogReport": ppe_exts.LogReport,
    "PlotReport": ppe_exts.PlotReport,
    "PrintReport": ppe_exts.PrintReport,
    "PrintReportNotebook": ppe_exts.PrintReportNotebook,
    "ProgressBar": ppe_exts.ProgressBar,
    "ProgressBarNotebook": ppe_exts.ProgressBarNotebook,
    "snapshot": ppe_exts.snapshot,
    "LRScheduler": ppe_exts.LRScheduler, 

    "MinValueTrigger": ppe_triggers.MinValueTrigger,
    "MaxValueTrigger": ppe_triggers.MaxValueTrigger,
    "EarlyStoppingTrigger": ppe_triggers.EarlyStoppingTrigger,
}

## configration

In [11]:
pre_eval_cfg = yaml.safe_load(
"""
globals:
  seed: 1086
  val_fold: null  # indicate when training
  output_path: null # indicate when training
  device: cuda:0
  enable_amp: False
  max_epoch: 75

model:
  type: BasicImageModel
  dims_head: [null, 1]
  base_name: resnext50_32x4d
  pretrained: True
  in_channels: 1

dataset:
  height: 320
  width: 320
  mixup: {enabled: True, alpha: 1.0}
  train:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: HorizontalFlip, p: 0.5}
        - {type: VerticalFlip, p: 0.5}
        - {type: ShiftScaleRotate, p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
            rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}
        - {type: RandomResizedCrop, p: 1.0,
            scale: [0.9, 1.0], height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}
  val:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}  
  test:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform: "@/dataset/val/transform"

loader:
  train: {type: DataLoader, dataset: "@/dataset/train",
    batch_size: 42, num_workers: 4, shuffle: True, pin_memory: True, drop_last: True}
  val: {type: DataLoader, dataset: "@/dataset/val",
    batch_size: 42, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}
  test: {type: DataLoader, dataset: "@/dataset/test",
    batch_size: 42, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}

optimizer:
  type: AdamW
  params: {type: method_call, obj: "@/model", method: parameters}
  lr: 1.0e-05
  weight_decay: 1.0e-04

scheduler:
  type: OneCycleLR
  optimizer: "@/optimizer"
  epochs: "@/globals/max_epoch"
  steps_per_epoch: {type: __len__, obj: "@/loader/train"}
  max_lr: 1.0e-3
  pct_start: 0.111
  anneal_strategy: cos
  div_factor: 1.0e+2
  final_div_factor: 1

loss: {type: BCEWithLogitsLoss}

eval:
  - type: micro_average
    metric_func: {type: BCEWithLogitsLoss}
    report_name: loss
  - type: calc_across_all_batchs
    metric_func: {type: ROCAUC}
    report_name: metric

manager:
  type: ExtensionsManager
  models: "@/model"
  optimizers: "@/optimizer"
  max_epochs: "@/globals/max_epoch"
  iters_per_epoch: {type: __len__, obj: "@/loader/train"}
  out_dir: "@/globals/output_path"
  # stop_trgiger: {type: EarlyStoppingTrigger,
  #   monitor: val/metric, mode: max, patience: 5, verbose: True,
  #   check_trigger: [1, epoch], max_trigger: ["@/globals/max_epoch", epoch]}

extensions:
  # # log
  - {type: observe_lr, optimizer: "@/optimizer"}
  - {type: LogReport}
  - {type: PlotReport, y_keys: lr, x_key: epoch, filename: lr.png}
  - {type: PlotReport, y_keys: [train/loss, val/loss], x_key: epoch, filename: loss.png}
  - {type: PlotReport, y_keys: val/metric, x_key: epoch, filename: metric.png}
  - {type: PrintReport, entries: [
      epoch, iteration, lr, train/loss, val/loss, val/metric, elapsed_time]}
  - {type: ProgressBarNotebook, update_interval: 20}
  # snapshot
  - extension: {type: snapshot, target: "@/model", filename: "snapshot_by_metric_epoch_{.epoch}.pth"}
    trigger: {type: MaxValueTrigger, key: "val/metric", trigger: [1, epoch]}
  # # lr scheduler
  - {type: LRScheduler, scheduler: "@/scheduler", trigger: [1,  iteration]}
"""
)

## functions for training

In [12]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore


def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [13]:
def get_path_label(cfg: Config, train_all: pd.DataFrame):
    """Get file path and target info."""
    use_fold = cfg["/globals/val_fold"]

    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
    train_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in train_df["id"].values],
        "labels": train_df[CLASSES].values.astype("f")}
    val_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in val_df["id"].values],
        "labels": val_df[CLASSES].values.astype("f")
    }
    return train_path_label, val_path_label


def get_eval_func(cfg, model, device):
    
    def eval_func(**batch):
        """Run evaliation for val or test. This function is applied to each batch."""
        batch = to_device(batch, device)
        x = batch["image"]
        with amp.autocast(cfg["/globals/enable_amp"]): 
            y = model(x)
        return y.detach().cpu().to(torch.float32)  # input of metrics

    return eval_func


def mixup_data(use_mixup, x, t, alpha=1.0, use_cuda=True, device="cuda:0"):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if not use_mixup:
        return x, t, None, None
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    t_a, t_b = t, t[index]
    return mixed_x, t_a, t_b, lam


def get_criterion(use_mixup, loss_func):

    def mixup_criterion(pred, t_a, t_b, lam):
        return lam * loss_func(pred, t_a) + (1 - lam) * loss_func(pred, t_b)

    def single_criterion(pred, t_a, t_b, lam):
        return loss_func(pred, t_a)
    
    if use_mixup:
        return mixup_criterion
    else:
        return single_criterion

In [14]:
def train_one_fold(cfg, train_all):
    """Main"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg["/globals/seed"], deterministic=True)
    device = torch.device(cfg["/globals/device"])
    
    train_path_label, val_path_label = get_path_label(cfg, train_all)
    print("train: {}, val: {}".format(len(train_path_label["paths"]), len(val_path_label["paths"])))
   
    cfg["/dataset/train"].lazy_init(**train_path_label)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    train_loader = cfg["/loader/train"]
    val_loader = cfg["/loader/val"]

    model = cfg["/model"]
    model.to(device)
    optimizer = cfg["/optimizer"]
    loss_func = cfg["/loss"]
    loss_func.to(device)
    
    manager = cfg["/manager"]
    for ext in cfg["/extensions"]:
        if isinstance(ext, dict):
            manager.extend(**ext)
        else:
            manager.extend(ext)

    evaluator = ppe_exts.Evaluator(
        val_loader, model, eval_func=get_eval_func(cfg, model, device),
        metrics=cfg["/eval"], progress_bar=False)
    manager.extend(evaluator, trigger=(1, "epoch"))

    use_amp = cfg["/globals/enable_amp"]
    scaler = amp.GradScaler(enabled=use_amp)
    use_mixup = cfg["/dataset/mixup/enabled"]
    mixup_alpha = cfg["/dataset/mixup/alpha"]
    
    while not manager.stop_trigger:
        model.train()
        for batch in train_loader:
            with manager.run_iteration():
                batch = to_device(batch, device)
                x, t = batch["image"], batch["target"]
                # # for mixup
                mixed_x, t_a, t_b, lam = mixup_data(use_mixup, x, t, mixup_alpha, device=cfg["/globals/device"])
                criterion = get_criterion(use_mixup, loss_func)
                
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    y = model(mixed_x)
                    loss = criterion(y, t_a, t_b, lam)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                ppe.reporting.report({'train/loss': loss.item()})

## run train

In [15]:
pre_eval_cfg_list = []
for fold_id in FOLDS:
    tmp_cfg = copy.deepcopy(pre_eval_cfg)
    tmp_cfg["globals"]["val_fold"] = fold_id
    tmp_cfg["globals"]["output_path"] = str(TMP / f"fold{fold_id}")
    pre_eval_cfg_list.append(tmp_cfg)

In [16]:
for pre_eval_cfg in pre_eval_cfg_list:
    cfg = Config(pre_eval_cfg, types=CONFIG_TYPES)
    print(f"\n[fold {cfg['/globals/val_fold']}]")
    train_one_fold(cfg, train)
    with torch.cuda.device(cfg["/globals/device"]):
        torch.cuda.empty_cache()
    del cfg
    gc.collect()


[fold 0]
train: 48000, val: 12000
load imagenet pretrained: True
resnext50_32x4d: 2048


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1142        4.47764e-05  0.346007    0.331714    0.513174    558.134       




[J2           2284        0.000144331  0.316077    0.278796    0.705328    1118.24       




[J3           3426        0.000294652  0.3057      0.250141    0.753587    1678.77       




[J4           4568        0.00047458  0.304293    0.267434    0.731956    2239.06       




[J5           5710        0.000658789  0.306119    0.264736    0.757367    2799.5        




[J6           6852        0.000821353  0.306722    0.252767    0.723696    3359.81       




[J7           7994        0.000939388  0.305934    0.278282    0.73585     3919.86       




[J8           9136        0.000996281  0.304767    0.260318    0.744576    4479.78       




[J9           10278       0.00099975  0.301662    0.252941    0.752837    5039.43       




[J10          11420       0.000998459  0.301531    0.252392    0.769962    5599.03       




[J11          12562       0.000996073  0.298031    0.25807     0.744536    6158.55       




[J12          13704       0.000992598  0.296314    0.249629    0.766168    6717.99       




[J13          14846       0.000988039  0.294651    0.244969    0.774843    7277.07       




[J14          15988       0.000982409  0.293351    0.240227    0.777655    7836.28       




[J15          17130       0.000975719  0.293281    0.247062    0.777322    8395.41       




[J16          18272       0.000967984  0.291237    0.238334    0.789724    8954.19       




[J17          19414       0.000959221  0.290106    0.243023    0.769304    9513.14       




[J18          20556       0.000949451  0.288693    0.245436    0.761045    10072         




[J19          21698       0.000938693  0.287708    0.231746    0.797344    10630.6       




[J20          22840       0.000926973  0.286859    0.228698    0.796968    11189.3       




[J21          23982       0.000914316  0.28555     0.232415    0.787784    11748.1       




[J22          25124       0.000900751  0.285269    0.232378    0.805282    12306.6       




[J23          26266       0.000886307  0.285144    0.225422    0.810678    12865         




[J24          27408       0.000871017  0.282186    0.225651    0.809835    13423.8       




[J25          28550       0.000854915  0.283375    0.223127    0.814926    13982.1       




[J26          29692       0.000838036  0.281752    0.226466    0.818961    14540.6       




[J27          30834       0.000820417  0.281225    0.220146    0.819133    15099.2       




[J28          31976       0.000802098  0.280646    0.220072    0.818867    15657.5       




[J29          33118       0.00078312  0.27926     0.223899    0.813154    16215.9       




[J30          34260       0.000763525  0.279614    0.219252    0.817705    16774.4       




[J31          35402       0.000743356  0.278084    0.216064    0.820928    17332.7       




[J32          36544       0.000722657  0.280722    0.218223    0.826336    17891.1       




[J33          37686       0.000701476  0.276735    0.213679    0.811414    18449.6       




[J34          38828       0.000679858  0.276699    0.211801    0.826997    19008.1       




[J35          39970       0.000657852  0.277533    0.21151     0.825206    19566.6       




[J36          41112       0.000635507  0.275656    0.213479    0.820824    20125         




[J37          42254       0.000612873  0.273441    0.212895    0.821821    20683.4       




[J38          43396       0.000589999  0.273098    0.210392    0.82356     21241.7       




[J39          44538       0.000566936  0.271708    0.210183    0.828222    21800         




[J40          45680       0.000543736  0.27275     0.207277    0.829819    22358.7       




[J41          46822       0.000520449  0.271325    0.209173    0.834274    22917.2       




[J42          47964       0.000497129  0.271588    0.210436    0.823412    23475.9       




[J43          49106       0.000473826  0.271325    0.205688    0.837034    24034.5       




[J44          50248       0.000450592  0.270017    0.205527    0.83316     24593.1       




[J45          51390       0.000427479  0.271152    0.204148    0.838377    25151.4       




[J46          52532       0.000404537  0.269392    0.204727    0.838548    25710.2       




[J47          53674       0.000381819  0.269023    0.201918    0.836549    26268.6       




[J48          54816       0.000359375  0.268929    0.201629    0.838852    26826.8       




[J49          55958       0.000337253  0.268274    0.199988    0.840421    27385.3       




[J50          57100       0.000315504  0.267749    0.201197    0.836237    27944.3       




[J51          58242       0.000294176  0.269189    0.201994    0.83866     28502.8       




[J52          59384       0.000273315  0.267278    0.20266     0.841848    29061.4       




[J53          60526       0.000252969  0.266931    0.196692    0.840409    29620         




[J54          61668       0.000233183  0.265613    0.200444    0.837938    30178.7       




[J55          62810       0.000213999  0.267022    0.200896    0.841039    30737.7       




[J56          63952       0.000195462  0.262505    0.198508    0.841833    31296.6       




[J57          65094       0.000177611  0.266359    0.198064    0.840678    31855.5       




[J58          66236       0.000160488  0.265571    0.197329    0.841977    32414.3       




[J59          67378       0.000144129  0.265723    0.199413    0.842779    32973.6       




[J60          68520       0.000128571  0.264142    0.197385    0.843459    33532.5       




[J61          69662       0.000113848  0.263362    0.200303    0.841796    34091.4       




[J62          70804       9.99941e-05  0.263264    0.198692    0.843019    34650.5       




[J63          71946       8.70389e-05  0.263368    0.198616    0.843993    35209.3       




[J64          73088       7.50115e-05  0.262392    0.197623    0.843743    35768.1       




[J65          74230       6.39385e-05  0.261829    0.194629    0.843279    36326.8       




[J66          75372       5.38446e-05  0.262961    0.197877    0.844321    36885.4       




[J67          76514       4.47521e-05  0.261668    0.198462    0.843188    37444.7       




[J68          77656       3.66811e-05  0.261895    0.19782     0.844425    38003.6       




[J69          78798       2.96497e-05  0.262651    0.194774    0.845397    38562.3       




[J70          79940       2.36735e-05  0.262695    0.195458    0.844367    39121.3       




[J71          81082       1.87656e-05  0.26227     0.196296    0.844479    39680.2       




[J72          82224       1.49371e-05  0.262125    0.194849    0.844562    40239         




[J73          83366       1.21963e-05  0.263037    0.194038    0.844786    40797.7       




[J74          84508       1.05494e-05  0.261658    0.193298    0.84523     41356.8       




[J75          85650       1e-05       0.26133     0.194082    0.844767    41915.7       

[fold 1]
train: 48000, val: 12000
load imagenet pretrained: True
resnext50_32x4d: 2048


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1142        4.47764e-05  0.346455    0.334367    0.509958    559.982       




[J2           2284        0.000144331  0.319016    0.295386    0.64276     1120.86       




[J3           3426        0.000294652  0.304281    0.296048    0.671929    1681.65       




[J4           4568        0.00047458  0.303532    0.273773    0.727957    2242.36       




[J5           5710        0.000658789  0.303676    0.263439    0.734225    2802.92       




[J6           6852        0.000821353  0.302306    0.266978    0.716461    3363.65       




[J7           7994        0.000939388  0.303988    0.269039    0.732358    3923.86       




[J8           9136        0.000996281  0.301977    0.254235    0.750094    4483.95       




[J9           10278       0.00099975  0.299918    0.270714    0.740181    5044.28       




[J10          11420       0.000998459  0.299548    0.273005    0.753073    5604.17       




[J11          12562       0.000996073  0.296467    0.257205    0.766991    6163.86       




[J12          13704       0.000992598  0.294991    0.246379    0.760962    6724.18       




[J13          14846       0.000988039  0.292723    0.250911    0.764522    7283.47       




[J14          15988       0.000982409  0.291502    0.246395    0.76978     7842.63       




[J15          17130       0.000975719  0.289875    0.243725    0.766687    8401.87       




[J16          18272       0.000967984  0.289583    0.240668    0.773046    8960.94       




[J17          19414       0.000959221  0.286826    0.237786    0.778081    9520.03       




[J18          20556       0.000949451  0.286988    0.236666    0.778911    10079.4       




[J19          21698       0.000938693  0.286671    0.239273    0.785711    10638.2       




[J20          22840       0.000926973  0.286488    0.236324    0.787299    11197.1       




[J21          23982       0.000914316  0.283973    0.235377    0.790009    11756.2       




[J22          25124       0.000900751  0.284174    0.235362    0.79156     12315.2       




[J23          26266       0.000886307  0.283514    0.2317      0.79188     12873.9       




[J24          27408       0.000871017  0.280305    0.231939    0.78695     13432.7       




[J25          28550       0.000854915  0.281425    0.228704    0.793292    13991.1       




[J26          29692       0.000838036  0.280847    0.235013    0.785889    14549.7       




[J27          30834       0.000820417  0.279527    0.223219    0.798705    15108.1       




[J28          31976       0.000802098  0.279818    0.226564    0.794682    15666.6       




[J29          33118       0.00078312  0.277565    0.234707    0.789138    16225         




[J30          34260       0.000763525  0.279101    0.223952    0.797735    16783.4       




[J31          35402       0.000743356  0.275895    0.225096    0.790682    17342         




[J32          36544       0.000722657  0.276559    0.220126    0.805641    17900.7       




[J33          37686       0.000701476  0.275635    0.232142    0.792384    18459.3       




[J34          38828       0.000679858  0.275245    0.225245    0.800608    19017.8       




[J35          39970       0.000657852  0.274886    0.220041    0.800499    19576         




[J36          41112       0.000635507  0.27269     0.221005    0.801046    20134.4       




[J37          42254       0.000612873  0.27084     0.220062    0.80822     20693.5       




[J38          43396       0.000589999  0.273637    0.217943    0.81426     21251.9       




[J39          44538       0.000566936  0.271076    0.218517    0.806037    21810.3       




[J40          45680       0.000543736  0.270507    0.217822    0.807096    22368.6       




[J41          46822       0.000520449  0.271897    0.216682    0.809981    22926.9       




[J42          47964       0.000497129  0.270354    0.217139    0.806975    23485.2       




[J43          49106       0.000473826  0.270357    0.218722    0.812979    24043.6       




[J44          50248       0.000450592  0.268775    0.215476    0.810023    24601.8       




[J45          51390       0.000427479  0.27044     0.214243    0.810352    25160.2       




[J46          52532       0.000404537  0.268292    0.215533    0.815542    25718.8       




[J47          53674       0.000381819  0.268585    0.219656    0.803112    26277.2       




[J48          54816       0.000359375  0.268102    0.214931    0.812367    26835.6       




[J49          55958       0.000337253  0.266372    0.211232    0.816125    27394.1       




[J50          57100       0.000315504  0.265736    0.210573    0.81411     27952.6       




[J51          58242       0.000294176  0.266977    0.211342    0.818107    28510.9       




[J52          59384       0.000273315  0.264626    0.213256    0.816453    29069.6       




[J53          60526       0.000252969  0.265852    0.214313    0.81561     29628.1       




[J54          61668       0.000233183  0.263288    0.210117    0.81849     30186.7       




[J55          62810       0.000213999  0.265021    0.215195    0.816018    30745.6       




[J56          63952       0.000195462  0.261817    0.207297    0.818495    31304.2       




[J57          65094       0.000177611  0.26399     0.210773    0.818681    31862.9       




[J58          66236       0.000160488  0.264115    0.209431    0.821024    32421.8       




[J59          67378       0.000144129  0.264544    0.207373    0.81842     32980.5       




[J60          68520       0.000128571  0.262182    0.208651    0.819776    33539.1       




[J61          69662       0.000113848  0.261436    0.212173    0.818704    34097.6       




[J62          70804       9.99941e-05  0.26211     0.208708    0.819407    34656.4       




[J63          71946       8.70389e-05  0.262157    0.209516    0.817874    35214.9       




[J64          73088       7.50115e-05  0.260019    0.210046    0.816361    35773.7       




[J65          74230       6.39385e-05  0.260993    0.209537    0.815871    36332.5       




[J66          75372       5.38446e-05  0.261801    0.210445    0.819459    36891.2       




[J67          76514       4.47521e-05  0.26108     0.208734    0.821508    37449.8       




[J68          77656       3.66811e-05  0.259434    0.209159    0.820543    38009.7       




[J69          78798       2.96497e-05  0.260928    0.206183    0.821675    38568.4       




[J70          79940       2.36735e-05  0.259738    0.208504    0.820239    39127.2       




[J71          81082       1.87656e-05  0.259616    0.208879    0.822903    39686         




[J72          82224       1.49371e-05  0.26072     0.207415    0.82262     40244.8       




[J73          83366       1.21963e-05  0.260337    0.207747    0.821735    40803.4       




[J74          84508       1.05494e-05  0.258578    0.207588    0.821906    41362.1       




[J75          85650       1e-05       0.259551    0.206323    0.821688    41920.8       

[fold 2]
train: 48000, val: 12000
load imagenet pretrained: True
resnext50_32x4d: 2048


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1142        4.47764e-05  0.345715    0.334282    0.519722    559.807       




[J2           2284        0.000144331  0.319282    0.271879    0.670662    1120.7        




[J3           3426        0.000294652  0.304912    0.270005    0.749712    1681.55       




[J4           4568        0.00047458  0.304519    0.264513    0.742713    2242.16       




[J5           5710        0.000658789  0.306158    0.268809    0.725581    2802.72       




[J6           6852        0.000821353  0.306638    0.260343    0.738646    3363.09       




[J7           7994        0.000939388  0.307675    0.270763    0.726423    3923.13       




[J8           9136        0.000996281  0.304004    0.282909    0.724523    4483.16       




[J9           10278       0.00099975  0.302555    0.251408    0.764898    5043.27       




[J10          11420       0.000998459  0.301502    0.255242    0.75074     5603.32       




[J11          12562       0.000996073  0.299006    0.253855    0.767077    6163.14       




[J12          13704       0.000992598  0.297039    0.24695     0.772662    6723.07       




[J13          14846       0.000988039  0.296129    0.241872    0.774867    7282.81       




[J14          15988       0.000982409  0.293675    0.237675    0.77621     7842.39       




[J15          17130       0.000975719  0.291679    0.243627    0.784585    8401.99       




[J16          18272       0.000967984  0.291766    0.229013    0.785922    8961.39       




[J17          19414       0.000959221  0.290349    0.229525    0.791944    9520.62       




[J18          20556       0.000949451  0.28913     0.235685    0.786327    10079.9       




[J19          21698       0.000938693  0.286641    0.230233    0.789574    10639         




[J20          22840       0.000926973  0.286852    0.228324    0.789876    11198.1       




[J21          23982       0.000914316  0.284969    0.229566    0.792808    11757.1       




[J22          25124       0.000900751  0.284869    0.227087    0.793442    12316         




[J23          26266       0.000886307  0.282832    0.230523    0.793764    12874.9       




[J24          27408       0.000871017  0.283281    0.229299    0.7991      13433.7       




[J25          28550       0.000854915  0.281776    0.224444    0.809038    13992.4       




[J26          29692       0.000838036  0.282706    0.220171    0.805011    14551         




[J27          30834       0.000820417  0.281552    0.218768    0.806318    15109.7       




[J28          31976       0.000802098  0.278788    0.221151    0.804076    15670         




[J29          33118       0.00078312  0.27965     0.216284    0.813278    16228.6       




[J30          34260       0.000763525  0.278973    0.211928    0.81044     16787.3       




[J31          35402       0.000743356  0.277933    0.215914    0.80615     17345.7       




[J32          36544       0.000722657  0.279148    0.21837     0.811938    17903.9       




[J33          37686       0.000701476  0.276261    0.218971    0.810179    18462.2       




[J34          38828       0.000679858  0.276593    0.217841    0.815636    19020.4       




[J35          39970       0.000657852  0.274922    0.209421    0.811696    19578.7       




[J36          41112       0.000635507  0.275314    0.211078    0.812881    20136.9       




[J37          42254       0.000612873  0.274756    0.214451    0.816921    20695.1       




[J38          43396       0.000589999  0.273417    0.213164    0.813818    21253.6       




[J39          44538       0.000566936  0.273042    0.208407    0.815344    21811.9       




[J40          45680       0.000543736  0.272536    0.211917    0.812318    22370.4       




[J41          46822       0.000520449  0.272331    0.206783    0.817784    22928.6       




[J42          47964       0.000497129  0.270846    0.212158    0.814542    23487.1       




[J43          49106       0.000473826  0.27069     0.209549    0.812201    24045.7       




[J44          50248       0.000450592  0.271458    0.209827    0.818141    24604         




[J45          51390       0.000427479  0.271362    0.206208    0.819071    25162.6       




[J46          52532       0.000404537  0.270311    0.211572    0.818335    25721.3       




[J47          53674       0.000381819  0.26845     0.207117    0.821379    26279.6       




[J48          54816       0.000359375  0.267979    0.203534    0.821713    26838.1       




[J49          55958       0.000337253  0.268491    0.204754    0.819542    27397         




[J50          57100       0.000315504  0.267411    0.205948    0.820974    27955.6       




[J51          58242       0.000294176  0.268629    0.202952    0.824514    28514.3       




[J52          59384       0.000273315  0.265679    0.206009    0.823008    29073.3       




[J53          60526       0.000252969  0.266114    0.201393    0.826115    29632         




[J54          61668       0.000233183  0.265098    0.20448     0.824757    30190.9       




[J55          62810       0.000213999  0.266416    0.202386    0.827836    30749.9       




[J56          63952       0.000195462  0.263201    0.201095    0.827779    31308.9       




[J57          65094       0.000177611  0.265807    0.20131     0.829137    31867.7       




[J58          66236       0.000160488  0.265437    0.199315    0.830729    32426.9       




[J59          67378       0.000144129  0.263813    0.201836    0.829813    32985.9       




[J60          68520       0.000128571  0.260795    0.198791    0.82975     33544.8       




[J61          69662       0.000113848  0.263701    0.20035     0.828538    34103.9       




[J62          70804       9.99941e-05  0.26301     0.200235    0.831008    34662.7       




[J63          71946       8.70389e-05  0.263801    0.199751    0.830485    35221.7       




[J64          73088       7.50115e-05  0.262876    0.199639    0.831339    35780.7       




[J65          74230       6.39385e-05  0.262356    0.197219    0.832356    36339.7       




[J66          75372       5.38446e-05  0.261936    0.200387    0.831877    36898.6       




[J67          76514       4.47521e-05  0.261766    0.199899    0.830579    37457.5       




[J68          77656       3.66811e-05  0.262099    0.199109    0.832358    38016         




[J69          78798       2.96497e-05  0.261473    0.197114    0.832644    38574.7       




[J70          79940       2.36735e-05  0.261793    0.197398    0.832202    39133.7       




[J71          81082       1.87656e-05  0.261135    0.199741    0.831496    39692.5       




[J72          82224       1.49371e-05  0.260415    0.198435    0.832194    40251.2       




[J73          83366       1.21963e-05  0.26095     0.199293    0.831513    40812         




[J74          84508       1.05494e-05  0.260531    0.196942    0.832099    41370.8       




[J75          85650       1e-05       0.260667    0.197515    0.831957    41929.6       

[fold 3]
train: 48000, val: 12000
load imagenet pretrained: True
resnext50_32x4d: 2048


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1142        4.47764e-05  0.34564     0.331511    0.507122    559.731       




[J2           2284        0.000144331  0.319921    0.284124    0.699951    1120.46       




[J3           3426        0.000294652  0.30569     0.260862    0.743437    1681.03       




[J4           4568        0.00047458  0.304887    0.264835    0.750261    2241.82       




[J5           5710        0.000658789  0.305237    0.275345    0.728982    2802.13       




[J6           6852        0.000821353  0.307098    0.264465    0.730922    3362.32       




[J7           7994        0.000939388  0.304689    0.264234    0.740089    3922.52       




[J8           9136        0.000996281  0.302732    0.254455    0.753844    4482.42       




[J9           10278       0.00099975  0.301401    0.255579    0.766309    5042.42       




[J10          11420       0.000998459  0.299453    0.25129     0.770635    5602.53       




[J11          12562       0.000996073  0.298332    0.255566    0.754767    6162.33       




[J12          13704       0.000992598  0.295723    0.243386    0.774042    6721.74       




[J13          14846       0.000988039  0.295314    0.265497    0.762829    7281.4        




[J14          15988       0.000982409  0.292347    0.235457    0.797734    7840.78       




[J15          17130       0.000975719  0.291242    0.244321    0.77039     8400.12       




[J16          18272       0.000967984  0.290147    0.244001    0.785499    8959.36       




[J17          19414       0.000959221  0.289052    0.234588    0.796862    9518.54       




[J18          20556       0.000949451  0.289942    0.259195    0.780604    10077.6       




[J19          21698       0.000938693  0.288677    0.242481    0.793599    10636.5       




[J20          22840       0.000926973  0.287226    0.233477    0.791876    11195.6       




[J21          23982       0.000914316  0.286244    0.226535    0.797415    11754.2       




[J22          25124       0.000900751  0.284756    0.230889    0.799353    12312.9       




[J23          26266       0.000886307  0.28464     0.227287    0.80195     12867.3       




[J24          27408       0.000871017  0.281854    0.225486    0.799369    13421.6       




[J25          28550       0.000854915  0.281777    0.221154    0.805917    13976.4       




[J26          29692       0.000838036  0.281777    0.222501    0.807336    14531.2       




[J27          30834       0.000820417  0.280083    0.221426    0.802985    15086.2       




[J28          31976       0.000802098  0.279554    0.22199     0.805024    15640.7       




[J29          33118       0.00078312  0.280406    0.221238    0.814804    16195.2       




[J30          34260       0.000763525  0.279564    0.221999    0.808694    16749.9       




[J31          35402       0.000743356  0.279759    0.21553     0.815865    17304.4       




[J32          36544       0.000722657  0.277664    0.218692    0.811228    17859.2       




[J33          37686       0.000701476  0.27561     0.214106    0.814496    18413.8       




[J34          38828       0.000679858  0.275407    0.217126    0.814812    18968.6       




[J35          39970       0.000657852  0.276045    0.205541    0.821383    19523.3       




[J36          41112       0.000635507  0.275799    0.213851    0.816956    20078.3       




[J37          42254       0.000612873  0.273277    0.214879    0.820936    20633         




[J38          43396       0.000589999  0.274314    0.218594    0.823601    21187.9       




[J39          44538       0.000566936  0.272131    0.20933     0.824573    21742.8       




[J40          45680       0.000543736  0.271035    0.206444    0.826708    22297.9       




[J41          46822       0.000520449  0.272059    0.211211    0.823296    22852.9       




[J42          47964       0.000497129  0.271463    0.214931    0.820285    23407.9       




[J43          49106       0.000473826  0.271446    0.20654     0.826653    23962.8       




[J44          50248       0.000450592  0.270344    0.209831    0.825363    24519.3       




[J45          51390       0.000427479  0.269834    0.207285    0.822352    25074.3       




[J46          52532       0.000404537  0.269002    0.207143    0.823561    25629.1       




[J47          53674       0.000381819  0.270391    0.208106    0.827596    26183.9       




[J48          54816       0.000359375  0.267989    0.20921     0.823424    26739         




[J49          55958       0.000337253  0.26814     0.199924    0.8309      27293.9       




[J50          57100       0.000315504  0.26935     0.203337    0.828306    27849         




[J51          58242       0.000294176  0.268432    0.209499    0.830063    28404         




[J52          59384       0.000273315  0.265948    0.20696     0.829637    28959.1       




[J53          60526       0.000252969  0.264318    0.203094    0.831376    29514.2       




[J54          61668       0.000233183  0.266918    0.206162    0.828556    30069.5       




[J55          62810       0.000213999  0.266243    0.205413    0.832009    30624.5       




[J56          63952       0.000195462  0.262576    0.198276    0.835182    31179.9       




[J57          65094       0.000177611  0.267129    0.201379    0.835169    31735.2       




[J58          66236       0.000160488  0.263309    0.204125    0.830869    32290.5       




[J59          67378       0.000144129  0.264511    0.203865    0.834916    32845.9       




[J60          68520       0.000128571  0.262128    0.201361    0.836593    33401.3       




[J61          69662       0.000113848  0.262638    0.201067    0.832881    33956.9       




[J62          70804       9.99941e-05  0.264193    0.199781    0.836887    34512.5       




[J63          71946       8.70389e-05  0.263141    0.2024      0.831294    35068.3       




[J64          73088       7.50115e-05  0.262286    0.201585    0.830633    35624         




[J65          74230       6.39385e-05  0.261791    0.196804    0.83618     36179.7       




[J66          75372       5.38446e-05  0.262536    0.201114    0.834921    36735.6       




[J67          76514       4.47521e-05  0.260602    0.199455    0.835547    37291.4       




[J68          77656       3.66811e-05  0.262297    0.199992    0.835665    37847.3       




[J69          78798       2.96497e-05  0.260527    0.199969    0.836807    38403         




[J70          79940       2.36735e-05  0.261141    0.199012    0.836912    38958.8       




[J71          81082       1.87656e-05  0.262241    0.200499    0.834256    39514.8       




[J72          82224       1.49371e-05  0.260938    0.198917    0.834955    40070.7       




[J73          83366       1.21963e-05  0.262909    0.196703    0.838109    40626.8       




[J74          84508       1.05494e-05  0.259975    0.19746     0.837898    41183         




[J75          85650       1e-05       0.261232    0.200271    0.836625    41739         

[fold 4]
train: 48000, val: 12000
load imagenet pretrained: True
resnext50_32x4d: 2048


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1142        4.47764e-05  0.346652    0.342028    0.530774    557.331       




[J2           2284        0.000144331  0.314938    0.280878    0.683142    1115.72       




[J3           3426        0.000294652  0.305398    0.276692    0.679652    1674.11       




[J4           4568        0.00047458  0.30567     0.283264    0.697435    2232.24       




[J5           5710        0.000658789  0.306193    0.268356    0.71198     2790.53       




[J6           6852        0.000821353  0.305622    0.258582    0.753719    3348.64       




[J7           7994        0.000939388  0.304352    0.261859    0.757147    3906.46       




[J8           9136        0.000996281  0.302658    0.260142    0.761167    4464.01       




[J9           10278       0.00099975  0.301073    0.252417    0.761025    5021.44       




[J10          11420       0.000998459  0.300007    0.241439    0.775891    5578.64       




[J11          12562       0.000996073  0.296653    0.251307    0.771188    6135.74       




[J12          13704       0.000992598  0.295308    0.25716     0.770297    6692.43       




[J13          14846       0.000988039  0.295226    0.253359    0.77102     7248.94       




[J14          15988       0.000982409  0.293188    0.245974    0.784353    7805.51       




[J15          17130       0.000975719  0.290681    0.241635    0.795193    8362          




[J16          18272       0.000967984  0.290275    0.235682    0.798541    8918.31       




[J17          19414       0.000959221  0.288662    0.241476    0.787939    9474.62       




[J18          20556       0.000949451  0.288574    0.239721    0.784105    10030.6       




[J19          21698       0.000938693  0.287423    0.233864    0.80277     10586.7       




[J20          22840       0.000926973  0.286372    0.237475    0.787803    11142.6       




[J21          23982       0.000914316  0.283344    0.224577    0.808396    11698.5       




[J22          25124       0.000900751  0.282961    0.230046    0.806988    12254.4       




[J23          26266       0.000886307  0.283171    0.240454    0.801512    12810.3       




[J24          27408       0.000871017  0.281616    0.237085    0.802204    13365.9       




[J25          28550       0.000854915  0.280755    0.230117    0.805874    13921.5       




[J26          29692       0.000838036  0.279931    0.223901    0.810939    14476.9       




[J27          30834       0.000820417  0.279583    0.225341    0.810692    15032.6       




[J28          31976       0.000802098  0.278724    0.229919    0.814951    15588.2       




[J29          33118       0.00078312  0.279981    0.235731    0.813937    16143.7       




[J30          34260       0.000763525  0.278237    0.223708    0.816916    16698.9       




[J31          35402       0.000743356  0.278666    0.222039    0.815087    17254.4       




[J32          36544       0.000722657  0.278024    0.225939    0.819134    17809.7       




[J33          37686       0.000701476  0.275613    0.221552    0.817413    18365.1       




[J34          38828       0.000679858  0.277335    0.224086    0.816129    18920.5       




[J35          39970       0.000657852  0.276226    0.219782    0.817761    19475.9       




[J36          41112       0.000635507  0.274353    0.217821    0.820251    20031.1       




[J37          42254       0.000612873  0.273937    0.223213    0.815061    20586.4       




[J38          43396       0.000589999  0.273152    0.215309    0.82313     21141.6       




[J39          44538       0.000566936  0.272349    0.223356    0.821819    21697.1       




[J40          45680       0.000543736  0.273562    0.225075    0.817181    22252.2       




[J41          46822       0.000520449  0.272518    0.217844    0.823534    22807.2       




[J42          47964       0.000497129  0.271052    0.217132    0.822112    23362.3       




[J43          49106       0.000473826  0.270495    0.2202      0.820487    23917.2       




[J44          50248       0.000450592  0.269157    0.211349    0.825523    24472.1       




[J45          51390       0.000427479  0.270432    0.212336    0.82226     25027.1       




[J46          52532       0.000404537  0.269122    0.218317    0.826676    25582         




[J47          53674       0.000381819  0.268782    0.211508    0.826025    26137.2       




[J48          54816       0.000359375  0.268312    0.209629    0.830326    26692.2       




[J49          55958       0.000337253  0.266951    0.211161    0.826864    27247.5       




[J50          57100       0.000315504  0.267146    0.208509    0.828743    27802.7       




[J51          58242       0.000294176  0.269699    0.215636    0.827554    28358.1       




[J52          59384       0.000273315  0.266927    0.210759    0.835429    28913.2       




[J53          60526       0.000252969  0.265621    0.206527    0.834935    29468.4       




[J54          61668       0.000233183  0.26616     0.208288    0.832179    30023.6       




[J55          62810       0.000213999  0.267208    0.208663    0.832927    30578.9       




[J56          63952       0.000195462  0.263329    0.208642    0.831564    31134.2       




[J57          65094       0.000177611  0.265605    0.210295    0.835831    31689.6       




[J58          66236       0.000160488  0.26366     0.206358    0.838157    32245         




[J59          67378       0.000144129  0.264588    0.209128    0.833533    32800.5       




[J60          68520       0.000128571  0.263073    0.20491     0.837323    33355.8       




[J61          69662       0.000113848  0.261486    0.208689    0.836886    33913.3       




[J62          70804       9.99941e-05  0.262023    0.206997    0.838681    34468.8       




[J63          71946       8.70389e-05  0.263017    0.205807    0.837475    35024.6       




[J64          73088       7.50115e-05  0.262606    0.207366    0.83674     35580.1       




[J65          74230       6.39385e-05  0.261665    0.205934    0.838429    36135.8       




[J66          75372       5.38446e-05  0.259798    0.207107    0.837261    36691.6       




[J67          76514       4.47521e-05  0.262379    0.206897    0.838296    37247.3       




[J68          77656       3.66811e-05  0.261369    0.206567    0.838958    37803.1       




[J69          78798       2.96497e-05  0.260558    0.207717    0.838021    38358.9       




[J70          79940       2.36735e-05  0.261478    0.20522     0.838713    38914.5       




[J71          81082       1.87656e-05  0.260488    0.204852    0.840055    39470.1       




[J72          82224       1.49371e-05  0.261006    0.204244    0.838752    40025.9       




[J73          83366       1.21963e-05  0.260518    0.20535     0.83966     40581.7       




[J74          84508       1.05494e-05  0.260228    0.203502    0.839335    41137.1       




[J75          85650       1e-05       0.260462    0.204979    0.837397    41692.8       


# Inference

## Copy best models

In [17]:
best_log_list = []
for pre_eval_cfg, fold_id in zip(pre_eval_cfg_list, FOLDS):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/metric"].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_metric_epoch_{best_epoch}.pth"
    copy_to = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/config.yml", "w") as fw:
        yaml.dump(pre_eval_cfg, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

Unnamed: 0,train/loss,lr,val/loss,val/metric,epoch,iteration,elapsed_time
0,0.262651,3e-05,0.194774,0.845397,69,78798,38562.335654
1,0.259616,1.9e-05,0.208879,0.822903,71,81082,39685.997894
2,0.261473,3e-05,0.197114,0.832644,69,78798,38574.711084
3,0.262909,1.2e-05,0.196703,0.838109,73,83366,40626.809673
4,0.260488,1.9e-05,0.204852,0.840055,71,81082,39470.126813


## Inference OOF & Test

In [18]:
def run_inference_loop(cfg, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch["image"], device)
            y = model(x)
            pred_list.append(y.sigmoid().detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [19]:
label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train), N_CLASSES))
score_list = []
test_pred_arr = np.zeros((N_FOLDS, len(smpl_sub), N_CLASSES))
test_path_label = {
    "paths": [DATA / f"test/{img_id[0]}/{img_id}.npy" for img_id in smpl_sub["id"].values],
    "labels": smpl_sub[CLASSES].values.astype("f")
}

for fold_id in FOLDS:
    print(f"\n[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "config.yml", "r") as fr:
        cfg = Config(yaml.safe_load(fr), types=CONFIG_TYPES)
    device = torch.device(cfg["/globals/device"])
    val_idx = train.query("fold == @fold_id").index.values

    # # get_dataloader
    _, val_path_label = get_path_label(cfg, train)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    cfg["/dataset/test"].lazy_init(**test_path_label)
    val_loader = cfg["/loader/val"]
    test_loader = cfg["/loader/test"]
    
    # # get model
    model_path = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    model = cfg["/model"]
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # # inference
    val_pred = run_inference_loop(cfg, model, val_loader, device)
    val_score = roc_auc_score(label_arr[val_idx], val_pred)
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_score])
    
    test_pred_arr[fold_id] = run_inference_loop(cfg, model, test_loader, device)
    
    del cfg, val_idx, val_path_label
    del model, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"val score: {val_score:.4f}")


[fold 0]
load imagenet pretrained: True
resnext50_32x4d: 2048


  0%|          | 0/286 [00:00<?, ?it/s]



  0%|          | 0/953 [00:00<?, ?it/s]



val score: 0.8454

[fold 1]
load imagenet pretrained: True
resnext50_32x4d: 2048


  0%|          | 0/286 [00:00<?, ?it/s]



  0%|          | 0/953 [00:00<?, ?it/s]



val score: 0.8229

[fold 2]
load imagenet pretrained: True
resnext50_32x4d: 2048


  0%|          | 0/286 [00:00<?, ?it/s]



  0%|          | 0/953 [00:00<?, ?it/s]



val score: 0.8326

[fold 3]
load imagenet pretrained: True
resnext50_32x4d: 2048


  0%|          | 0/286 [00:00<?, ?it/s]



  0%|          | 0/953 [00:00<?, ?it/s]



val score: 0.8381

[fold 4]
load imagenet pretrained: True
resnext50_32x4d: 2048


  0%|          | 0/286 [00:00<?, ?it/s]



  0%|          | 0/953 [00:00<?, ?it/s]



val score: 0.8401


In [20]:
oof_score = roc_auc_score(label_arr, oof_pred_arr)
score_list.append(["oof", oof_score])
pd.DataFrame(score_list, columns=["fold", "metric"])

Unnamed: 0,fold,metric
0,0,0.845397
1,1,0.822903
2,2,0.832644
3,3,0.838109
4,4,0.840055
5,oof,0.835904


In [21]:
oof_df = train.copy()
oof_df[CLASSES] = oof_pred_arr
oof_df.to_csv(OUTPUT / "oof_prediction.csv", index=False)

## Make submission

In [22]:
sub_df = smpl_sub.copy()
sub_df[CLASSES] = test_pred_arr.mean(axis=0)
sub_df.to_csv(OUTPUT / "submission.csv", index=False)

In [23]:
sub_df.head()

Unnamed: 0,id,target
0,000bf832cae9ff1,0.070979
1,000c74cc71a1140,0.179661
2,000f5f9851161d3,0.069309
3,000f7499e95aba6,0.111574
4,00133ce6ec257f9,0.088089


# EOF