# About

I'd like to share an example for [small models](https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644).

### Experimental Settings

### model
* backbone: resnet18d (use the pretrained model provided by [timm](https://github.com/rwightman/pytorch-image-models))
* head classifier: one linear layer
* num of input channels: **1**

### data augmentation
* implemented by [albumentations](https://albumentations.ai/docs/) **except for Mixup**
* Train
  * Resize
  * HorizontalFlip
  * VerticalFlip
  * ShiftScaleRotate
  * RandomResizedCrop
  * Mixup(alpha=1.0)
* Val, Test
  * Resize

### learning settings
* CV Strategy: Stratified KFold (K=5)
* max epochs: 18
* data:
  * input image size: 1x320x320
  * batch size: 64
* loss: [BCEWithLogitsLoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss)
* optimizer: [AdamW](https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW)
  * weight decay: 1.0e-04
* learning rate scheduler: [OneCycleLR](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR) 
  * epochs: 18
  * max_lr: 1e-3
  * pct_start: 0.111
  * anneal_strategy: cos
  * div_factor: 1.0e+2
  * inal_div_factor: 1
  
### NOTE: I use only on-target ('A') observations

```python
img = np.load(path)[[0, 2, 4]]          # shape: (3, 273, 256)
img = np.vstack(img)                    # shape: (819, 256)
img = img.transpose(1, 0)               # shape: (256, 819)
```

### Submission -> 

# Prapere

## Install

## Import

In [1]:
import os
import gc
import copy
import yaml
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.config import Config
from pytorch_pfn_extras.training import extensions as ppe_exts, triggers as ppe_triggers

In [2]:
# timm.list_models(pretrained=True)

In [3]:
ROOT = Path.cwd()
INPUT = ROOT / "../../"
OUTPUT = ROOT / "output"
DATA = INPUT / "seti-breakthrough-listen"
TRAIN = DATA / "train"
TEST = DATA / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086
CLASSES = ["target"]
N_CLASSES = len(CLASSES)
FOLDS = [0, 1, 2, 3, 4]
N_FOLDS = len(FOLDS)

## Read Data, Split folds

In [4]:
train = pd.read_csv(DATA / "train_labels.csv")
smpl_sub = pd.read_csv(DATA / "sample_submission.csv")

In [5]:
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (_, val_idx) in enumerate(skf.split(train["id"], train["target"])):
    train.loc[val_idx, "fold"] = fold_id

In [6]:
train.groupby("fold").agg(total=("id", len), pos=("target", sum))

Unnamed: 0_level_0,total,pos
fold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,12000,1200
1,12000,1200
2,12000,1200
3,12000,1200
4,12000,1200


## Definition of Model, Dataset, Metric

### Model

In [7]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name: str, dims_head: tp.List[int],
        pretrained=False, in_channels: int=3
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            base_model = timm.create_model(
                base_name, num_classes=0, pretrained=pretrained, in_chans=in_channels)
            in_features = base_model.num_features
            print("load imagenet pretrained:", pretrained)
        else:
            raise NotImplementedError

        self.backbone = base_model
        print(f"{base_name}: {in_features}")
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),
            ])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1])
        )
        self.head_cls = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head_cls(h)
        return h

### Dataset

In [8]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]


class SetiSimpleDataset(torch.utils.data.Dataset):
    """
    Dataset using 6 channels by stacking them along time-axis

    Attributes
    ----------
    paths : tp.Sequence[FilePath]
        Sequence of path to cadence snippet file
    labels : tp.Sequence[Label]
        Sequence of label for cadence snippet file
    transform: albumentations.Compose
        composed data augmentations for data
    """

    def __init__(
        self,
        paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label],
        transform: A.Compose,
    ):
        """Initialize"""
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Return num of cadence snippets"""
        return len(self.paths)

    def __getitem__(self, index: int):
        """Return transformed image and label for given index."""
        path, label = self.paths[index], self.labels[index]
        img = self._read_cadence_array(path)
        img = self.transform(image=img)["image"]
        return {"image": img, "target": label}

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)  # shape: (6, 273, 256)
        img = np.vstack(img)  # shape: (1638, 256)
        img = img.transpose(1, 0)  # shape: (256, 1638)
        img = img.astype("f")[..., np.newaxis]  # shape: (256, 1638, 1)
        return img

    def lazy_init(self, paths=None, labels=None, transform=None):
        """Reset Members"""
        if paths is not None:
            self.paths = paths
        if labels is not None:
            self.labels = labels
        if transform is not None:
            self.transform = transform


class SetiAObsDataset(SetiSimpleDataset):
    """Use only on-target observation"""

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
        img = np.vstack(img)  # shape: (819, 256)
        img = img.transpose(1, 0)  # shape: (256, 819)
        img = img.astype("f")[..., np.newaxis]  # shape: (819, 256, 1)
        return img

### Metric

In [9]:
Batch = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]]
ModelOut = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor], torch.Tensor]


class ROCAUC(nn.Module):
    """ROC AUC score"""

    def __init__(self, average="macro") -> None:
        """Initialize."""
        self.average = average
        super(ROCAUC, self).__init__()

    def forward(self, y, t) -> float:
        """Forward."""
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu().numpy()

        return roc_auc_score(t, y, average=self.average)


def micro_average(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """Return Metric Wrapper for Simple Mean Metric"""
    metric_sum = [0.]
    n_examples = [0]
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple): 
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        metric = metric_func(y, t).item()
        metric_sum[0] += metric * y.shape[0]
        n_examples[0] += y.shape[0]

        if is_last_batch:
            final_metric = metric_sum[0] / n_examples[0]
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            metric_sum[0] = 0.
            n_examples[0] = 0

    return wrapper


def calc_across_all_batchs(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """
    Return Metric Wrapper for Metrics caluculated on all data
    
    storing predictions and labels of evry batch, finally calculating metric on them.
    """
    pred_list = []
    label_list = []
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple):
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        pred_list.append(y.numpy())
        label_list.append(t.numpy())

        if is_last_batch:
            pred = np.concatenate(pred_list, axis=0)
            label = np.concatenate(label_list, axis=0)
            final_metric = metric_func(pred, label)
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            pred_list[:] = []
            label_list[:] = []

    return wrapper

# Train

## config_types for evaluating configuration

I use [pytorch-pfn-extras](https://github.com/pfnet/pytorch-pfn-extras) for training NNs. This library has useful config systems but requires some preparation.

For more details, see [docs](https://github.com/pfnet/pytorch-pfn-extras/blob/master/docs/config.md).

In [10]:
CONFIG_TYPES = {
    # # utils
    "__len__": lambda obj: len(obj),
    "method_call": lambda obj, method: getattr(obj, method)(),

    # # Dataset, DataLoader
    "SetiSimpleDataset": SetiSimpleDataset,
    "SetiAObsDataset": SetiAObsDataset,
    "DataLoader": torch.utils.data.DataLoader,

    # # Data Augmentation
    "Compose": A.Compose, "OneOf": A.OneOf,
    "Resize": A.Resize,
    "HorizontalFlip": A.HorizontalFlip, "VerticalFlip": A.VerticalFlip,
    "ShiftScaleRotate": A.ShiftScaleRotate,
    "RandomResizedCrop": A.RandomResizedCrop,
    "Cutout": A.Cutout,
    "ToTensorV2": ToTensorV2,

    # # Model
    "BasicImageModel": BasicImageModel,

    # # Optimizer
    "AdamW": optim.AdamW,

    # # Scheduler
    "OneCycleLR": lr_scheduler.OneCycleLR,

    # # Loss,Metric
    "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "ROCAUC": ROCAUC,

    # # Metric Wrapper
    "micro_average": micro_average,
    "calc_across_all_batchs": calc_across_all_batchs,

    # # PPE Extensions
    "ExtensionsManager": ppe.training.ExtensionsManager,

    "observe_lr": ppe_exts.observe_lr,
    "LogReport": ppe_exts.LogReport,
    "PlotReport": ppe_exts.PlotReport,
    "PrintReport": ppe_exts.PrintReport,
    "PrintReportNotebook": ppe_exts.PrintReportNotebook,
    "ProgressBar": ppe_exts.ProgressBar,
    "ProgressBarNotebook": ppe_exts.ProgressBarNotebook,
    "snapshot": ppe_exts.snapshot,
    "LRScheduler": ppe_exts.LRScheduler, 

    "MinValueTrigger": ppe_triggers.MinValueTrigger,
    "MaxValueTrigger": ppe_triggers.MaxValueTrigger,
    "EarlyStoppingTrigger": ppe_triggers.EarlyStoppingTrigger,
}

## configration

In [11]:
pre_eval_cfg = yaml.safe_load(
"""
globals:
  seed: 1086
  val_fold: null  # indicate when training
  output_path: null # indicate when training
  device: cuda:1
  enable_amp: False
  max_epoch: 100

