# About

I'd like to share an example for [small models](https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644).

### Experimental Settings

### model
* backbone: resnet18d (use the pretrained model provided by [timm](https://github.com/rwightman/pytorch-image-models))
* head classifier: one linear layer
* num of input channels: **1**

### data augmentation
* implemented by [albumentations](https://albumentations.ai/docs/) **except for Mixup**
* Train
  * Resize
  * HorizontalFlip
  * VerticalFlip
  * ShiftScaleRotate
  * RandomResizedCrop
  * Mixup(alpha=1.0)
* Val, Test
  * Resize

### learning settings
* CV Strategy: Stratified KFold (K=5)
* max epochs: 18
* data:
  * input image size: 1x320x320
  * batch size: 64
* loss: [BCEWithLogitsLoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss)
* optimizer: [AdamW](https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW)
  * weight decay: 1.0e-04
* learning rate scheduler: [OneCycleLR](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR) 
  * epochs: 18
  * max_lr: 1e-3
  * pct_start: 0.111
  * anneal_strategy: cos
  * div_factor: 1.0e+2
  * inal_div_factor: 1
  
### NOTE: I use only on-target ('A') observations

```python
img = np.load(path)[[0, 2, 4]]          # shape: (3, 273, 256)
img = np.vstack(img)                    # shape: (819, 256)
img = img.transpose(1, 0)               # shape: (256, 819)
```

### Submission -> 0.747

# Prapere

## Install

## Import

In [1]:
import os
import gc
import copy
import yaml
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.config import Config
from pytorch_pfn_extras.training import extensions as ppe_exts, triggers as ppe_triggers

In [2]:
# timm.list_models(pretrained=True)

In [3]:
ROOT = Path.cwd()
INPUT = ROOT / "../../"
OUTPUT = ROOT / "output"
DATA = INPUT / "seti-breakthrough-listen"
TRAIN = DATA / "train"
TEST = DATA / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086
CLASSES = ["target"]
N_CLASSES = len(CLASSES)
FOLDS = [0, 1, 2, 3, 4]
N_FOLDS = len(FOLDS)

## Read Data, Split folds

In [4]:
train = pd.read_csv(DATA / "train_labels.csv")
smpl_sub = pd.read_csv(DATA / "sample_submission.csv")

In [5]:
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (_, val_idx) in enumerate(skf.split(train["id"], train["target"])):
    train.loc[val_idx, "fold"] = fold_id

In [6]:
train.groupby("fold").agg(total=("id", len), pos=("target", sum))

Unnamed: 0_level_0,total,pos
fold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,12000,1200
1,12000,1200
2,12000,1200
3,12000,1200
4,12000,1200


## Definition of Model, Dataset, Metric

### Model

In [7]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name: str, dims_head: tp.List[int],
        pretrained=False, in_channels: int=3
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            base_model = timm.create_model(
                base_name, num_classes=0, pretrained=pretrained, in_chans=in_channels)
            in_features = base_model.num_features
            print("load imagenet pretrained:", pretrained)
        else:
            raise NotImplementedError

        self.backbone = base_model
        print(f"{base_name}: {in_features}")
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),
            ])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1])
        )
        self.head_cls = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head_cls(h)
        return h

### Dataset

In [8]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]


class SetiSimpleDataset(torch.utils.data.Dataset):
    """
    Dataset using 6 channels by stacking them along time-axis

    Attributes
    ----------
    paths : tp.Sequence[FilePath]
        Sequence of path to cadence snippet file
    labels : tp.Sequence[Label]
        Sequence of label for cadence snippet file
    transform: albumentations.Compose
        composed data augmentations for data
    """

    def __init__(
        self,
        paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label],
        transform: A.Compose,
    ):
        """Initialize"""
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Return num of cadence snippets"""
        return len(self.paths)

    def __getitem__(self, index: int):
        """Return transformed image and label for given index."""
        path, label = self.paths[index], self.labels[index]
        img = self._read_cadence_array(path)
        img = self.transform(image=img)["image"]
        return {"image": img, "target": label}

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)  # shape: (6, 273, 256)
        img = np.vstack(img)  # shape: (1638, 256)
        img = img.transpose(1, 0)  # shape: (256, 1638)
        img = img.astype("f")[..., np.newaxis]  # shape: (256, 1638, 1)
        return img

    def lazy_init(self, paths=None, labels=None, transform=None):
        """Reset Members"""
        if paths is not None:
            self.paths = paths
        if labels is not None:
            self.labels = labels
        if transform is not None:
            self.transform = transform


class SetiAObsDataset(SetiSimpleDataset):
    """Use only on-target observation"""

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
        img = np.vstack(img)  # shape: (819, 256)
        img = img.transpose(1, 0)  # shape: (256, 819)
        img = img.astype("f")[..., np.newaxis]  # shape: (819, 256, 1)
        return img

### Metric

In [9]:
Batch = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]]
ModelOut = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor], torch.Tensor]


class ROCAUC(nn.Module):
    """ROC AUC score"""

    def __init__(self, average="macro") -> None:
        """Initialize."""
        self.average = average
        super(ROCAUC, self).__init__()

    def forward(self, y, t) -> float:
        """Forward."""
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu().numpy()

        return roc_auc_score(t, y, average=self.average)


def micro_average(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """Return Metric Wrapper for Simple Mean Metric"""
    metric_sum = [0.]
    n_examples = [0]
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple): 
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        metric = metric_func(y, t).item()
        metric_sum[0] += metric * y.shape[0]
        n_examples[0] += y.shape[0]

        if is_last_batch:
            final_metric = metric_sum[0] / n_examples[0]
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            metric_sum[0] = 0.
            n_examples[0] = 0

    return wrapper


def calc_across_all_batchs(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """
    Return Metric Wrapper for Metrics caluculated on all data
    
    storing predictions and labels of evry batch, finally calculating metric on them.
    """
    pred_list = []
    label_list = []
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple):
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        pred_list.append(y.numpy())
        label_list.append(t.numpy())

        if is_last_batch:
            pred = np.concatenate(pred_list, axis=0)
            label = np.concatenate(label_list, axis=0)
            final_metric = metric_func(pred, label)
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            pred_list[:] = []
            label_list[:] = []

    return wrapper

# Train

## config_types for evaluating configuration

I use [pytorch-pfn-extras](https://github.com/pfnet/pytorch-pfn-extras) for training NNs. This library has useful config systems but requires some preparation.

For more details, see [docs](https://github.com/pfnet/pytorch-pfn-extras/blob/master/docs/config.md).

In [10]:
CONFIG_TYPES = {
    # # utils
    "__len__": lambda obj: len(obj),
    "method_call": lambda obj, method: getattr(obj, method)(),

    # # Dataset, DataLoader
    "SetiSimpleDataset": SetiSimpleDataset,
    "SetiAObsDataset": SetiAObsDataset,
    "DataLoader": torch.utils.data.DataLoader,

    # # Data Augmentation
    "Compose": A.Compose, "OneOf": A.OneOf,
    "Resize": A.Resize,
    "HorizontalFlip": A.HorizontalFlip, "VerticalFlip": A.VerticalFlip,
    "ShiftScaleRotate": A.ShiftScaleRotate,
    "RandomResizedCrop": A.RandomResizedCrop,
    "Cutout": A.Cutout,
    "ToTensorV2": ToTensorV2,

    # # Model
    "BasicImageModel": BasicImageModel,

    # # Optimizer
    "AdamW": optim.AdamW,

    # # Scheduler
    "OneCycleLR": lr_scheduler.OneCycleLR,

    # # Loss,Metric
    "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "ROCAUC": ROCAUC,

    # # Metric Wrapper
    "micro_average": micro_average,
    "calc_across_all_batchs": calc_across_all_batchs,

    # # PPE Extensions
    "ExtensionsManager": ppe.training.ExtensionsManager,

    "observe_lr": ppe_exts.observe_lr,
    "LogReport": ppe_exts.LogReport,
    "PlotReport": ppe_exts.PlotReport,
    "PrintReport": ppe_exts.PrintReport,
    "PrintReportNotebook": ppe_exts.PrintReportNotebook,
    "ProgressBar": ppe_exts.ProgressBar,
    "ProgressBarNotebook": ppe_exts.ProgressBarNotebook,
    "snapshot": ppe_exts.snapshot,
    "LRScheduler": ppe_exts.LRScheduler, 

    "MinValueTrigger": ppe_triggers.MinValueTrigger,
    "MaxValueTrigger": ppe_triggers.MaxValueTrigger,
    "EarlyStoppingTrigger": ppe_triggers.EarlyStoppingTrigger,
}

## configration

In [11]:
pre_eval_cfg = yaml.safe_load(
"""
globals:
  seed: 1086
  val_fold: null  # indicate when training
  output_path: null # indicate when training
  device: cuda:0
  enable_amp: False
  max_epoch: 100

model:
  type: BasicImageModel
  dims_head: [null, 1]
  base_name: efficientnetv2_rw_s
  pretrained: True
  in_channels: 1

dataset:
  height: 320
  width: 320
  mixup: {enabled: True, alpha: 1.0}
  train:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: HorizontalFlip, p: 0.5}
        - {type: VerticalFlip, p: 0.5}
        - {type: ShiftScaleRotate, p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
            rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}
        - {type: RandomResizedCrop, p: 1.0,
            scale: [0.9, 1.0], height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}
  val:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}  
  test:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform: "@/dataset/val/transform"

loader:
  train: {type: DataLoader, dataset: "@/dataset/train",
    batch_size: 25, num_workers: 4, shuffle: True, pin_memory: True, drop_last: True}
  val: {type: DataLoader, dataset: "@/dataset/val",
    batch_size: 25, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}
  test: {type: DataLoader, dataset: "@/dataset/test",
    batch_size: 25, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}

optimizer:
  type: AdamW
  params: {type: method_call, obj: "@/model", method: parameters}
  lr: 1.0e-05
  weight_decay: 1.0e-04

scheduler:
  type: OneCycleLR
  optimizer: "@/optimizer"
  epochs: "@/globals/max_epoch"
  steps_per_epoch: {type: __len__, obj: "@/loader/train"}
  max_lr: 1.0e-3
  pct_start: 0.111
  anneal_strategy: cos
  div_factor: 1.0e+2
  final_div_factor: 1

loss: {type: BCEWithLogitsLoss}

eval:
  - type: micro_average
    metric_func: {type: BCEWithLogitsLoss}
    report_name: loss
  - type: calc_across_all_batchs
    metric_func: {type: ROCAUC}
    report_name: metric

manager:
  type: ExtensionsManager
  models: "@/model"
  optimizers: "@/optimizer"
  max_epochs: "@/globals/max_epoch"
  iters_per_epoch: {type: __len__, obj: "@/loader/train"}
  out_dir: "@/globals/output_path"
  #  stop_trigger: {type: EarlyStoppingTrigger,
  #    monitor: val/metric, mode: max, patience: 5, verbose: True,
  #    check_trigger: [1, epoch], max_trigger: ["@/globals/max_epoch", epoch]}

extensions:
  # # log
  - {type: observe_lr, optimizer: "@/optimizer"}
  - {type: LogReport}
  - {type: PlotReport, y_keys: lr, x_key: epoch, filename: lr.png}
  - {type: PlotReport, y_keys: [train/loss, val/loss], x_key: epoch, filename: loss.png}
  - {type: PlotReport, y_keys: val/metric, x_key: epoch, filename: metric.png}
  - {type: PrintReport, entries: [
      epoch, iteration, lr, train/loss, val/loss, val/metric, elapsed_time]}
  - {type: ProgressBarNotebook, update_interval: 20}
  # snapshot
  - extension: {type: snapshot, target: "@/model", filename: "snapshot_by_metric_epoch_{.epoch}.pth"}
    trigger: {type: MaxValueTrigger, key: "val/metric", trigger: [1, epoch]}
  # # lr scheduler
  - {type: LRScheduler, scheduler: "@/scheduler", trigger: [1,  iteration]}
"""
)

## functions for training

In [12]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore


def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [13]:
def get_path_label(cfg: Config, train_all: pd.DataFrame):
    """Get file path and target info."""
    use_fold = cfg["/globals/val_fold"]

    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
    train_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in train_df["id"].values],
        "labels": train_df[CLASSES].values.astype("f")}
    val_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in val_df["id"].values],
        "labels": val_df[CLASSES].values.astype("f")
    }
    return train_path_label, val_path_label


def get_eval_func(cfg, model, device):
    
    def eval_func(**batch):
        """Run evaliation for val or test. This function is applied to each batch."""
        batch = to_device(batch, device)
        x = batch["image"]
        with amp.autocast(cfg["/globals/enable_amp"]): 
            y = model(x)
        return y.detach().cpu().to(torch.float32)  # input of metrics

    return eval_func


def mixup_data(use_mixup, x, t, alpha=1.0, use_cuda=True, device="cuda:0"):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if not use_mixup:
        return x, t, None, None
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    t_a, t_b = t, t[index]
    return mixed_x, t_a, t_b, lam


def get_criterion(use_mixup, loss_func):

    def mixup_criterion(pred, t_a, t_b, lam):
        return lam * loss_func(pred, t_a) + (1 - lam) * loss_func(pred, t_b)

    def single_criterion(pred, t_a, t_b, lam):
        return loss_func(pred, t_a)
    
    if use_mixup:
        return mixup_criterion
    else:
        return single_criterion

In [14]:
def train_one_fold(cfg, train_all):
    """Main"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg["/globals/seed"], deterministic=True)
    device = torch.device(cfg["/globals/device"])
    
    train_path_label, val_path_label = get_path_label(cfg, train_all)
    print("train: {}, val: {}".format(len(train_path_label["paths"]), len(val_path_label["paths"])))
   
    cfg["/dataset/train"].lazy_init(**train_path_label)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    train_loader = cfg["/loader/train"]
    val_loader = cfg["/loader/val"]

    model = cfg["/model"]
    model.to(device)
    optimizer = cfg["/optimizer"]
    loss_func = cfg["/loss"]
    loss_func.to(device)
    
    manager = cfg["/manager"]
    for ext in cfg["/extensions"]:
        if isinstance(ext, dict):
            manager.extend(**ext)
        else:
            manager.extend(ext)

    evaluator = ppe_exts.Evaluator(
        val_loader, model, eval_func=get_eval_func(cfg, model, device),
        metrics=cfg["/eval"], progress_bar=False)
    manager.extend(evaluator, trigger=(1, "epoch"))

    use_amp = cfg["/globals/enable_amp"]
    scaler = amp.GradScaler(enabled=use_amp)
    use_mixup = cfg["/dataset/mixup/enabled"]
    mixup_alpha = cfg["/dataset/mixup/alpha"]
    
    while not manager.stop_trigger:
        model.train()
        for batch in train_loader:
            with manager.run_iteration():
                batch = to_device(batch, device)
                x, t = batch["image"], batch["target"]
                # # for mixup
                mixed_x, t_a, t_b, lam = mixup_data(use_mixup, x, t, mixup_alpha, device=cfg["/globals/device"])
                criterion = get_criterion(use_mixup, loss_func)
                
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    y = model(mixed_x)
                    loss = criterion(y, t_a, t_b, lam)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                ppe.reporting.report({'train/loss': loss.item()})

## run train

In [15]:
pre_eval_cfg_list = []
for fold_id in FOLDS:
    tmp_cfg = copy.deepcopy(pre_eval_cfg)
    tmp_cfg["globals"]["val_fold"] = fold_id
    tmp_cfg["globals"]["output_path"] = str(TMP / f"fold{fold_id}")
    pre_eval_cfg_list.append(tmp_cfg)