model:
  type: BasicImageModel
  dims_head: [null, 1]
  base_name: efficientnet_b0
  pretrained: True
  in_channels: 1

dataset:
  height: 320
  width: 320
  mixup: {enabled: True, alpha: 1.0}
  train:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: HorizontalFlip, p: 0.5}
        - {type: VerticalFlip, p: 0.5}
        - {type: ShiftScaleRotate, p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
            rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}
        - {type: RandomResizedCrop, p: 1.0,
            scale: [0.9, 1.0], height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}
  val:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}  
  test:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform: "@/dataset/val/transform"

loader:
  train: {type: DataLoader, dataset: "@/dataset/train",
    batch_size: 48, num_workers: 4, shuffle: True, pin_memory: True, drop_last: True}
  val: {type: DataLoader, dataset: "@/dataset/val",
    batch_size: 48, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}
  test: {type: DataLoader, dataset: "@/dataset/test",
    batch_size: 48, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}

optimizer:
  type: AdamW
  params: {type: method_call, obj: "@/model", method: parameters}
  lr: 1.0e-05
  weight_decay: 1.0e-04

scheduler:
  type: OneCycleLR
  optimizer: "@/optimizer"
  epochs: "@/globals/max_epoch"
  steps_per_epoch: {type: __len__, obj: "@/loader/train"}
  max_lr: 1.0e-3
  pct_start: 0.111
  anneal_strategy: cos
  div_factor: 1.0e+2
  final_div_factor: 1

loss: {type: BCEWithLogitsLoss}

eval:
  - type: micro_average
    metric_func: {type: BCEWithLogitsLoss}
    report_name: loss
  - type: calc_across_all_batchs
    metric_func: {type: ROCAUC}
    report_name: metric

manager:
  type: ExtensionsManager
  models: "@/model"
  optimizers: "@/optimizer"
  max_epochs: "@/globals/max_epoch"
  iters_per_epoch: {type: __len__, obj: "@/loader/train"}
  out_dir: "@/globals/output_path"
  # stop_trgiger: {type: EarlyStoppingTrigger,
  #   monitor: val/metric, mode: max, patience: 5, verbose: True,
  #   check_trigger: [1, epoch], max_trigger: ["@/globals/max_epoch", epoch]}

extensions:
  # # log
  - {type: observe_lr, optimizer: "@/optimizer"}
  - {type: LogReport}
  - {type: PlotReport, y_keys: lr, x_key: epoch, filename: lr.png}
  - {type: PlotReport, y_keys: [train/loss, val/loss], x_key: epoch, filename: loss.png}
  - {type: PlotReport, y_keys: val/metric, x_key: epoch, filename: metric.png}
  - {type: PrintReport, entries: [
      epoch, iteration, lr, train/loss, val/loss, val/metric, elapsed_time]}
  - {type: ProgressBarNotebook, update_interval: 20}
  # snapshot
  - extension: {type: snapshot, target: "@/model", filename: "snapshot_by_metric_epoch_{.epoch}.pth"}
    trigger: {type: MaxValueTrigger, key: "val/metric", trigger: [1, epoch]}
  # # lr scheduler
  - {type: LRScheduler, scheduler: "@/scheduler", trigger: [1,  iteration]}
"""
)

## functions for training

In [12]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore


def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [13]:
def get_path_label(cfg: Config, train_all: pd.DataFrame):
    """Get file path and target info."""
    use_fold = cfg["/globals/val_fold"]

    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
    train_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in train_df["id"].values],
        "labels": train_df[CLASSES].values.astype("f")}
    val_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in val_df["id"].values],
        "labels": val_df[CLASSES].values.astype("f")
    }
    return train_path_label, val_path_label


def get_eval_func(cfg, model, device):
    
    def eval_func(**batch):
        """Run evaliation for val or test. This function is applied to each batch."""
        batch = to_device(batch, device)
        x = batch["image"]
        with amp.autocast(cfg["/globals/enable_amp"]): 
            y = model(x)
        return y.detach().cpu().to(torch.float32)  # input of metrics

    return eval_func


def mixup_data(use_mixup, x, t, alpha=1.0, use_cuda=True, device="cuda:0"):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if not use_mixup:
        return x, t, None, None
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    t_a, t_b = t, t[index]
    return mixed_x, t_a, t_b, lam


def get_criterion(use_mixup, loss_func):

    def mixup_criterion(pred, t_a, t_b, lam):
        return lam * loss_func(pred, t_a) + (1 - lam) * loss_func(pred, t_b)

    def single_criterion(pred, t_a, t_b, lam):
        return loss_func(pred, t_a)
    
    if use_mixup:
        return mixup_criterion
    else:
        return single_criterion

In [14]:
def train_one_fold(cfg, train_all):
    """Main"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg["/globals/seed"], deterministic=True)
    device = torch.device(cfg["/globals/device"])
    
    train_path_label, val_path_label = get_path_label(cfg, train_all)
    print("train: {}, val: {}".format(len(train_path_label["paths"]), len(val_path_label["paths"])))
   
    cfg["/dataset/train"].lazy_init(**train_path_label)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    train_loader = cfg["/loader/train"]
    val_loader = cfg["/loader/val"]

    model = cfg["/model"]
    model.to(device)
    optimizer = cfg["/optimizer"]
    loss_func = cfg["/loss"]
    loss_func.to(device)
    
    manager = cfg["/manager"]
    for ext in cfg["/extensions"]:
        if isinstance(ext, dict):
            manager.extend(**ext)
        else:
            manager.extend(ext)

    evaluator = ppe_exts.Evaluator(
        val_loader, model, eval_func=get_eval_func(cfg, model, device),
        metrics=cfg["/eval"], progress_bar=False)
    manager.extend(evaluator, trigger=(1, "epoch"))

    use_amp = cfg["/globals/enable_amp"]
    scaler = amp.GradScaler(enabled=use_amp)
    use_mixup = cfg["/dataset/mixup/enabled"]
    mixup_alpha = cfg["/dataset/mixup/alpha"]
    
    while not manager.stop_trigger:
        model.train()
        for batch in train_loader:
            with manager.run_iteration():
                batch = to_device(batch, device)
                x, t = batch["image"], batch["target"]
                # # for mixup
                mixed_x, t_a, t_b, lam = mixup_data(use_mixup, x, t, mixup_alpha, device=cfg["/globals/device"])
                criterion = get_criterion(use_mixup, loss_func)
                
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    y = model(mixed_x)
                    loss = criterion(y, t_a, t_b, lam)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                ppe.reporting.report({'train/loss': loss.item()})

## run train

In [15]:
pre_eval_cfg_list = []
for fold_id in FOLDS:
    tmp_cfg = copy.deepcopy(pre_eval_cfg)
    tmp_cfg["globals"]["val_fold"] = fold_id
    tmp_cfg["globals"]["output_path"] = str(TMP / f"fold{fold_id}")
    pre_eval_cfg_list.append(tmp_cfg)

In [None]:
for pre_eval_cfg in pre_eval_cfg_list:
    cfg = Config(pre_eval_cfg, types=CONFIG_TYPES)
    print(f"\n[fold {cfg['/globals/val_fold']}]")
    train_one_fold(cfg, train)
    with torch.cuda.device(cfg["/globals/device"]):
        torch.cuda.empty_cache()
    del cfg
    gc.collect()


[fold 0]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnet_b0: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…

epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1000        2.96581e-05  0.38121     0.327062    0.523433    228.9         