In [16]:
for pre_eval_cfg in pre_eval_cfg_list:
    cfg = Config(pre_eval_cfg, types=CONFIG_TYPES)
    print(f"\n[fold {cfg['/globals/val_fold']}]")
    train_one_fold(cfg, train)
    with torch.cuda.device(cfg["/globals/device"]):
        torch.cuda.empty_cache()
    del cfg
    gc.collect()


[fold 0]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…

epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.351755    0.368666    0.577105    561.232       




[J2           3840        8.71758e-05  0.308377    0.260624    0.757121    1125.4        




[J3           5760        0.000177926  0.293959    0.24584     0.784233    1689.59       




[J4           7680        0.000294704  0.288816    0.254956    0.781617    2254.24       




[J5           9600        0.000428217  0.289676    0.241164    0.796126    2818.77       




[J6           11520       0.000567841  0.290633    0.22913     0.802372    3383.46       




[J7           13440       0.000702463  0.290717    0.264974    0.759294    3948.34       




[J8           15360       0.000821372  0.291541    0.237573    0.790133    4513.15       




[J9           17280       0.000915105  0.292948    0.240325    0.785345    5077.64       




[J10          19200       0.000976202  0.291698    0.235459    0.796163    5642.29       




[J11          21120       0.000999802  0.290859    0.237104    0.798526    6206.58       




[J12          23040       0.00099975  0.288778    0.236622    0.789414    6771.32       




[J13          24960       0.000998885  0.28558     0.223824    0.808853    7335.75       




[J14          26880       0.000997403  0.285702    0.22564     0.796309    7900.52       




[J15          28800       0.000995306  0.284829    0.224706    0.809563    8465.34       




[J16          30720       0.000992598  0.283286    0.216788    0.817306    9030.92       




[J17          32640       0.00098928  0.28187     0.222634    0.822022    9597.07       




[J18          34560       0.000985357  0.280623    0.21997     0.815497    10163.6       




[J19          36480       0.000980835  0.280411    0.219111    0.820924    10729.5       




[J20          38400       0.000975719  0.279878    0.223109    0.812077    11295.9       




[J21          40320       0.000970015  0.279478    0.213626    0.819472    11862         




[J22          42240       0.00096373  0.276883    0.21233     0.824068    12428.6       




[J23          44160       0.000956872  0.278363    0.204401    0.830382    12994.8       




[J24          46080       0.000949451  0.276266    0.213779    0.825946    13561.6       




[J25          48000       0.000941474  0.274737    0.210323    0.834491    14127.5       




[J26          49920       0.000932952  0.273153    0.207555    0.835176    14693.7       




[J27          51840       0.000923896  0.275094    0.208211    0.836547    15259.1       




[J28          53760       0.000914316  0.274191    0.202399    0.840739    15825.3       




[J29          55680       0.000904226  0.272641    0.206784    0.838385    16391.3       




[J30          57600       0.000893637  0.272933    0.205575    0.832327    16957.2       




[J31          59520       0.000882563  0.27171     0.204972    0.840284    17523.2       




[J32          61440       0.000871017  0.270403    0.200325    0.845145    18089.4       




[J33          63360       0.000859015  0.270115    0.201708    0.843905    18656.1       




[J34          65280       0.00084657  0.270016    0.200269    0.841674    19222         




[J35          67200       0.000833699  0.270344    0.203225    0.836753    19788.6       




[J36          69120       0.000820417  0.269167    0.202028    0.836386    20354.9       




[J37          71040       0.000806742  0.268528    0.196989    0.844952    20921         




[J38          72960       0.000792689  0.268213    0.202268    0.842397    21487.4       




[J39          74880       0.000778278  0.267151    0.199721    0.847202    22053.5       




[J40          76800       0.000763525  0.265627    0.199301    0.848159    22619.7       




[J41          78720       0.000748449  0.267576    0.198412    0.846281    23186.2       




[J42          80640       0.00073307  0.266818    0.19943     0.844821    23751.6       




[J43          82560       0.000717405  0.266719    0.191038    0.850171    24317.5       




[J44          84480       0.000701476  0.26516     0.192286    0.850817    24883.6       




[J45          86400       0.000685301  0.264391    0.196699    0.846094    25448.7       




[J46          88320       0.000668901  0.264107    0.190705    0.850363    26013.8       




[J47          90240       0.000652296  0.2643      0.193779    0.848659    26580.1       




[J48          92160       0.000635507  0.263353    0.190363    0.851367    27145.6       




[J49          94080       0.000618556  0.263905    0.192392    0.851356    27711.3       




[J50          96000       0.000601462  0.263183    0.194866    0.85227     28276.3       




[J51          97920       0.000584249  0.26232     0.18714     0.853285    28841.6       




[J52          99840       0.000566936  0.261574    0.191432    0.851081    29407.5       




[J53          101760      0.000549546  0.261845    0.189553    0.851833    29974         




[J54          103680      0.0005321   0.260775    0.186983    0.855213    30540         




[J55          105600      0.00051462  0.259568    0.188488    0.859646    31105.9       




[J56          107520      0.000497129  0.26064     0.189463    0.855088    31670.9       




[J57          109440      0.000479647  0.260094    0.192693    0.855874    32235.8       




[J58          111360      0.000462197  0.257459    0.190527    0.854168    32801.1       




[J59          113280      0.0004448   0.259161    0.187098    0.854818    33366.2       




[J60          115200      0.000427479  0.25915     0.183834    0.860655    33931.4       




[J61          117120      0.000410254  0.25845     0.185283    0.856481    34496.6       




[J62          119040      0.000393147  0.256756    0.185343    0.859101    35061.8       




[J63          120960      0.000376181  0.257409    0.180879    0.860705    35626.4       




[J64          122880      0.000359375  0.257412    0.182283    0.86009     36191.4       




[J65          124800      0.000342751  0.257239    0.181384    0.86071     36756.1       




[J66          126720      0.000326329  0.255755    0.183383    0.859426    37321         




[J67          128640      0.000310131  0.256359    0.181872    0.861665    37885.8       




[J68          130560      0.000294176  0.256206    0.184367    0.860452    38451.5       




[J69          132480      0.000278484  0.25493     0.180379    0.861647    39016.2       




[J70          134400      0.000263075  0.254356    0.179028    0.861829    39581.3       




[J71          136320      0.000247968  0.256071    0.182715    0.857917    40145.9       




[J72          138240      0.000233183  0.253631    0.182741    0.861434    40710.8       




[J73          140160      0.000218736  0.253189    0.180068    0.86057     41275.8       




[J74          142080      0.000204647  0.253705    0.181046    0.862901    41840.9       




[J75          144000      0.000190933  0.253518    0.181528    0.863468    42406         




[J76          145920      0.000177611  0.253932    0.178177    0.864199    42971.8       




[J77          147840      0.000164698  0.252846    0.177719    0.864396    43537.1       




[J78          149760      0.00015221  0.253153    0.181165    0.862317    44102.3       




[J79          151680      0.000140163  0.251638    0.179942    0.862707    44667.2       




[J80          153600      0.000128571  0.251999    0.179312    0.864211    45231.9       




[J81          155520      0.000117449  0.251223    0.180418    0.863693    45797.2       




[J82          157440      0.000106811  0.25097     0.182092    0.864651    46362         




[J83          159360      9.66699e-05  0.250541    0.181817    0.866219    46926.9       




[J84          161280      8.70389e-05  0.250641    0.181045    0.862958    47491.8       




[J85          163200      7.79299e-05  0.250474    0.177978    0.866736    48057.3       




[J86          165120      6.93541e-05  0.250602    0.176485    0.867049    48622.5       




[J87          167040      6.13224e-05  0.248718    0.176885    0.867071    49187.5       




[J88          168960      5.38446e-05  0.251822    0.175794    0.868366    49752.4       




[J89          170880      4.69302e-05  0.250322    0.177668    0.867344    50318         




[J90          172800      4.05877e-05  0.248615    0.177552    0.868879    50882.8       




[J91          174720      3.48252e-05  0.25016     0.177112    0.8671      51448.3       




[J92          176640      2.96497e-05  0.249449    0.179389    0.867565    52013.1       




[J93          178560      2.50679e-05  0.248913    0.175857    0.868209    52578.1       




[J94          180480      2.10853e-05  0.250701    0.174678    0.867238    53142.7       




[J95          182400      1.77069e-05  0.249962    0.176246    0.867509    53708.6       




[J96          184320      1.49371e-05  0.249798    0.175248    0.86756     54273.5       




[J97          186240      1.27791e-05  0.248987    0.175175    0.867792    54838.3       




[J98          188160      1.12358e-05  0.248571    0.178229    0.867629    55403.1       




[J99          190080      1.0309e-05  0.249796    0.176304    0.867872    55967.7       




[J100         192000      1e-05       0.250206    0.175772    0.86805     56532.7       

[fold 1]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.352073    0.329914    0.569657    563.823       




[J2           3840        8.71758e-05  0.308112    0.270538    0.727613    1128.74       




[J3           5760        0.000177926  0.293804    0.244363    0.766193    1693.68       




[J4           7680        0.000294704  0.28763     0.246545    0.775053    2259.1        




[J5           9600        0.000428217  0.286901    0.238275    0.769538    2824.33       




[J6           11520       0.000567841  0.28921     0.264982    0.780397    3389.27       




[J7           13440       0.000702463  0.289438    0.245115    0.77195     3954.15       




[J8           15360       0.000821372  0.290819    0.255631    0.746785    4519.35       




[J9           17280       0.000915105  0.290346    0.259375    0.758322    5083.67       




[J10          19200       0.000976202  0.290312    0.243243    0.770911    5648.59       




[J11          21120       0.000999802  0.289729    0.244101    0.774737    6213.22       




[J12          23040       0.00099975  0.288883    0.244635    0.782046    6778.06       




[J13          24960       0.000998885  0.285611    0.23663     0.771652    7343.56       




[J14          26880       0.000997403  0.284948    0.251742    0.775484    7908.72       




[J15          28800       0.000995306  0.283862    0.235031    0.780348    8473.46       




[J16          30720       0.000992598  0.282471    0.233247    0.794564    9038.21       




[J17          32640       0.00098928  0.280462    0.230199    0.792181    9603.33       




[J18          34560       0.000985357  0.280057    0.233387    0.800358    10168.3       




[J19          36480       0.000980835  0.279097    0.225595    0.796687    10733.6       




[J20          38400       0.000975719  0.277793    0.22555     0.803815    11298         




[J21          40320       0.000970015  0.279142    0.226335    0.796652    11863.4       




[J22          42240       0.00096373  0.275784    0.22427     0.803807    12428.2       




[J23          44160       0.000956872  0.275678    0.221094    0.798226    12993.2       




[J24          46080       0.000949451  0.274753    0.223081    0.805593    13558.4       




[J25          48000       0.000941474  0.274073    0.217308    0.812051    14123.6       




[J26          49920       0.000932952  0.27305     0.220046    0.811528    14688.7       




[J27          51840       0.000923896  0.273237    0.225271    0.800639    15253.5       




[J28          53760       0.000914316  0.272074    0.213647    0.817265    15818.2       




[J29          55680       0.000904226  0.270693    0.21362     0.808076    16383.4       




[J30          57600       0.000893637  0.271056    0.217109    0.812905    16948.4       




[J31          59520       0.000882563  0.269909    0.211038    0.811792    17513.8       




[J32          61440       0.000871017  0.268491    0.211295    0.818128    18078.8       




[J33          63360       0.000859015  0.269689    0.20992     0.820028    18644.1       




[J34          65280       0.00084657  0.269025    0.209464    0.823642    19209.5       




[J35          67200       0.000833699  0.26908     0.214139    0.814423    19774.1       




[J36          69120       0.000820417  0.268147    0.215855    0.816397    20338.9       




[J37          71040       0.000806742  0.2664      0.206264    0.828841    20903.8       




[J38          72960       0.000792689  0.266166    0.207672    0.824396    21469.7       




[J39          74880       0.000778278  0.26536     0.209175    0.821713    22035.1       




[J40          76800       0.000763525  0.266496    0.207178    0.829884    22600         




[J41          78720       0.000748449  0.263251    0.205032    0.827315    23164.8       




[J42          80640       0.00073307  0.264441    0.200268    0.83411     23729.9       




[J43          82560       0.000717405  0.263176    0.207684    0.8329      24294.9       




[J44          84480       0.000701476  0.263813    0.20325     0.830628    24859.8       




[J45          86400       0.000685301  0.263733    0.203845    0.832526    25424.8       




[J46          88320       0.000668901  0.262131    0.206058    0.826515    25989.5       




[J47          90240       0.000652296  0.263283    0.201613    0.83435     26554.5       




[J48          92160       0.000635507  0.262418    0.202755    0.82963     27119.5       




[J49          94080       0.000618556  0.260637    0.200928    0.832026    27684.3       




[J50          96000       0.000601462  0.260025    0.207347    0.824569    28248.6       




[J51          97920       0.000584249  0.259394    0.2013      0.837368    28813.3       




[J52          99840       0.000566936  0.2604      0.201133    0.833095    29377.8       




[J53          101760      0.000549546  0.26097     0.202273    0.833835    29942.3       




[J54          103680      0.0005321   0.258184    0.198289    0.836213    30507.2       




[J55          105600      0.00051462  0.259405    0.201146    0.838283    31072.2       




[J56          107520      0.000497129  0.258205    0.198733    0.836099    31637.3       




[J57          109440      0.000479647  0.261064    0.199186    0.839601    32202.3       




[J58          111360      0.000462197  0.254793    0.195851    0.841131    32767.7       




[J59          113280      0.0004448   0.256336    0.196368    0.841683    33333         




[J60          115200      0.000427479  0.25663     0.196113    0.844005    33898         




[J61          117120      0.000410254  0.256235    0.195165    0.841163    34463.2       




[J62          119040      0.000393147  0.255942    0.192985    0.843339    35027.8       




[J63          120960      0.000376181  0.255115    0.193828    0.84397     35592.7       




[J64          122880      0.000359375  0.25404     0.193561    0.839233    36157.4       




[J65          124800      0.000342751  0.254071    0.195641    0.842083    36722.4       




[J66          126720      0.000326329  0.25594     0.196399    0.842801    37287.2       




[J67          128640      0.000310131  0.254711    0.19579     0.841184    37853.5       




[J68          130560      0.000294176  0.253716    0.196855    0.837733    38418.3       




[J69          132480      0.000278484  0.253665    0.194356    0.842761    38983.4       




[J70          134400      0.000263075  0.252087    0.194787    0.842717    39548.3       




[J71          136320      0.000247968  0.252318    0.193914    0.842643    40113.5       




[J72          138240      0.000233183  0.25311     0.193523    0.843331    40678.6       




[J73          140160      0.000218736  0.252151    0.19222     0.845579    41243.8       




[J74          142080      0.000204647  0.252883    0.192972    0.844729    41808.8       




[J75          144000      0.000190933  0.25151     0.195928    0.84016     42374.3       




[J76          145920      0.000177611  0.251702    0.190564    0.846558    42939.5       




[J77          147840      0.000164698  0.250824    0.18939     0.848506    43505.4       




[J78          149760      0.00015221  0.252363    0.191825    0.847432    44071.2       




[J79          151680      0.000140163  0.249667    0.191896    0.849147    44637         




[J80          153600      0.000128571  0.251143    0.190232    0.848664    45202.6       




[J81          155520      0.000117449  0.249075    0.191018    0.846987    45768.1       




[J82          157440      0.000106811  0.249551    0.190485    0.849102    46333         




[J83          159360      9.66699e-05  0.250375    0.192529    0.848182    46898.1       




[J84          161280      8.70389e-05  0.248805    0.189974    0.848106    47463.8       




[J85          163200      7.79299e-05  0.248021    0.188937    0.849768    48029.1       




[J86          165120      6.93541e-05  0.249561    0.187803    0.851133    48594.7       




[J87          167040      6.13224e-05  0.247888    0.190255    0.847464    49160         




[J88          168960      5.38446e-05  0.248464    0.188593    0.850877    49725.4       




[J89          170880      4.69302e-05  0.249932    0.188551    0.851685    50290.2       




[J90          172800      4.05877e-05  0.24814     0.191477    0.850442    50855.6       




[J91          174720      3.48252e-05  0.2472      0.189166    0.850947    51420.8       




[J92          176640      2.96497e-05  0.247309    0.189019    0.851987    51986.3       




[J93          178560      2.50679e-05  0.248847    0.188987    0.85183     52551.6       




[J94          180480      2.10853e-05  0.247894    0.188051    0.851681    53116.8       




[J95          182400      1.77069e-05  0.249574    0.189467    0.851562    53681.8       




[J96          184320      1.49371e-05  0.24625     0.19018     0.85138     54247.3       




[J97          186240      1.27791e-05  0.246207    0.189249    0.85137     54812.4       




[J98          188160      1.12358e-05  0.245502    0.189872    0.851575    55377.4       




[J99          190080      1.0309e-05  0.246458    0.189923    0.850563    55942.4       




[J100         192000      1e-05       0.246039    0.18945     0.851484    56507.7       

[fold 2]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.3518      0.350975    0.562719    564.319       




[J2           3840        8.71758e-05  0.309551    0.452824    0.724688    1129.01       




[J3           5760        0.000177926  0.293519    0.254398    0.760232    1684.64       




[J4           7680        0.000294704  0.289376    0.236979    0.778478    2236.24       




[J5           9600        0.000428217  0.28875     0.257227    0.762425    2787.88       




[J6           11520       0.000567841  0.289427    0.247649    0.769504    3339.55       




[J7           13440       0.000702463  0.290981    0.237786    0.778165    3891.18       




[J8           15360       0.000821372  0.291946    0.238142    0.788256    4442.9        




[J9           17280       0.000915105  0.291557    0.231331    0.799589    4994.48       




[J10          19200       0.000976202  0.290887    0.245159    0.783484    5546.31       




[J11          21120       0.000999802  0.289391    0.234618    0.792881    6097.66       




[J12          23040       0.00099975  0.288986    0.247806    0.7837      6649.06       




[J13          24960       0.000998885  0.287835    0.246792    0.762692    7200.55       




[J14          26880       0.000997403  0.286889    0.233433    0.787638    7752.14       




[J15          28800       0.000995306  0.286332    0.23708     0.791069    8303.74       




[J16          30720       0.000992598  0.283097    0.227756    0.797873    8855.15       




[J17          32640       0.00098928  0.283754    0.223371    0.804294    9406.71       




[J18          34560       0.000985357  0.28254     0.224498    0.807651    9958.42       




[J19          36480       0.000980835  0.280938    0.229405    0.810951    10510         




[J20          38400       0.000975719  0.281015    0.22186     0.801772    11061.6       




[J21          40320       0.000970015  0.280389    0.218649    0.810723    11613.2       




[J22          42240       0.00096373  0.27781     0.215959    0.810666    12164.7       




[J23          44160       0.000956872  0.277169    0.218415    0.805041    12716.4       




[J24          46080       0.000949451  0.274555    0.212282    0.816461    13268.1       




[J25          48000       0.000941474  0.275454    0.213124    0.813065    13819.5       




[J26          49920       0.000932952  0.27324     0.220571    0.804065    14370.9       




[J27          51840       0.000923896  0.275005    0.214011    0.814712    14922.5       




[J28          53760       0.000914316  0.274859    0.20962     0.817008    15474         




[J29          55680       0.000904226  0.272539    0.206295    0.822968    16025.7       




[J30          57600       0.000893637  0.27184     0.203622    0.822239    16577.3       




[J31          59520       0.000882563  0.27287     0.211742    0.817869    17130.3       




[J32          61440       0.000871017  0.271752    0.208283    0.823656    17681.8       




[J33          63360       0.000859015  0.270564    0.202186    0.828412    18233.4       




[J34          65280       0.00084657  0.269889    0.201907    0.823771    18785.2       




[J35          67200       0.000833699  0.270098    0.203706    0.827198    19336.8       




[J36          69120       0.000820417  0.268721    0.21094     0.825411    19888.3       




[J37          71040       0.000806742  0.267524    0.206883    0.821807    20439.9       




[J38          72960       0.000792689  0.269682    0.204267    0.829894    20991.6       




[J39          74880       0.000778278  0.266657    0.202986    0.823435    21543.3       




[J40          76800       0.000763525  0.265901    0.201771    0.829325    22094.8       




[J41          78720       0.000748449  0.267405    0.203611    0.832484    22646.5       




[J42          80640       0.00073307  0.264453    0.206884    0.827889    23198.5       




[J43          82560       0.000717405  0.266781    0.197203    0.833701    23750.1       




[J44          84480       0.000701476  0.264853    0.200238    0.830586    24301.9       




[J45          86400       0.000685301  0.264328    0.19884     0.832061    24853.5       




[J46          88320       0.000668901  0.263933    0.195212    0.833703    25405.3       




[J47          90240       0.000652296  0.263724    0.197837    0.837859    25957.2       




[J48          92160       0.000635507  0.262386    0.190604    0.832707    26509         




[J49          94080       0.000618556  0.2635      0.196745    0.835479    27060.6       




[J50          96000       0.000601462  0.262468    0.19798     0.831494    27612.2       




[J51          97920       0.000584249  0.263056    0.192446    0.837176    28163.7       




[J52          99840       0.000566936  0.262065    0.192781    0.834623    28715.3       




[J53          101760      0.000549546  0.262056    0.194683    0.836539    29267         




[J54          103680      0.0005321   0.261036    0.197838    0.830911    29818.7       




[J55          105600      0.00051462  0.259728    0.19335     0.837677    30370.4       




[J56          107520      0.000497129  0.259952    0.19225     0.834156    30922.3       




[J57          109440      0.000479647  0.259833    0.190858    0.843077    31474.1       




[J58          111360      0.000462197  0.258087    0.189872    0.842218    32026.2       




[J59          113280      0.0004448   0.259851    0.190156    0.843038    32577.8       




[J60          115200      0.000427479  0.259742    0.189898    0.843589    33129.5       




[J61          117120      0.000410254  0.258203    0.188693    0.844913    33681.3       




[J62          119040      0.000393147  0.258283    0.187347    0.846224    34233.1       




[J63          120960      0.000376181  0.257667    0.187913    0.844501    34784.9       




[J64          122880      0.000359375  0.255823    0.185716    0.84651     35336.4       




[J65          124800      0.000342751  0.25624     0.183423    0.845459    35888.1       




[J66          126720      0.000326329  0.25694     0.188418    0.848355    36439.9       




[J67          128640      0.000310131  0.256552    0.185986    0.847294    36991.7       




[J68          130560      0.000294176  0.255275    0.187915    0.846782    37543.3       




[J69          132480      0.000278484  0.254698    0.185446    0.848162    38095         




[J70          134400      0.000263075  0.253505    0.188401    0.847556    38646.8       




[J71          136320      0.000247968  0.25553     0.187106    0.850742    39198.5       




[J72          138240      0.000233183  0.253666    0.187195    0.849612    39750.3       




[J73          140160      0.000218736  0.254477    0.183649    0.847677    40301.9       




[J74          142080      0.000204647  0.254339    0.183732    0.846798    40853.5       




[J75          144000      0.000190933  0.251421    0.182901    0.850292    41405.1       




[J76          145920      0.000177611  0.25589     0.18362     0.850066    41958.4       




[J77          147840      0.000164698  0.253254    0.180662    0.851451    42510.1       




[J78          149760      0.00015221  0.253477    0.184151    0.851807    43062         




[J79          151680      0.000140163  0.251661    0.183047    0.85283     43613.8       




[J80          153600      0.000128571  0.250761    0.179962    0.852829    44166         




[J81          155520      0.000117449  0.250704    0.181984    0.850603    44717.8       




[J82          157440      0.000106811  0.25172     0.183086    0.851054    45269.5       




[J83          159360      9.66699e-05  0.250883    0.181066    0.852152    45821.1       




[J84          161280      8.70389e-05  0.250847    0.182015    0.851175    46372.8       




[J85          163200      7.79299e-05  0.249786    0.178739    0.853363    46924.5       




[J86          165120      6.93541e-05  0.250477    0.180948    0.852969    47476.4       




[J87          167040      6.13224e-05  0.248994    0.179986    0.852137    48028.1       




[J88          168960      5.38446e-05  0.250251    0.181565    0.851676    48579.9       




[J89          170880      4.69302e-05  0.251246    0.182598    0.853012    49131.6       




[J90          172800      4.05877e-05  0.249778    0.181128    0.853815    49683.2       




[J91          174720      3.48252e-05  0.250313    0.181449    0.852193    50234.9       




[J92          176640      2.96497e-05  0.249991    0.181495    0.85229     50786.6       




[J93          178560      2.50679e-05  0.248741    0.178884    0.853122    51338.1       




[J94          180480      2.10853e-05  0.249331    0.17798     0.853911    51889.7       




[J95          182400      1.77069e-05  0.250706    0.179929    0.852821    52441.4       




[J96          184320      1.49371e-05  0.249743    0.179524    0.85322     52993.2       




[J97          186240      1.27791e-05  0.249021    0.179687    0.852754    53544.8       




[J98          188160      1.12358e-05  0.249522    0.180921    0.852765    54096.6       




[J99          190080      1.0309e-05  0.24815     0.184766    0.852689    54651.8       




[J100         192000      1e-05       0.248098    0.17841     0.852534    55204.1       

[fold 3]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.352486    0.999022    0.536006    553.125       




[J2           3840        8.71758e-05  0.310061    0.368392    0.730989    1116.58       




[J3           5760        0.000177926  0.292434    0.257249    0.761372    1681.53       




[J4           7680        0.000294704  0.287824    0.241778    0.783202    2246.6        




[J5           9600        0.000428217  0.289057    0.234818    0.793635    2811.87       




[J6           11520       0.000567841  0.290291    0.238503    0.790193    3377.03       




[J7           13440       0.000702463  0.289673    0.246402    0.792055    3941.84       




[J8           15360       0.000821372  0.289974    0.244847    0.786376    4506.8        




[J9           17280       0.000915105  0.291528    0.247044    0.784069    5071.64       




[J10          19200       0.000976202  0.292201    0.244731    0.792212    5636.01       




[J11          21120       0.000999802  0.288971    0.235829    0.794459    6201.24       




[J12          23040       0.00099975  0.28855     0.233753    0.789365    6765.73       




[J13          24960       0.000998885  0.286835    0.233763    0.791396    7330.59       




[J14          26880       0.000997403  0.285621    0.227827    0.801945    7895.64       




[J15          28800       0.000995306  0.285144    0.233048    0.790831    8460.97       




[J16          30720       0.000992598  0.284006    0.227091    0.803455    9025.75       




[J17          32640       0.00098928  0.282204    0.232301    0.805265    9590.66       




[J18          34560       0.000985357  0.281411    0.220297    0.813269    10156.1       




[J19          36480       0.000980835  0.280128    0.222355    0.811462    10721.2       




[J20          38400       0.000975719  0.278044    0.218292    0.809628    11286         




[J21          40320       0.000970015  0.280273    0.214881    0.814741    11850.8       




[J22          42240       0.00096373  0.276303    0.211339    0.824371    12416.4       




[J23          44160       0.000956872  0.276965    0.218191    0.811233    12981.2       




[J24          46080       0.000949451  0.275187    0.210164    0.824631    13546         




[J25          48000       0.000941474  0.274981    0.215053    0.813997    14110.9       




[J26          49920       0.000932952  0.275462    0.210332    0.816922    14676         




[J27          51840       0.000923896  0.276521    0.223021    0.801635    15241.1       




[J28          53760       0.000914316  0.274077    0.223134    0.819039    15806.6       




[J29          55680       0.000904226  0.272569    0.209938    0.826334    16371.7       




[J30          57600       0.000893637  0.272268    0.210845    0.824431    16937.1       




[J31          59520       0.000882563  0.272499    0.2161      0.823972    17502.1       




[J32          61440       0.000871017  0.270956    0.207422    0.833385    18067.1       




[J33          63360       0.000859015  0.271146    0.206422    0.828171    18633.7       




[J34          65280       0.00084657  0.270783    0.205756    0.829358    19201.4       




[J35          67200       0.000833699  0.270089    0.210595    0.825294    19767.5       




[J36          69120       0.000820417  0.269088    0.202079    0.833407    20334.4       




[J37          71040       0.000806742  0.26798     0.209279    0.83517     20900.8       




[J38          72960       0.000792689  0.267485    0.202705    0.83303     21467.5       




[J39          74880       0.000778278  0.266627    0.206463    0.83023     22034.3       




[J40          76800       0.000763525  0.267149    0.19532     0.835617    22600.5       




[J41          78720       0.000748449  0.267519    0.198329    0.835653    23166.2       




[J42          80640       0.00073307  0.264761    0.196518    0.838637    23732.8       




[J43          82560       0.000717405  0.266485    0.194999    0.839426    24300.2       




[J44          84480       0.000701476  0.265662    0.195008    0.835645    24867.8       




[J45          86400       0.000685301  0.265032    0.191587    0.840882    25434.5       




[J46          88320       0.000668901  0.264448    0.191434    0.842863    26001.6       




[J47          90240       0.000652296  0.264967    0.201383    0.83878     26570.7       




[J48          92160       0.000635507  0.264879    0.191013    0.847332    27137.5       




[J49          94080       0.000618556  0.262745    0.191731    0.84646     27704.1       




[J50          96000       0.000601462  0.262401    0.192117    0.85151     28270.8       




[J51          97920       0.000584249  0.260246    0.194715    0.852992    28837.3       




[J52          99840       0.000566936  0.260902    0.192875    0.845122    29404.1       




[J53          101760      0.000549546  0.261434    0.1975      0.841925    29970.9       




[J54          103680      0.0005321   0.26138     0.196655    0.848134    30537.7       




[J55          105600      0.00051462  0.260368    0.191085    0.846434    31104.5       




[J56          107520      0.000497129  0.259999    0.195011    0.843695    31670.9       




[J57          109440      0.000479647  0.260569    0.192458    0.845128    32237.8       




[J58          111360      0.000462197  0.258392    0.193717    0.846666    32804.2       




[J59          113280      0.0004448   0.260108    0.191492    0.850258    33370.7       




[J60          115200      0.000427479  0.258207    0.186572    0.85282     33937.6       




[J61          117120      0.000410254  0.257211    0.185885    0.85149     34504.2       




[J62          119040      0.000393147  0.256673    0.186777    0.849158    35070.7       




[J63          120960      0.000376181  0.257028    0.188668    0.853635    35637.3       




[J64          122880      0.000359375  0.256281    0.184947    0.853117    36203.7       




[J65          124800      0.000342751  0.258339    0.183553    0.855036    36770.5       




[J66          126720      0.000326329  0.255502    0.188379    0.85087     37337.6       




[J67          128640      0.000310131  0.256524    0.185676    0.853892    37904.2       




[J68          130560      0.000294176  0.255044    0.183823    0.854352    38470.7       




[J69          132480      0.000278484  0.255236    0.185385    0.853515    39037.3       




[J70          134400      0.000263075  0.25442     0.185265    0.855917    39604.1       




[J71          136320      0.000247968  0.255981    0.185438    0.854948    40170.2       




[J72          138240      0.000233183  0.25456     0.182945    0.856483    40735.5       




[J73          140160      0.000218736  0.255274    0.181285    0.85604     41302         




[J74          142080      0.000204647  0.253662    0.183323    0.857784    41868.7       




[J75          144000      0.000190933  0.254137    0.182883    0.858163    42435.2       




[J76          145920      0.000177611  0.253657    0.184739    0.856664    43000.8       




[J77          147840      0.000164698  0.251103    0.179464    0.857097    43566.8       




[J78          149760      0.00015221  0.253831    0.184852    0.85569     44133.4       




[J79          151680      0.000140163  0.250811    0.181165    0.855208    44699.5       




[J80          153600      0.000128571  0.252797    0.179651    0.856316    45265.5       




[J81          155520      0.000117449  0.251017    0.181037    0.857412    45831.4       




[J82          157440      0.000106811  0.251428    0.180131    0.859142    46397.4       




[J83          159360      9.66699e-05  0.250255    0.180661    0.85801     46962.7       




[J84          161280      8.70389e-05  0.249631    0.18022     0.857372    47528.6       




[J85          163200      7.79299e-05  0.250867    0.179998    0.857972    48094.6       




[J86          165120      6.93541e-05  0.250423    0.180548    0.859231    48659.4       




[J87          167040      6.13224e-05  0.247893    0.179609    0.859112    49224.8       




[J88          168960      5.38446e-05  0.251709    0.179103    0.859241    49790.8       




[J89          170880      4.69302e-05  0.25056     0.17832     0.858874    50355.8       




[J90          172800      4.05877e-05  0.248807    0.17777     0.859898    50920.8       




[J91          174720      3.48252e-05  0.249992    0.178262    0.859703    51487         




[J92          176640      2.96497e-05  0.248669    0.177039    0.859633    52053.2       




[J93          178560      2.50679e-05  0.248974    0.176406    0.859093    52619.6       




[J94          180480      2.10853e-05  0.249492    0.176902    0.859154    53185.9       




[J95          182400      1.77069e-05  0.25048     0.17726     0.8588      53751.7       




[J96          184320      1.49371e-05  0.248874    0.176978    0.859111    54317.9       




[J97          186240      1.27791e-05  0.249089    0.177892    0.859427    54884         




[J98          188160      1.12358e-05  0.248496    0.178791    0.859768    55449.1       




[J99          190080      1.0309e-05  0.249576    0.179594    0.860026    56015.2       




[J100         192000      1e-05       0.248632    0.176534    0.860121    56581.9       

[fold 4]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.351984    2.39856     0.588945    564.097       




[J2           3840        8.71758e-05  0.307155    18.9532     0.739597    1129.43       




[J3           5760        0.000177926  0.293189    0.67278     0.767676    1694.74       




[J4           7680        0.000294704  0.287506    0.259241    0.786381    2259.81       




[J5           9600        0.000428217  0.286942    0.249052    0.774226    2826.2        




[J6           11520       0.000567841  0.289535    0.247581    0.781536    3392.09       




[J7           13440       0.000702463  0.29006     0.248005    0.769612    3958.48       




[J8           15360       0.000821372  0.291596    0.25073     0.777271    4524.96       




[J9           17280       0.000915105  0.292205    0.251578    0.777119    5091.16       




[J10          19200       0.000976202  0.291097    0.254007    0.781581    5653.17       




[J11          21120       0.000999802  0.28968     0.233082    0.792663    6204.55       




[J12          23040       0.00099975  0.289753    0.238434    0.793241    6756.15       




[J13          24960       0.000998885  0.287268    0.234832    0.794099    7307.91       




[J14          26880       0.000997403  0.284354    0.242515    0.798732    7859.54       




[J15          28800       0.000995306  0.284225    0.239634    0.799892    8411.16       




[J16          30720       0.000992598  0.28251     0.228356    0.802276    8962.92       




[J17          32640       0.00098928  0.280629    0.229208    0.804045    9514.64       




[J18          34560       0.000985357  0.281364    0.227828    0.804759    10066.4       




[J19          36480       0.000980835  0.280349    0.225481    0.811495    10618.1       




[J20          38400       0.000975719  0.278837    0.22045     0.819675    11169.9       




[J21          40320       0.000970015  0.278606    0.221198    0.813155    11721.5       




[J22          42240       0.00096373  0.275867    0.22433     0.814678    12273.1       




[J23          44160       0.000956872  0.274795    0.219377    0.808012    12824.4       




[J24          46080       0.000949451  0.275492    0.219918    0.81851     13375.9       




[J25          48000       0.000941474  0.275346    0.213887    0.824143    13927.4       




[J26          49920       0.000932952  0.274553    0.215245    0.825365    14479.1       




[J27          51840       0.000923896  0.274072    0.21581     0.818854    15030.8       




[J28          53760       0.000914316  0.274476    0.206856    0.823968    15582.4       




[J29          55680       0.000904226  0.27133     0.217229    0.81301     16133.9       




[J30          57600       0.000893637  0.273402    0.215406    0.826417    16685.4       




[J31          59520       0.000882563  0.272186    0.217181    0.824627    17237.2       




[J32          61440       0.000871017  0.26985     0.209791    0.831115    17788.7       




[J33          63360       0.000859015  0.269905    0.208587    0.824823    18340.3       




[J34          65280       0.00084657  0.269934    0.212673    0.826918    18891.8       




[J35          67200       0.000833699  0.268098    0.206081    0.829259    19443.3       




[J36          69120       0.000820417  0.267833    0.204526    0.833154    19994.8       




[J37          71040       0.000806742  0.266614    0.205716    0.833286    20546.4       




[J38          72960       0.000792689  0.267686    0.205067    0.833839    21098         




[J39          74880       0.000778278  0.266774    0.212061    0.831119    21649.5       




[J40          76800       0.000763525  0.265651    0.205367    0.831279    22201.2       




[J41          78720       0.000748449  0.266683    0.205635    0.836636    22752.7       




[J42          80640       0.00073307  0.265272    0.204602    0.830837    23304.4       




[J43          82560       0.000717405  0.266971    0.207294    0.832834    23855.7       




[J44          84480       0.000701476  0.264581    0.200624    0.834793    24407.2       




[J45          86400       0.000685301  0.264046    0.200304    0.836248    24958.6       




[J46          88320       0.000668901  0.263055    0.198913    0.838484    25510.1       




[J47          90240       0.000652296  0.263313    0.206047    0.835599    26061.7       




[J48          92160       0.000635507  0.263627    0.203696    0.838119    26613         




[J49          94080       0.000618556  0.262565    0.197092    0.8385      27164.5       




[J50          96000       0.000601462  0.262172    0.20446     0.839156    27716.1       




[J51          97920       0.000584249  0.261135    0.200684    0.835218    28267.6       




[J52          99840       0.000566936  0.260671    0.198088    0.843366    28818.9       




[J53          101760      0.000549546  0.259576    0.19747     0.838387    29370.3       




[J54          103680      0.0005321   0.260796    0.200673    0.842429    29921.6       




[J55          105600      0.00051462  0.259346    0.1947      0.842246    30472.9       




[J56          107520      0.000497129  0.259979    0.196059    0.842327    31024.2       




[J57          109440      0.000479647  0.259825    0.198233    0.847432    31575.4       




[J58          111360      0.000462197  0.256512    0.195816    0.844181    32126.8       




[J59          113280      0.0004448   0.25943     0.191956    0.845199    32678.3       




[J60          115200      0.000427479  0.258389    0.193664    0.844912    33229.6       




[J61          117120      0.000410254  0.25754     0.193946    0.842949    33781.1       




[J62          119040      0.000393147  0.257397    0.193642    0.840634    34332.7       




[J63          120960      0.000376181  0.257026    0.192145    0.844242    34884.3       




[J64          122880      0.000359375  0.255342    0.191842    0.846926    35435.6       




[J65          124800      0.000342751  0.253796    0.191145    0.846599    35989.6       




[J66          126720      0.000326329  0.257215    0.189999    0.847521    36541.1       




[J67          128640      0.000310131  0.254759    0.19112     0.845687    37092.6       




[J68          130560      0.000294176  0.2548      0.189437    0.848593    37643.8       




[J69          132480      0.000278484  0.256099    0.190961    0.850721    38195.4       




[J70          134400      0.000263075  0.25497     0.192218    0.84763     38746.9       




[J71          136320      0.000247968  0.255261    0.190068    0.85019     39298.2       




[J72          138240      0.000233183  0.25377     0.190733    0.847934    39849.6       




[J73          140160      0.000218736  0.252951    0.18722     0.850551    40401         




[J74          142080      0.000204647  0.254513    0.190307    0.852336    40952.2       




[J75          144000      0.000190933  0.25229     0.191748    0.848905    41503.7       




[J76          145920      0.000177611  0.252952    0.189683    0.85146     42055.2       




[J77          147840      0.000164698  0.251268    0.186104    0.850093    42606.6       




[J78          149760      0.00015221  0.251768    0.19284     0.846607    43158.1       




[J79          151680      0.000140163  0.250329    0.190404    0.850562    43709.5       




[J80          153600      0.000128571  0.250823    0.186612    0.85249     44260.7       




[J81          155520      0.000117449  0.2505      0.188644    0.852086    44812.2       




[J82          157440      0.000106811  0.251091    0.189076    0.850972    45363.7       




[J83          159360      9.66699e-05  0.249736    0.186372    0.853142    45915.3       




[J84          161280      8.70389e-05  0.250498    0.186585    0.853309    46466.7       




[J85          163200      7.79299e-05  0.249359    0.187027    0.85177     47018.1       




[J86          165120      6.93541e-05  0.250008    0.185673    0.853188    47569.9       




[J87          167040      6.13224e-05  0.247118    0.185568    0.852621    48121.5       




[J88          168960      5.38446e-05  0.248552    0.187006    0.852803    48673         




[J89          170880      4.69302e-05  0.250286    0.185341    0.853601    49224.6       




[J90          172800      4.05877e-05  0.249177    0.185984    0.853056    49776.3       




[J91          174720      3.48252e-05  0.248973    0.185577    0.853052    50327.7       




[J92          176640      2.96497e-05  0.248118    0.183621    0.855137    50879.1       




[J93          178560      2.50679e-05  0.248605    0.182863    0.855307    51430.5       




[J94          180480      2.10853e-05  0.249012    0.18394     0.85476     51982.5       




[J95          182400      1.77069e-05  0.248597    0.184739    0.854014    52534.2       




[J96          184320      1.49371e-05  0.248227    0.184071    0.855568    53085.5       




[J97          186240      1.27791e-05  0.247841    0.185707    0.854166    53637.1       




[J98          188160      1.12358e-05  0.248874    0.184768    0.854179    54188.4       




[J99          190080      1.0309e-05  0.246507    0.184867    0.855722    54739.8       




[J100         192000      1e-05       0.248636    0.183981    0.85426     55291.5       


# Inference

## Copy best models

In [24]:
best_log_list = []
for pre_eval_cfg, fold_id in zip(pre_eval_cfg_list, FOLDS):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/metric"].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_metric_epoch_{best_epoch}.pth"
    copy_to = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/config.yml", "w") as fw:
        yaml.dump(pre_eval_cfg, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

FileNotFoundError: [Errno 2] No such file or directory: '/home/yamaguchi-milkcocholate/SETI/notebooks/efficientnetv2s_100epoch/tmp/fold0/snapshot_by_metric_epoch_90.pth'

## Inference OOF & Test

In [18]:
def run_inference_loop(cfg, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch["image"], device)
            y = model(x)
            pred_list.append(y.sigmoid().detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [19]:
label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train), N_CLASSES))
score_list = []
test_pred_arr = np.zeros((N_FOLDS, len(smpl_sub), N_CLASSES))
test_path_label = {
    "paths": [DATA / f"test/{img_id[0]}/{img_id}.npy" for img_id in smpl_sub["id"].values],
    "labels": smpl_sub[CLASSES].values.astype("f")
}

for fold_id in FOLDS:
    print(f"\n[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "config.yml", "r") as fr:
        cfg = Config(yaml.safe_load(fr), types=CONFIG_TYPES)
    device = torch.device(cfg["/globals/device"])
    val_idx = train.query("fold == @fold_id").index.values

    # # get_dataloader
    _, val_path_label = get_path_label(cfg, train)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    cfg["/dataset/test"].lazy_init(**test_path_label)
    val_loader = cfg["/loader/val"]
    test_loader = cfg["/loader/test"]
    
    # # get model
    model_path = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    model = cfg["/model"]
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # # inference
    val_pred = run_inference_loop(cfg, model, val_loader, device)
    val_score = roc_auc_score(label_arr[val_idx], val_pred)
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_score])
    
    test_pred_arr[fold_id] = run_inference_loop(cfg, model, test_loader, device)
    
    del cfg, val_idx, val_path_label
    del model, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"val score: {val_score:.4f}")


[fold 0]
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


  0%|          | 0/480 [00:00<?, ?it/s]



  0%|          | 0/1600 [00:00<?, ?it/s]



val score: 0.8689

[fold 1]
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


  0%|          | 0/480 [00:00<?, ?it/s]



  0%|          | 0/1600 [00:00<?, ?it/s]



val score: 0.8520

[fold 2]
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


  0%|          | 0/480 [00:00<?, ?it/s]



  0%|          | 0/1600 [00:00<?, ?it/s]



val score: 0.8539

[fold 3]
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


  0%|          | 0/480 [00:00<?, ?it/s]



  0%|          | 0/1600 [00:00<?, ?it/s]



val score: 0.8601

[fold 4]
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


  0%|          | 0/480 [00:00<?, ?it/s]



  0%|          | 0/1600 [00:00<?, ?it/s]



val score: 0.8557


In [20]:
oof_score = roc_auc_score(label_arr, oof_pred_arr)
score_list.append(["oof", oof_score])
pd.DataFrame(score_list, columns=["fold", "metric"])

Unnamed: 0,fold,metric
0,0,0.868879
1,1,0.851987
2,2,0.853911
3,3,0.860121
4,4,0.855722
5,oof,0.857497


In [21]:
oof_df = train.copy()
oof_df[CLASSES] = oof_pred_arr
oof_df.to_csv(OUTPUT / "oof_prediction.csv", index=False)

## Make submission

In [22]:
sub_df = smpl_sub.copy()
sub_df[CLASSES] = test_pred_arr.mean(axis=0)
sub_df.to_csv(OUTPUT / "submission.csv", index=False)

In [23]:
sub_df.head()

Unnamed: 0,id,target
0,000bf832cae9ff1,0.054728
1,000c74cc71a1140,0.08241
2,000f5f9851161d3,0.051909
3,000f7499e95aba6,0.119952
4,00133ce6ec257f9,0.064661


# EOF