[J2           2000        8.71463e-05  0.327156    0.323834    0.556312    456.856       




[J3           3000        0.000177889  0.320226    0.276808    0.720033    684.924       




[J4           4000        0.000294665  0.305462    0.2679      0.760171    912.974       




[J5           5000        0.000428181  0.300172    0.240266    0.787446    1140.78       




[J6           6000        0.00056781  0.295007    0.244662    0.770569    1368.68       




[J7           7000        0.00070244  0.29453     0.244512    0.765695    1596.56       




[J8           8000        0.000821358  0.293176    0.237761    0.795275    1824.49       




[J9           9000        0.000915097  0.29156     0.245096    0.794974    2052.58       




[J10          10000       0.0009762   0.291494    0.243507    0.792971    2280.56       




[J11          11000       0.000999802  0.290327    0.239949    0.795811    2508.53       




[J12          12000       0.00099975  0.28921     0.250237    0.796008    2736.46       




[J13          13000       0.000998885  0.287586    0.236807    0.79974     2964.44       




[J14          14000       0.000997403  0.286611    0.236581    0.805622    3192.52       




[J15          15000       0.000995306  0.285857    0.22755     0.785099    3420.88       




[J16          16000       0.000992598  0.284206    0.226907    0.807456    3649.02       




[J17          17000       0.00098928  0.283764    0.228773    0.813927    3876.89       




[J18          18000       0.000985357  0.283272    0.223537    0.81698     4105.01       




[J19          19000       0.000980835  0.282396    0.220484    0.82351     4333.07       




[J20          20000       0.000975719  0.282346    0.22461     0.819397    4561.24       




[J21          21000       0.000970015  0.280417    0.222578    0.829078    4789.35       




[J22          22000       0.00096373  0.281715    0.222999    0.822648    5017.29       




[J23          23000       0.000956872  0.277793    0.216296    0.829534    5245.4        




[J24          24000       0.000949451  0.277687    0.218132    0.829675    5473.45       




[J25          25000       0.000941474  0.276168    0.218226    0.815757    5701.44       




[J26          26000       0.000932952  0.277558    0.21672     0.833642    5929.34       




[J27          27000       0.000923896  0.275308    0.215786    0.82925     6157.29       




[J28          28000       0.000914316  0.276667    0.214682    0.827361    6385.29       




[J29          29000       0.000904226  0.276456    0.218167    0.830275    6613.1        




[J30          30000       0.000893637  0.27603     0.214814    0.831435    6841.2        




[J31          31000       0.000882563  0.274902    0.210724    0.836998    7069.2        




[J32          32000       0.000871017  0.273925    0.213799    0.829698    7297.18       




[J33          33000       0.000859015  0.273633    0.213273    0.824656    7525.32       




[J34          34000       0.00084657  0.273124    0.207918    0.833883    7753.21       




[J35          35000       0.000833699  0.271288    0.209666    0.84041     7981.33       




[J36          36000       0.000820417  0.27206     0.208413    0.840565    8209.72       




[J37          37000       0.000806742  0.271331    0.208923    0.84265     8437.82       




[J38          38000       0.000792689  0.270638    0.208167    0.838087    8665.96       




[J39          39000       0.000778278  0.270778    0.207133    0.840803    8894.27       




[J40          40000       0.000763525  0.269998    0.205762    0.842016    9122.6        




[J41          41000       0.000748449  0.269744    0.210575    0.837923    9350.96       




[J42          42000       0.00073307  0.267662    0.205462    0.845081    9579.02       




[J43          43000       0.000717405  0.267797    0.212472    0.844009    9807.35       




[J44          44000       0.000701476  0.267527    0.195292    0.848638    10035.5       




[J45          45000       0.000685301  0.267175    0.202404    0.845474    10263.8       




[J46          46000       0.000668901  0.266203    0.200177    0.843717    10492.2       




[J47          47000       0.000652296  0.266163    0.199095    0.847065    10720.5       




[J48          48000       0.000635507  0.265908    0.201452    0.849269    10948.8       




[J49          49000       0.000618556  0.266504    0.208434    0.845644    11177.1       




[J50          50000       0.000601462  0.265517    0.199649    0.844655    11405.8       




[J51          51000       0.000584249  0.26564     0.198568    0.846213    11634.2       




[J52          52000       0.000566936  0.265269    0.203025    0.845945    11862.6       




[J53          53000       0.000549546  0.264426    0.202694    0.846966    12091.1       




[J54          54000       0.0005321   0.26421     0.20105     0.842373    12319.5       




[J55          55000       0.00051462  0.264165    0.19858     0.849677    12547.8       




[J56          56000       0.000497129  0.263528    0.200007    0.848992    12776.2       




[J57          57000       0.000479647  0.263219    0.199112    0.852572    13004.3       




[J58          58000       0.000462197  0.264705    0.200578    0.852352    13232.8       




[J59          59000       0.0004448   0.261237    0.201638    0.853656    13461         




[J60          60000       0.000427479  0.262501    0.193597    0.85089     13689.5       




[J61          61000       0.000410254  0.259924    0.192172    0.854661    13917.9       




[J62          62000       0.000393147  0.262236    0.191583    0.85082     14146.2       




[J63          63000       0.000376181  0.260228    0.195587    0.851216    14374.3       




[J64          64000       0.000359375  0.258325    0.200153    0.854405    14602.4       




[J65          65000       0.000342751  0.260507    0.199135    0.851898    14830.5       




[J66          66000       0.000326329  0.259427    0.195154    0.852771    15058.7       




[J67          67000       0.000310131  0.259443    0.200118    0.851389    15287.3       




[J68          68000       0.000294176  0.259403    0.196514    0.851803    15515.6       




[J69          69000       0.000278484  0.257892    0.198194    0.850693    15743.6       




[J70          70000       0.000263075  0.256862    0.19192     0.851326    15971.8       




[J71          71000       0.000247968  0.258982    0.195518    0.851826    16200         




[J72          72000       0.000233183  0.257796    0.198351    0.852545    16428.4       




[J73          73000       0.000218736  0.257657    0.198417    0.851121    16656.7       




[J74          74000       0.000204647  0.255385    0.195283    0.851042    16884.8       




[J75          75000       0.000190933  0.256334    0.197932    0.852326    17113.2       




[J76          76000       0.000177611  0.256605    0.193461    0.855852    17341.4       




[J77          77000       0.000164698  0.256496    0.194662    0.856552    17569.9       




[J78          78000       0.00015221  0.256136    0.19313     0.856379    17798.3       




[J79          79000       0.000140163  0.255287    0.19486     0.85691     18026.5       




[J80          80000       0.000128571  0.256831    0.193374    0.85598     18255         




[J81          81000       0.000117449  0.255461    0.192098    0.85333     18483.3       




[J82          82000       0.000106811  0.255844    0.192549    0.855613    18711.7       




[J83          83000       9.66699e-05  0.255848    0.1894      0.856142    18940         




[J84          84000       8.70389e-05  0.255869    0.191195    0.855666    19168.3       




[J85          85000       7.79299e-05  0.253157    0.190966    0.855882    19396.6       




[J86          86000       6.93541e-05  0.25537     0.191365    0.854418    19624.9       




[J87          87000       6.13224e-05  0.254315    0.190686    0.856149    19853.6       




[J88          88000       5.38446e-05  0.253935    0.192125    0.856608    20082         




[J89          89000       4.69302e-05  0.253957    0.191858    0.857011    20310.2       




[J90          90000       4.05877e-05  0.25373     0.189276    0.856967    20538.6       




[J91          91000       3.48252e-05  0.252982    0.189036    0.856867    20766.8       




[J92          92000       2.96497e-05  0.254857    0.191355    0.856433    20995         




[J93          93000       2.50679e-05  0.253415    0.190379    0.856818    21223.2       




[J94          94000       2.10853e-05  0.256351    0.191658    0.855942    21451.3       




[J95          95000       1.77069e-05  0.252946    0.19212     0.856535    21679.4       




[J96          96000       1.49371e-05  0.252938    0.188426    0.856512    21907.5       




[J97          97000       1.27791e-05  0.25156     0.188004    0.856996    22135.7       




[J98          98000       1.12358e-05  0.252437    0.193301    0.856462    22363.9       




[J99          99000       1.0309e-05  0.255238    0.195475    0.856854    22592.1       




[J100         100000      1e-05       0.251387    0.188843    0.856979    22820.4       

[fold 1]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnet_b0: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1000        2.96581e-05  0.381982    0.327809    0.52821     227.187       




[J2           2000        8.71463e-05  0.326816    0.323926    0.565594    455.442       




[J3           3000        0.000177889  0.317078    0.287238    0.68361     683.635       




[J4           4000        0.000294665  0.303237    0.269739    0.733907    911.781       




[J5           5000        0.000428181  0.298688    0.2485      0.761656    1140.12       




[J6           6000        0.00056781  0.294227    0.275393    0.756855    1368.22       




[J7           7000        0.00070244  0.291839    0.246858    0.774329    1596.38       




[J8           8000        0.000821358  0.291227    0.242419    0.785836    1824.71       




[J9           9000        0.000915097  0.290427    0.247971    0.762891    2052.84       




[J10          10000       0.0009762   0.290488    0.23713     0.786229    2281.06       




[J11          11000       0.000999802  0.28842     0.240142    0.783048    2509.08       




[J12          12000       0.00099975  0.289137    0.25593     0.765353    2737.33       




[J13          13000       0.000998885  0.286033    0.240945    0.79652     2965.6        




[J14          14000       0.000997403  0.284618    0.239061    0.793856    3193.72       




[J15          15000       0.000995306  0.283081    0.235707    0.797509    3422.01       




[J16          16000       0.000992598  0.281093    0.230223    0.793692    3650.24       




[J17          17000       0.00098928  0.282104    0.235867    0.800379    3879.23       




[J18          18000       0.000985357  0.281276    0.229368    0.803524    4107.46       




[J19          19000       0.000980835  0.28069     0.234863    0.79593     4335.57       




[J20          20000       0.000975719  0.278374    0.226944    0.804034    4563.75       




[J21          21000       0.000970015  0.278676    0.236795    0.809813    4791.9        




[J22          22000       0.00096373  0.277609    0.223139    0.807136    5020.19       




[J23          23000       0.000956872  0.277655    0.224126    0.811447    5248.48       




[J24          24000       0.000949451  0.2755      0.230419    0.812294    5476.79       




[J25          25000       0.000941474  0.273327    0.224153    0.805218    5704.97       




[J26          26000       0.000932952  0.275702    0.228842    0.811956    5933.11       




[J27          27000       0.000923896  0.273392    0.224183    0.815399    6161.36       




[J28          28000       0.000914316  0.273176    0.226496    0.811243    6389.64       




[J29          29000       0.000904226  0.27319     0.227709    0.81208     6617.94       




[J30          30000       0.000893637  0.271267    0.237908    0.805099    6846.16       




[J31          31000       0.000882563  0.270861    0.220626    0.815743    7074.3        




[J32          32000       0.000871017  0.269729    0.219551    0.817392    7302.73       




[J33          33000       0.000859015  0.271227    0.225569    0.813604    7530.99       




[J34          34000       0.00084657  0.270683    0.217493    0.815898    7759.43       




[J35          35000       0.000833699  0.270183    0.230608    0.806438    7987.65       




[J36          36000       0.000820417  0.269624    0.223346    0.820278    8215.74       




[J37          37000       0.000806742  0.269099    0.220776    0.818129    8443.96       




[J38          38000       0.000792689  0.268066    0.217281    0.816043    8672.07       




[J39          39000       0.000778278  0.27012     0.220465    0.820357    8900.21       




[J40          40000       0.000763525  0.267244    0.218479    0.819498    9128.58       




[J41          41000       0.000748449  0.267377    0.215423    0.826936    9356.61       




[J42          42000       0.00073307  0.26698     0.215215    0.825432    9584.96       




[J43          43000       0.000717405  0.265533    0.226028    0.816704    9813.09       




[J44          44000       0.000701476  0.265663    0.214122    0.823074    10041.4       




[J45          45000       0.000685301  0.264385    0.217581    0.826651    10269.7       




[J46          46000       0.000668901  0.265302    0.212553    0.830902    10497.9       




[J47          47000       0.000652296  0.264726    0.212987    0.83025     10726.2       




[J48          48000       0.000635507  0.263991    0.219588    0.825465    10954.3       




[J49          49000       0.000618556  0.263299    0.213005    0.830296    11183.4       




[J50          50000       0.000601462  0.26394     0.218257    0.827808    11411.6       




[J51          51000       0.000584249  0.26346     0.211625    0.838128    11639.5       




[J52          52000       0.000566936  0.263375    0.2116      0.838669    11867.6       




[J53          53000       0.000549546  0.262072    0.209655    0.839792    12095.7       




[J54          54000       0.0005321   0.262588    0.212813    0.829378    12324         




[J55          55000       0.00051462  0.262397    0.211384    0.839493    12552.2       




[J56          56000       0.000497129  0.260891    0.207147    0.837375    12780.2       




[J57          57000       0.000479647  0.261817    0.213622    0.834852    13008.3       




[J58          58000       0.000462197  0.261122    0.209243    0.836172    13236.2       




[J59          59000       0.0004448   0.259592    0.214172    0.83683     13464.3       




[J60          60000       0.000427479  0.260057    0.207901    0.836486    13692.5       




[J61          61000       0.000410254  0.258545    0.206783    0.83797     13920.5       




[J62          62000       0.000393147  0.260175    0.206253    0.835953    14148.6       




[J63          63000       0.000376181  0.258172    0.21005     0.834193    14376.5       




[J64          64000       0.000359375  0.256442    0.209198    0.836974    14604.6       




[J65          65000       0.000342751  0.259269    0.210008    0.834673    14832.8       




[J66          66000       0.000326329  0.25622     0.204841    0.838156    15060.8       




[J67          67000       0.000310131  0.258544    0.210951    0.835012    15289.1       




[J68          68000       0.000294176  0.256003    0.211991    0.838941    15517.1       




[J69          69000       0.000278484  0.255458    0.210146    0.837642    15745.3       




[J70          70000       0.000263075  0.254339    0.211577    0.838428    15973.4       




[J71          71000       0.000247968  0.254842    0.209499    0.839337    16201.3       




[J72          72000       0.000233183  0.255342    0.208929    0.839561    16429.5       




[J73          73000       0.000218736  0.255079    0.201859    0.844047    16657.7       




[J74          74000       0.000204647  0.253925    0.202852    0.84113     16886         




[J75          75000       0.000190933  0.256279    0.210478    0.839042    17114         




[J76          76000       0.000177611  0.255957    0.206588    0.843188    17342         




[J77          77000       0.000164698  0.252869    0.205213    0.842353    17570.3       




[J78          78000       0.00015221  0.25254     0.202826    0.840784    17798.5       




[J79          79000       0.000140163  0.253302    0.206852    0.842828    18026.9       




[J80          80000       0.000128571  0.255608    0.205148    0.841805    18255         




[J81          81000       0.000117449  0.252671    0.203646    0.843512    18483.3       




[J82          82000       0.000106811  0.252571    0.204764    0.845369    18711.5       




[J83          83000       9.66699e-05  0.253478    0.200899    0.845556    18939.5       




[J84          84000       8.70389e-05  0.252218    0.203523    0.844052    19167.8       




[J85          85000       7.79299e-05  0.249093    0.202775    0.844657    19396         




[J86          86000       6.93541e-05  0.252665    0.207414    0.842264    19624.3       




[J87          87000       6.13224e-05  0.251953    0.204149    0.842453    19853.4       




[J88          88000       5.38446e-05  0.252124    0.199307    0.843746    20081.6       




[J89          89000       4.69302e-05  0.252369    0.203035    0.842878    20309.9       




[J90          90000       4.05877e-05  0.251385    0.200439    0.844452    20538         




[J91          91000       3.48252e-05  0.252455    0.201868    0.844114    20766.2       




[J92          92000       2.96497e-05  0.252947    0.204222    0.844711    20994.4       




[J93          93000       2.50679e-05  0.250965    0.199669    0.845353    21222.5       




[J94          94000       2.10853e-05  0.254354    0.205573    0.845173    21450.7       




[J95          95000       1.77069e-05  0.250969    0.2032      0.845612    21678.8       




[J96          96000       1.49371e-05  0.251178    0.201682    0.845373    21907         




[J97          97000       1.27791e-05  0.250922    0.201521    0.846122    22135.2       




[J98          98000       1.12358e-05  0.25053     0.205947    0.844397    22363.5       




[J99          99000       1.0309e-05  0.252412    0.205078    0.844721    22591.8       




[J100         100000      1e-05       0.249387    0.200539    0.844781    22819.9       

[fold 2]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnet_b0: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1000        2.96581e-05  0.381596    0.327098    0.520019    227.225       




[J2           2000        8.71463e-05  0.327312    0.324513    0.553381    455.377       




[J3           3000        0.000177889  0.316889    0.273096    0.719351    683.571       




[J4           4000        0.000294665  0.304211    0.253718    0.76851     911.776       




[J5           5000        0.000428181  0.298133    0.253497    0.770271    1139.84       




[J6           6000        0.00056781  0.29448     0.258274    0.771356    1368.13       




[J7           7000        0.00070244  0.294613    0.240757    0.775248    1596.33       




[J8           8000        0.000821358  0.293445    0.245475    0.776801    1824.48       




[J9           9000        0.000915097  0.291684    0.255181    0.774825    2052.66       




[J10          10000       0.0009762   0.291603    0.252834    0.777892    2280.85       




[J11          11000       0.000999802  0.291854    0.242715    0.784077    2509.07       




[J12          12000       0.00099975  0.288541    0.247152    0.787114    2737.4        




[J13          13000       0.000998885  0.287259    0.242034    0.785782    2965.51       




[J14          14000       0.000997403  0.285027    0.235441    0.788927    3193.71       




[J15          15000       0.000995306  0.285463    0.229033    0.79968     3421.95       




[J16          16000       0.000992598  0.285156    0.225343    0.801051    3650.15       




[J17          17000       0.00098928  0.285123    0.228133    0.808621    3878.32       




[J18          18000       0.000985357  0.282775    0.226332    0.810007    4106.37       




[J19          19000       0.000980835  0.284258    0.235312    0.807987    4334.51       




[J20          20000       0.000975719  0.28086     0.225043    0.809201    4562.56       




[J21          21000       0.000970015  0.279936    0.233362    0.809535    4790.7        




[J22          22000       0.00096373  0.279793    0.222007    0.805909    5018.94       




[J23          23000       0.000956872  0.277538    0.223185    0.806442    5246.92       




[J24          24000       0.000949451  0.278064    0.220694    0.810336    5474.99       




[J25          25000       0.000941474  0.27653     0.218164    0.818479    5703.08       




[J26          26000       0.000932952  0.277504    0.219568    0.816292    5931.46       




[J27          27000       0.000923896  0.275677    0.225506    0.81927     6159.57       




[J28          28000       0.000914316  0.276377    0.214782    0.816619    6387.69       




[J29          29000       0.000904226  0.274863    0.224292    0.810251    6615.83       




[J30          30000       0.000893637  0.274801    0.224544    0.813066    6843.91       




[J31          31000       0.000882563  0.273949    0.220674    0.80994     7072.1        




[J32          32000       0.000871017  0.273799    0.215272    0.825292    7300.19       




[J33          33000       0.000859015  0.274055    0.21166     0.821656    7528.47       




[J34          34000       0.00084657  0.272022    0.220827    0.825277    7756.69       




[J35          35000       0.000833699  0.271708    0.215537    0.824329    7984.85       




[J36          36000       0.000820417  0.269991    0.217891    0.821206    8213.23       




[J37          37000       0.000806742  0.27127     0.208621    0.826238    8441.26       




[J38          38000       0.000792689  0.26978     0.211256    0.821772    8671.09       




[J39          39000       0.000778278  0.270889    0.210822    0.825339    8899.33       




[J40          40000       0.000763525  0.269697    0.212323    0.828197    9127.41       




[J41          41000       0.000748449  0.269861    0.212327    0.830713    9355.83       




[J42          42000       0.00073307  0.270433    0.214923    0.832046    9583.91       




[J43          43000       0.000717405  0.267484    0.209589    0.826751    9812.27       




[J44          44000       0.000701476  0.267688    0.207743    0.838109    10040.6       




[J45          45000       0.000685301  0.266532    0.212534    0.827671    10268.7       




[J46          46000       0.000668901  0.267061    0.208629    0.831216    10497.1       




[J47          47000       0.000652296  0.267596    0.211148    0.829986    10725.2       




[J48          48000       0.000635507  0.266071    0.207004    0.830593    10953.5       




[J49          49000       0.000618556  0.266236    0.206397    0.831607    11181.6       




[J50          50000       0.000601462  0.265128    0.205835    0.833031    11409.7       




[J51          51000       0.000584249  0.267383    0.205111    0.835256    11637.9       




[J52          52000       0.000566936  0.266389    0.205885    0.83115     11866         




[J53          53000       0.000549546  0.26449     0.20894     0.832185    12094.3       




[J54          54000       0.0005321   0.264188    0.20151     0.837947    12322.7       




[J55          55000       0.00051462  0.262868    0.19749     0.838014    12550.8       




[J56          56000       0.000497129  0.264136    0.200351    0.837757    12779         




[J57          57000       0.000479647  0.262001    0.199259    0.838507    13007.3       




[J58          58000       0.000462197  0.262729    0.205498    0.837702    13235.7       




[J59          59000       0.0004448   0.262434    0.201517    0.836154    13463.9       




[J60          60000       0.000427479  0.262614    0.199088    0.837993    13692         




[J61          61000       0.000410254  0.260702    0.195891    0.83947     13920.3       




[J62          62000       0.000393147  0.261261    0.19829     0.84136     14148.4       




[J63          63000       0.000376181  0.261148    0.204659    0.835577    14376.7       




[J64          64000       0.000359375  0.258988    0.201186    0.840897    14604.9       




[J65          65000       0.000342751  0.261099    0.201252    0.842217    14833.1       




[J66          66000       0.000326329  0.260177    0.199424    0.842438    15061.4       




[J67          67000       0.000310131  0.25996     0.19765     0.841643    15289.6       




[J68          68000       0.000294176  0.259587    0.199314    0.842381    15517.8       




[J69          69000       0.000278484  0.256871    0.200163    0.846616    15746         




[J70          70000       0.000263075  0.25761     0.195537    0.843459    15974.2       




[J71          71000       0.000247968  0.258995    0.197958    0.845257    16202.4       




[J72          72000       0.000233183  0.25647     0.197616    0.842896    16430.5       




[J73          73000       0.000218736  0.256829    0.199255    0.844477    16658.6       




[J74          74000       0.000204647  0.256443    0.197937    0.845788    16887.1       




[J75          75000       0.000190933  0.257822    0.19777     0.845854    17115.1       




[J76          76000       0.000177611  0.256796    0.192863    0.848315    17343.3       




[J77          77000       0.000164698  0.255629    0.190948    0.846538    17571.5       




[J78          78000       0.00015221  0.254885    0.194632    0.845593    17799.9       




[J79          79000       0.000140163  0.255258    0.20158     0.845793    18028.1       




[J80          80000       0.000128571  0.256064    0.193655    0.845047    18256.2       




[J81          81000       0.000117449  0.255691    0.190392    0.846101    18484.5       




[J82          82000       0.000106811  0.255819    0.194105    0.846398    18712.8       




[J83          83000       9.66699e-05  0.254679    0.193293    0.844044    18941         




[J84          84000       8.70389e-05  0.255061    0.194058    0.847284    19169         




[J85          85000       7.79299e-05  0.253631    0.191245    0.846825    19397.2       




[J86          86000       6.93541e-05  0.254914    0.195473    0.846943    19625.4       




[J87          87000       6.13224e-05  0.254704    0.188317    0.847584    19853.6       




[J88          88000       5.38446e-05  0.254595    0.192059    0.847173    20081.7       




[J89          89000       4.69302e-05  0.253941    0.195997    0.847038    20309.7       




[J90          90000       4.05877e-05  0.253045    0.195542    0.84745     20538         




[J91          91000       3.48252e-05  0.254502    0.193853    0.848732    20766.2       




[J92          92000       2.96497e-05  0.254474    0.1932      0.848902    20994.4       




[J93          93000       2.50679e-05  0.252835    0.194005    0.84803     21222.8       




[J94          94000       2.10853e-05  0.253556    0.193423    0.847513    21450.9       




[J95          95000       1.77069e-05  0.253734    0.192777    0.848124    21679.2       




[J96          96000       1.49371e-05  0.25211     0.192512    0.848563    21910.7       




[J97          97000       1.27791e-05  0.253195    0.1887      0.848507    22138.8       




[J98          98000       1.12358e-05  0.252868    0.195547    0.847679    22366.9       




[J99          99000       1.0309e-05  0.254354    0.193179    0.847038    22595         




[J100         100000      1e-05       0.252501    0.188847    0.848767    22823.3       

[fold 3]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnet_b0: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1000        2.96581e-05  0.380804    0.328055    0.517485    227.179       




[J2           2000        8.71463e-05  0.327044    0.32579     0.535381    455.231       




[J3           3000        0.000177889  0.319523    0.283106    0.708835    683.481       




[J4           4000        0.000294665  0.304733    0.257008    0.770404    911.609       




[J5           5000        0.000428181  0.298418    0.252       0.784298    1139.9        




[J6           6000        0.00056781  0.295151    0.257864    0.775762    1368.09       




[J7           7000        0.00070244  0.293884    0.244061    0.791297    1596.13       




[J8           8000        0.000821358  0.294197    0.246793    0.777457    1824.43       




[J9           9000        0.000915097  0.292637    0.26417     0.793238    2052.57       




[J10          10000       0.0009762   0.290649    0.243834    0.796136    2280.87       




[J11          11000       0.000999802  0.290978    0.244406    0.799135    2509          




[J12          12000       0.00099975  0.288882    0.231203    0.807358    2737.05       




[J13          13000       0.000998885  0.287178    0.233607    0.800508    2965.25       




[J14          14000       0.000997403  0.287567    0.237206    0.801155    3193.35       




[J15          15000       0.000995306  0.284411    0.227541    0.810215    3421.49       




[J16          16000       0.000992598  0.283704    0.221108    0.811682    3649.79       




[J17          17000       0.00098928  0.282328    0.227505    0.809508    3877.84       




[J18          18000       0.000985357  0.281163    0.230993    0.809983    4106.06       




[J19          19000       0.000980835  0.281996    0.227684    0.806288    4334.14       




[J20          20000       0.000975719  0.280976    0.225741    0.806608    4562.29       




[J21          21000       0.000970015  0.280185    0.230129    0.815732    4790.48       




[J22          22000       0.00096373  0.279851    0.219279    0.81759     5018.48       




[J23          23000       0.000956872  0.278726    0.220106    0.816776    5246.69       




[J24          24000       0.000949451  0.276993    0.227557    0.815191    5474.65       




[J25          25000       0.000941474  0.276653    0.223247    0.811894    5702.84       




[J26          26000       0.000932952  0.275832    0.225384    0.815399    5931.01       




[J27          27000       0.000923896  0.27438     0.218919    0.824608    6159          




[J28          28000       0.000914316  0.276127    0.211924    0.824235    6387.14       




[J29          29000       0.000904226  0.274343    0.219786    0.823161    6615.11       




[J30          30000       0.000893637  0.274267    0.218351    0.825952    6843.2        




[J31          31000       0.000882563  0.274652    0.220488    0.820249    7071.24       




[J32          32000       0.000871017  0.273014    0.215012    0.824675    7299.36       




[J33          33000       0.000859015  0.272629    0.211473    0.834011    7527.55       




[J34          34000       0.00084657  0.273674    0.221651    0.830677    7755.55       




[J35          35000       0.000833699  0.271856    0.217725    0.826447    7983.64       




[J36          36000       0.000820417  0.271377    0.2233      0.830917    8211.7        




[J37          37000       0.000806742  0.272529    0.213841    0.819277    8439.82       




[J38          38000       0.000792689  0.270529    0.20614     0.830633    8667.82       




[J39          39000       0.000778278  0.272013    0.205936    0.833676    8895.76       




[J40          40000       0.000763525  0.269553    0.207513    0.838349    9123.82       




[J41          41000       0.000748449  0.268521    0.208751    0.834245    9351.97       




[J42          42000       0.00073307  0.269514    0.207193    0.833619    9580.11       




[J43          43000       0.000717405  0.266965    0.202602    0.834526    9808.23       




[J44          44000       0.000701476  0.267091    0.202937    0.832269    10036.2       




[J45          45000       0.000685301  0.267015    0.20338     0.832368    10264.4       




[J46          46000       0.000668901  0.266666    0.202061    0.838237    10492.3       




[J47          47000       0.000652296  0.266568    0.206802    0.83472     10720.4       




[J48          48000       0.000635507  0.26554     0.204679    0.833306    10948.4       




[J49          49000       0.000618556  0.265872    0.204058    0.840455    11176.4       




[J50          50000       0.000601462  0.264277    0.201505    0.835761    11404.5       




[J51          51000       0.000584249  0.266006    0.199577    0.839816    11632.5       




[J52          52000       0.000566936  0.264757    0.198068    0.838278    11860.7       




[J53          53000       0.000549546  0.262405    0.201382    0.836906    12088.8       




[J54          54000       0.0005321   0.264954    0.202788    0.837719    12316.8       




[J55          55000       0.00051462  0.262892    0.199504    0.843665    12544.9       




[J56          56000       0.000497129  0.264245    0.203091    0.841913    12772.9       




[J57          57000       0.000479647  0.262591    0.198448    0.842177    13001.1       




[J58          58000       0.000462197  0.262862    0.20083     0.847518    13229.1       




[J59          59000       0.0004448   0.261773    0.196666    0.847584    13457.1       




[J60          60000       0.000427479  0.263297    0.198784    0.839129    13687.2       




[J61          61000       0.000410254  0.258975    0.200243    0.842546    13915.2       




[J62          62000       0.000393147  0.260743    0.192069    0.844073    14143.3       




[J63          63000       0.000376181  0.260847    0.194088    0.845325    14371.5       




# Inference

## Copy best models

In [None]:
best_log_list = []
for pre_eval_cfg, fold_id in zip(pre_eval_cfg_list, FOLDS):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/metric"].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_metric_epoch_{best_epoch}.pth"
    copy_to = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/config.yml", "w") as fw:
        yaml.dump(pre_eval_cfg, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

## Inference OOF & Test

In [None]:
def run_inference_loop(cfg, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch["image"], device)
            y = model(x)
            pred_list.append(y.sigmoid().detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [None]:
label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train), N_CLASSES))
score_list = []
test_pred_arr = np.zeros((N_FOLDS, len(smpl_sub), N_CLASSES))
test_path_label = {
    "paths": [DATA / f"test/{img_id[0]}/{img_id}.npy" for img_id in smpl_sub["id"].values],
    "labels": smpl_sub[CLASSES].values.astype("f")
}

for fold_id in FOLDS:
    print(f"\n[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "config.yml", "r") as fr:
        cfg = Config(yaml.safe_load(fr), types=CONFIG_TYPES)
    device = torch.device(cfg["/globals/device"])
    val_idx = train.query("fold == @fold_id").index.values

    # # get_dataloader
    _, val_path_label = get_path_label(cfg, train)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    cfg["/dataset/test"].lazy_init(**test_path_label)
    val_loader = cfg["/loader/val"]
    test_loader = cfg["/loader/test"]
    
    # # get model
    model_path = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    model = cfg["/model"]
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # # inference
    val_pred = run_inference_loop(cfg, model, val_loader, device)
    val_score = roc_auc_score(label_arr[val_idx], val_pred)
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_score])
    
    test_pred_arr[fold_id] = run_inference_loop(cfg, model, test_loader, device)
    
    del cfg, val_idx, val_path_label
    del model, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"val score: {val_score:.4f}")

In [None]:
oof_score = roc_auc_score(label_arr, oof_pred_arr)
score_list.append(["oof", oof_score])
pd.DataFrame(score_list, columns=["fold", "metric"])

In [None]:
oof_df = train.copy()
oof_df[CLASSES] = oof_pred_arr
oof_df.to_csv(OUTPUT / "oof_prediction.csv", index=False)

## Make submission

In [None]:
sub_df = smpl_sub.copy()
sub_df[CLASSES] = test_pred_arr.mean(axis=0)
sub_df.to_csv(OUTPUT / "submission.csv", index=False)

In [None]:
sub_df.head()

# EOF