# About

I'd like to share an example for [small models](https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644).

### Experimental Settings

### model
* backbone: resnet18d (use the pretrained model provided by [timm](https://github.com/rwightman/pytorch-image-models))
* head classifier: one linear layer
* num of input channels: **1**

### data augmentation
* implemented by [albumentations](https://albumentations.ai/docs/) **except for Mixup**
* Train
  * Resize
  * HorizontalFlip
  * VerticalFlip
  * ShiftScaleRotate
  * RandomResizedCrop
  * Mixup(alpha=1.0)
* Val, Test
  * Resize

### learning settings
* CV Strategy: Stratified KFold (K=5)
* max epochs: 18
* data:
  * input image size: 1x320x320
  * batch size: 64
* loss: [BCEWithLogitsLoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss)
* optimizer: [AdamW](https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW)
  * weight decay: 1.0e-04
* learning rate scheduler: [OneCycleLR](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR) 
  * epochs: 18
  * max_lr: 1e-3
  * pct_start: 0.111
  * anneal_strategy: cos
  * div_factor: 1.0e+2
  * inal_div_factor: 1
  
### NOTE: I use only on-target ('A') observations

```python
img = np.load(path)[[0, 2, 4]]          # shape: (3, 273, 256)
img = np.vstack(img)                    # shape: (819, 256)
img = img.transpose(1, 0)               # shape: (256, 819)
```

### Submission -> 

# Prapere

## Install

## Import

In [1]:
import os
import gc
import copy
import yaml
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.config import Config
from pytorch_pfn_extras.training import extensions as ppe_exts, triggers as ppe_triggers

In [2]:
# timm.list_models(pretrained=True)

['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn

In [3]:
ROOT = Path.cwd()
INPUT = ROOT / "../../"
OUTPUT = ROOT / "output"
DATA = INPUT / "seti-breakthrough-listen"
TRAIN = DATA / "train"
TEST = DATA / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086
CLASSES = ["target"]
N_CLASSES = len(CLASSES)
FOLDS = [0, 1, 2, 3, 4]
N_FOLDS = len(FOLDS)

## Read Data, Split folds

In [4]:
train = pd.read_csv(DATA / "train_labels.csv")
smpl_sub = pd.read_csv(DATA / "sample_submission.csv")

In [5]:
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (_, val_idx) in enumerate(skf.split(train["id"], train["target"])):
    train.loc[val_idx, "fold"] = fold_id

In [6]:
train.groupby("fold").agg(total=("id", len), pos=("target", sum))

Unnamed: 0_level_0,total,pos
fold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,12000,1200
1,12000,1200
2,12000,1200
3,12000,1200
4,12000,1200


## Definition of Model, Dataset, Metric

### Model

In [7]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name: str, dims_head: tp.List[int],
        pretrained=False, in_channels: int=3
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            base_model = timm.create_model(
                base_name, num_classes=0, pretrained=pretrained, in_chans=in_channels)
            in_features = base_model.num_features
            print("load imagenet pretrained:", pretrained)
        else:
            raise NotImplementedError

        self.backbone = base_model
        print(f"{base_name}: {in_features}")
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),
            ])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1])
        )
        self.head_cls = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head_cls(h)
        return h

### Dataset

In [8]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]


class SetiSimpleDataset(torch.utils.data.Dataset):
    """
    Dataset using 6 channels by stacking them along time-axis

    Attributes
    ----------
    paths : tp.Sequence[FilePath]
        Sequence of path to cadence snippet file
    labels : tp.Sequence[Label]
        Sequence of label for cadence snippet file
    transform: albumentations.Compose
        composed data augmentations for data
    """

    def __init__(
        self,
        paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label],
        transform: A.Compose,
    ):
        """Initialize"""
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Return num of cadence snippets"""
        return len(self.paths)

    def __getitem__(self, index: int):
        """Return transformed image and label for given index."""
        path, label = self.paths[index], self.labels[index]
        img = self._read_cadence_array(path)
        img = self.transform(image=img)["image"]
        return {"image": img, "target": label}

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)  # shape: (6, 273, 256)
        img = np.vstack(img)  # shape: (1638, 256)
        img = img.transpose(1, 0)  # shape: (256, 1638)
        img = img.astype("f")[..., np.newaxis]  # shape: (256, 1638, 1)
        return img

    def lazy_init(self, paths=None, labels=None, transform=None):
        """Reset Members"""
        if paths is not None:
            self.paths = paths
        if labels is not None:
            self.labels = labels
        if transform is not None:
            self.transform = transform


class SetiAObsDataset(SetiSimpleDataset):
    """Use only on-target observation"""

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
        img = np.vstack(img)  # shape: (819, 256)
        img = img.transpose(1, 0)  # shape: (256, 819)
        img = img.astype("f")[..., np.newaxis]  # shape: (819, 256, 1)
        return img

### Metric

In [9]:
Batch = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]]
ModelOut = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor], torch.Tensor]


class ROCAUC(nn.Module):
    """ROC AUC score"""

    def __init__(self, average="macro") -> None:
        """Initialize."""
        self.average = average
        super(ROCAUC, self).__init__()

    def forward(self, y, t) -> float:
        """Forward."""
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu().numpy()

        return roc_auc_score(t, y, average=self.average)


def micro_average(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """Return Metric Wrapper for Simple Mean Metric"""
    metric_sum = [0.]
    n_examples = [0]
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple): 
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        metric = metric_func(y, t).item()
        metric_sum[0] += metric * y.shape[0]
        n_examples[0] += y.shape[0]

        if is_last_batch:
            final_metric = metric_sum[0] / n_examples[0]
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            metric_sum[0] = 0.
            n_examples[0] = 0

    return wrapper


def calc_across_all_batchs(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """
    Return Metric Wrapper for Metrics caluculated on all data
    
    storing predictions and labels of evry batch, finally calculating metric on them.
    """
    pred_list = []
    label_list = []
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple):
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        pred_list.append(y.numpy())
        label_list.append(t.numpy())

        if is_last_batch:
            pred = np.concatenate(pred_list, axis=0)
            label = np.concatenate(label_list, axis=0)
            final_metric = metric_func(pred, label)
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            pred_list[:] = []
            label_list[:] = []

    return wrapper

# Train

## config_types for evaluating configuration

I use [pytorch-pfn-extras](https://github.com/pfnet/pytorch-pfn-extras) for training NNs. This library has useful config systems but requires some preparation.

For more details, see [docs](https://github.com/pfnet/pytorch-pfn-extras/blob/master/docs/config.md).

In [10]:
CONFIG_TYPES = {
    # # utils
    "__len__": lambda obj: len(obj),
    "method_call": lambda obj, method: getattr(obj, method)(),

    # # Dataset, DataLoader
    "SetiSimpleDataset": SetiSimpleDataset,
    "SetiAObsDataset": SetiAObsDataset,
    "DataLoader": torch.utils.data.DataLoader,

    # # Data Augmentation
    "Compose": A.Compose, "OneOf": A.OneOf,
    "Resize": A.Resize,
    "HorizontalFlip": A.HorizontalFlip, "VerticalFlip": A.VerticalFlip,
    "ShiftScaleRotate": A.ShiftScaleRotate,
    "RandomResizedCrop": A.RandomResizedCrop,
    "Cutout": A.Cutout,
    "ToTensorV2": ToTensorV2,

    # # Model
    "BasicImageModel": BasicImageModel,

    # # Optimizer
    "AdamW": optim.AdamW,

    # # Scheduler
    "OneCycleLR": lr_scheduler.OneCycleLR,

    # # Loss,Metric
    "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "ROCAUC": ROCAUC,

    # # Metric Wrapper
    "micro_average": micro_average,
    "calc_across_all_batchs": calc_across_all_batchs,

    # # PPE Extensions
    "ExtensionsManager": ppe.training.ExtensionsManager,

    "observe_lr": ppe_exts.observe_lr,
    "LogReport": ppe_exts.LogReport,
    "PlotReport": ppe_exts.PlotReport,
    "PrintReport": ppe_exts.PrintReport,
    "PrintReportNotebook": ppe_exts.PrintReportNotebook,
    "ProgressBar": ppe_exts.ProgressBar,
    "ProgressBarNotebook": ppe_exts.ProgressBarNotebook,
    "snapshot": ppe_exts.snapshot,
    "LRScheduler": ppe_exts.LRScheduler, 

    "MinValueTrigger": ppe_triggers.MinValueTrigger,
    "MaxValueTrigger": ppe_triggers.MaxValueTrigger,
    "EarlyStoppingTrigger": ppe_triggers.EarlyStoppingTrigger,
}

## configration

In [11]:
pre_eval_cfg = yaml.safe_load(
"""
globals:
  seed: 1086
  val_fold: null  # indicate when training
  output_path: null # indicate when training
  device: cuda:1
  enable_amp: False
  max_epoch: 75

model:
  type: BasicImageModel
  dims_head: [null, 1]
  base_name: mobilenetv2_100
  pretrained: True
  in_channels: 1

dataset:
  height: 320
  width: 320
  mixup: {enabled: True, alpha: 1.0}
  train:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: HorizontalFlip, p: 0.5}
        - {type: VerticalFlip, p: 0.5}
        - {type: ShiftScaleRotate, p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
            rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}
        - {type: RandomResizedCrop, p: 1.0,
            scale: [0.9, 1.0], height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}
  val:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}  
  test:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform: "@/dataset/val/transform"

loader:
  train: {type: DataLoader, dataset: "@/dataset/train",
    batch_size: 64, num_workers: 4, shuffle: True, pin_memory: True, drop_last: True}
  val: {type: DataLoader, dataset: "@/dataset/val",
    batch_size: 64, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}
  test: {type: DataLoader, dataset: "@/dataset/test",
    batch_size: 64, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}

optimizer:
  type: AdamW
  params: {type: method_call, obj: "@/model", method: parameters}
  lr: 1.0e-05
  weight_decay: 1.0e-04

scheduler:
  type: OneCycleLR
  optimizer: "@/optimizer"
  epochs: "@/globals/max_epoch"
  steps_per_epoch: {type: __len__, obj: "@/loader/train"}
  max_lr: 1.0e-3
  pct_start: 0.111
  anneal_strategy: cos
  div_factor: 1.0e+2
  final_div_factor: 1

loss: {type: BCEWithLogitsLoss}

eval:
  - type: micro_average
    metric_func: {type: BCEWithLogitsLoss}
    report_name: loss
  - type: calc_across_all_batchs
    metric_func: {type: ROCAUC}
    report_name: metric

manager:
  type: ExtensionsManager
  models: "@/model"
  optimizers: "@/optimizer"
  max_epochs: "@/globals/max_epoch"
  iters_per_epoch: {type: __len__, obj: "@/loader/train"}
  out_dir: "@/globals/output_path"
  # stop_trgiger: {type: EarlyStoppingTrigger,
  #   monitor: val/metric, mode: max, patience: 5, verbose: True,
  #   check_trigger: [1, epoch], max_trigger: ["@/globals/max_epoch", epoch]}

extensions:
  # # log
  - {type: observe_lr, optimizer: "@/optimizer"}
  - {type: LogReport}
  - {type: PlotReport, y_keys: lr, x_key: epoch, filename: lr.png}
  - {type: PlotReport, y_keys: [train/loss, val/loss], x_key: epoch, filename: loss.png}
  - {type: PlotReport, y_keys: val/metric, x_key: epoch, filename: metric.png}
  - {type: PrintReport, entries: [
      epoch, iteration, lr, train/loss, val/loss, val/metric, elapsed_time]}
  - {type: ProgressBarNotebook, update_interval: 20}
  # snapshot
  - extension: {type: snapshot, target: "@/model", filename: "snapshot_by_metric_epoch_{.epoch}.pth"}
    trigger: {type: MaxValueTrigger, key: "val/metric", trigger: [1, epoch]}
  # # lr scheduler
  - {type: LRScheduler, scheduler: "@/scheduler", trigger: [1,  iteration]}
"""
)

## functions for training

In [12]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore


def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [13]:
def get_path_label(cfg: Config, train_all: pd.DataFrame):
    """Get file path and target info."""
    use_fold = cfg["/globals/val_fold"]

    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
    train_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in train_df["id"].values],
        "labels": train_df[CLASSES].values.astype("f")}
    val_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in val_df["id"].values],
        "labels": val_df[CLASSES].values.astype("f")
    }
    return train_path_label, val_path_label


def get_eval_func(cfg, model, device):
    
    def eval_func(**batch):
        """Run evaliation for val or test. This function is applied to each batch."""
        batch = to_device(batch, device)
        x = batch["image"]
        with amp.autocast(cfg["/globals/enable_amp"]): 
            y = model(x)
        return y.detach().cpu().to(torch.float32)  # input of metrics

    return eval_func


def mixup_data(use_mixup, x, t, alpha=1.0, use_cuda=True, device="cuda:0"):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if not use_mixup:
        return x, t, None, None
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    t_a, t_b = t, t[index]
    return mixed_x, t_a, t_b, lam


def get_criterion(use_mixup, loss_func):

    def mixup_criterion(pred, t_a, t_b, lam):
        return lam * loss_func(pred, t_a) + (1 - lam) * loss_func(pred, t_b)

    def single_criterion(pred, t_a, t_b, lam):
        return loss_func(pred, t_a)
    
    if use_mixup:
        return mixup_criterion
    else:
        return single_criterion

In [14]:
def train_one_fold(cfg, train_all):
    """Main"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg["/globals/seed"], deterministic=True)
    device = torch.device(cfg["/globals/device"])
    
    train_path_label, val_path_label = get_path_label(cfg, train_all)
    print("train: {}, val: {}".format(len(train_path_label["paths"]), len(val_path_label["paths"])))
   
    cfg["/dataset/train"].lazy_init(**train_path_label)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    train_loader = cfg["/loader/train"]
    val_loader = cfg["/loader/val"]

    model = cfg["/model"]
    model.to(device)
    optimizer = cfg["/optimizer"]
    loss_func = cfg["/loss"]
    loss_func.to(device)
    
    manager = cfg["/manager"]
    for ext in cfg["/extensions"]:
        if isinstance(ext, dict):
            manager.extend(**ext)
        else:
            manager.extend(ext)

    evaluator = ppe_exts.Evaluator(
        val_loader, model, eval_func=get_eval_func(cfg, model, device),
        metrics=cfg["/eval"], progress_bar=False)
    manager.extend(evaluator, trigger=(1, "epoch"))

    use_amp = cfg["/globals/enable_amp"]
    scaler = amp.GradScaler(enabled=use_amp)
    use_mixup = cfg["/dataset/mixup/enabled"]
    mixup_alpha = cfg["/dataset/mixup/alpha"]
    
    while not manager.stop_trigger:
        model.train()
        for batch in train_loader:
            with manager.run_iteration():
                batch = to_device(batch, device)
                x, t = batch["image"], batch["target"]
                # # for mixup
                mixed_x, t_a, t_b, lam = mixup_data(use_mixup, x, t, mixup_alpha, device=cfg["/globals/device"])
                criterion = get_criterion(use_mixup, loss_func)
                
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    y = model(mixed_x)
                    loss = criterion(y, t_a, t_b, lam)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                ppe.reporting.report({'train/loss': loss.item()})

## run train

In [15]:
pre_eval_cfg_list = []
for fold_id in FOLDS:
    tmp_cfg = copy.deepcopy(pre_eval_cfg)
    tmp_cfg["globals"]["val_fold"] = fold_id
    tmp_cfg["globals"]["output_path"] = str(TMP / f"fold{fold_id}")
    pre_eval_cfg_list.append(tmp_cfg)

In [16]:
for pre_eval_cfg in pre_eval_cfg_list:
    cfg = Config(pre_eval_cfg, types=CONFIG_TYPES)
    print(f"\n[fold {cfg['/globals/val_fold']}]")
    train_one_fold(cfg, train)
    with torch.cuda.device(cfg["/globals/device"]):
        torch.cuda.empty_cache()
    del cfg
    gc.collect()


[fold 0]
train: 48000, val: 12000
load imagenet pretrained: True
mobilenetv2_100: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…

epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           750         4.47487e-05  0.395594    0.327175    0.515819    163.901       




[J2           1500        0.000144287  0.327411    0.324589    0.557651    327.009       




[J3           2250        0.000294602  0.319686    0.29573     0.684397    490.241       




[J4           3000        0.000474535  0.312865    0.274343    0.724752    653.543       




[J5           3750        0.000658757  0.305989    0.265316    0.753737    816.724       




[J6           4500        0.000821334  0.301167    0.260002    0.759107    979.805       




[J7           5250        0.000939381  0.297903    0.257428    0.773312    1143.28       




[J8           6000        0.000996281  0.297038    0.241525    0.784921    1306.59       




[J9           6750        0.00099975  0.296714    0.250565    0.787971    1469.88       




[J10          7500        0.000998459  0.291872    0.250764    0.780725    1633.25       




[J11          8250        0.000996073  0.293153    0.233896    0.801242    1796.43       




[J12          9000        0.000992598  0.290534    0.248379    0.789301    1959.63       




[J13          9750        0.000988039  0.289405    0.239186    0.788981    2122.83       




[J14          10500       0.000982409  0.290826    0.233727    0.815674    2286.07       




[J15          11250       0.000975719  0.289336    0.23357     0.80255     2449.49       




[J16          12000       0.000967984  0.287999    0.233015    0.813204    2612.78       




[J17          12750       0.000959221  0.286854    0.234813    0.822651    2775.92       




[J18          13500       0.000949451  0.28524     0.229626    0.813788    2939.16       




[J19          14250       0.000938693  0.28579     0.224295    0.818219    3102.3        




[J20          15000       0.000926973  0.282279    0.227861    0.810677    3265.43       




[J21          15750       0.000914316  0.283545    0.232081    0.817223    3428.65       




[J22          16500       0.000900751  0.283044    0.224836    0.817678    3591.97       




[J23          17250       0.000886307  0.281674    0.227348    0.818798    3755.26       




[J24          18000       0.000871017  0.280573    0.215096    0.823084    3918.56       




[J25          18750       0.000854915  0.280965    0.228209    0.823872    4081.76       




[J26          19500       0.000838036  0.280448    0.229022    0.819768    4244.88       




[J27          20250       0.000820417  0.281011    0.22044     0.828787    4408.02       




[J28          21000       0.000802098  0.279956    0.225397    0.831024    4571.22       




[J29          21750       0.00078312  0.277742    0.221234    0.824057    4734.51       




[J30          22500       0.000763525  0.276956    0.21469     0.826555    4897.85       




[J31          23250       0.000743356  0.276417    0.219458    0.829336    5061.09       




[J32          24000       0.000722657  0.276447    0.21276     0.828078    5224.43       




[J33          24750       0.000701476  0.275229    0.216722    0.833121    5387.64       




[J34          25500       0.000679858  0.27468     0.214093    0.831864    5550.81       




[J35          26250       0.000657852  0.273863    0.223226    0.829051    5714.05       




[J36          27000       0.000635507  0.27257     0.21232     0.828415    5877.54       




[J37          27750       0.000612873  0.274458    0.220377    0.833952    6040.71       




[J38          28500       0.000589999  0.272166    0.21397     0.837877    6204.04       




[J39          29250       0.000566936  0.273262    0.209533    0.837629    6367.29       




[J40          30000       0.000543736  0.272156    0.214952    0.839402    6530.53       




[J41          30750       0.000520449  0.270969    0.210679    0.832517    6693.77       




[J42          31500       0.000497129  0.272428    0.212274    0.842533    6856.93       




[J43          32250       0.000473826  0.27067     0.214909    0.841115    7020.25       




[J44          33000       0.000450592  0.269553    0.211464    0.838183    7183.32       




[J45          33750       0.000427479  0.271387    0.214065    0.841283    7346.44       




[J46          34500       0.000404537  0.268992    0.205492    0.838966    7509.61       




[J47          35250       0.000381819  0.270607    0.202335    0.841006    7673.02       




[J48          36000       0.000359375  0.268439    0.207274    0.837924    7836.15       




[J49          36750       0.000337253  0.268302    0.20751     0.841229    7999.57       




[J50          37500       0.000315504  0.266941    0.203532    0.840714    8163.16       




[J51          38250       0.000294176  0.267442    0.20024     0.840945    8326.19       




[J52          39000       0.000273315  0.268714    0.205067    0.842946    8489.33       




[J53          39750       0.000252969  0.266991    0.206364    0.84549     8652.64       




[J54          40500       0.000233183  0.265198    0.197238    0.846307    8815.76       




[J55          41250       0.000213999  0.267267    0.206335    0.847158    8978.86       




[J56          42000       0.000195462  0.265523    0.203201    0.844666    9142.14       




[J57          42750       0.000177611  0.263824    0.20417     0.846929    9305.33       




[J58          43500       0.000160488  0.26506     0.207144    0.843963    9468.67       




[J59          44250       0.000144129  0.264563    0.203641    0.845284    9631.81       




[J60          45000       0.000128571  0.261859    0.201428    0.844171    9794.98       




[J61          45750       0.000113848  0.26338     0.204829    0.844811    9958.3        




[J62          46500       9.99941e-05  0.262806    0.199652    0.84622     10121.4       




[J63          47250       8.70389e-05  0.265277    0.204553    0.846783    10284.6       




[J64          48000       7.50115e-05  0.260902    0.199306    0.847341    10447.8       




[J65          48750       6.39385e-05  0.262861    0.200197    0.847131    10611.1       




[J66          49500       5.38446e-05  0.263511    0.196788    0.84731     10775.1       




[J67          50250       4.47521e-05  0.261622    0.197265    0.847597    10938.5       




[J68          51000       3.66811e-05  0.263816    0.199244    0.847723    11101.8       




[J69          51750       2.96497e-05  0.261979    0.198408    0.847493    11265.1       




[J70          52500       2.36735e-05  0.263198    0.201598    0.847471    11428.3       




[J71          53250       1.87656e-05  0.262038    0.203362    0.84766     11591.5       




[J72          54000       1.49371e-05  0.26117     0.198632    0.847476    11754.9       




[J73          54750       1.21963e-05  0.262976    0.200278    0.848155    11918.3       




[J74          55500       1.05494e-05  0.261081    0.197901    0.848252    12081.7       




[J75          56250       1e-05       0.25979     0.199091    0.848729    12245         

[fold 1]
train: 48000, val: 12000
load imagenet pretrained: True
mobilenetv2_100: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           750         4.47487e-05  0.396433    0.327648    0.521964    162.27        




[J2           1500        0.000144287  0.32725     0.325862    0.550087    325.661       




[J3           2250        0.000294602  0.319064    0.300241    0.667631    489.064       




[J4           3000        0.000474535  0.308852    0.276343    0.718552    652.589       




[J5           3750        0.000658757  0.303955    0.264782    0.73843     816.028       




[J6           4500        0.000821334  0.30146     0.262993    0.744302    979.326       




[J7           5250        0.000939381  0.297337    0.261833    0.746532    1142.74       




[J8           6000        0.000996281  0.295722    0.253628    0.767793    1306.11       




[J9           6750        0.00099975  0.29534     0.246073    0.770703    1469.6        




[J10          7500        0.000998459  0.293527    0.256962    0.777545    1633.02       




[J11          8250        0.000996073  0.290599    0.25229     0.768891    1796.45       




[J12          9000        0.000992598  0.288722    0.245682    0.771381    1960.36       




[J13          9750        0.000988039  0.289547    0.237586    0.782613    2123.81       




[J14          10500       0.000982409  0.287508    0.238681    0.779822    2287.22       




[J15          11250       0.000975719  0.287231    0.234765    0.790251    2450.55       




[J16          12000       0.000967984  0.285186    0.240803    0.789088    2613.88       




[J17          12750       0.000959221  0.285028    0.239572    0.787416    2777.22       




[J18          13500       0.000949451  0.282994    0.235383    0.791306    2940.59       




[J19          14250       0.000938693  0.284005    0.243833    0.787469    3103.95       




[J20          15000       0.000926973  0.28169     0.2358      0.793101    3267.34       




[J21          15750       0.000914316  0.28253     0.237567    0.797796    3430.72       




[J22          16500       0.000900751  0.280649    0.231179    0.793963    3594.13       




[J23          17250       0.000886307  0.281947    0.227512    0.800311    3757.42       




[J24          18000       0.000871017  0.278665    0.237524    0.797887    3920.8        




[J25          18750       0.000854915  0.280248    0.231541    0.801192    4084.2        




[J26          19500       0.000838036  0.277       0.232633    0.793433    4247.69       




[J27          20250       0.000820417  0.27878     0.241435    0.794351    4411.14       




[J28          21000       0.000802098  0.275867    0.236123    0.797642    4574.44       




[J29          21750       0.00078312  0.277044    0.232572    0.808003    4737.87       




[J30          22500       0.000763525  0.276099    0.23357     0.802391    4901.26       




[J31          23250       0.000743356  0.273181    0.232554    0.80507     5064.74       




[J32          24000       0.000722657  0.275906    0.233339    0.803545    5228.04       




[J33          24750       0.000701476  0.273988    0.227167    0.807494    5391.47       




[J34          25500       0.000679858  0.273712    0.21922     0.817455    5554.84       




[J35          26250       0.000657852  0.273344    0.235148    0.813352    5718.43       




[J36          27000       0.000635507  0.270809    0.222048    0.818709    5881.86       




[J37          27750       0.000612873  0.271917    0.227586    0.810938    6046.72       




[J38          28500       0.000589999  0.271642    0.226473    0.81225     6210.11       




[J39          29250       0.000566936  0.270329    0.224859    0.814614    6373.48       




[J40          30000       0.000543736  0.26944     0.219682    0.818351    6536.85       




[J41          30750       0.000520449  0.270702    0.220813    0.82196     6700.39       




[J42          31500       0.000497129  0.271454    0.224048    0.819359    6863.77       




[J43          32250       0.000473826  0.269401    0.222697    0.826206    7027.29       




[J44          33000       0.000450592  0.267555    0.222852    0.819214    7190.75       




[J45          33750       0.000427479  0.270495    0.222087    0.818787    7354.03       




[J46          34500       0.000404537  0.266913    0.221331    0.819509    7517.32       




[J47          35250       0.000381819  0.268687    0.215796    0.82001     7680.69       




[J48          36000       0.000359375  0.266507    0.217937    0.818449    7844.03       




[J49          36750       0.000337253  0.266932    0.220472    0.821418    8007.51       




[J50          37500       0.000315504  0.264716    0.211847    0.823725    8171.07       




[J51          38250       0.000294176  0.264543    0.21603     0.82226     8334.67       




[J52          39000       0.000273315  0.266216    0.219644    0.826845    8498.21       




[J53          39750       0.000252969  0.266434    0.214936    0.824369    8661.7        




[J54          40500       0.000233183  0.263523    0.211893    0.827351    8825.15       




[J55          41250       0.000213999  0.265509    0.219539    0.828561    8988.6        




[J56          42000       0.000195462  0.263098    0.212686    0.827324    9151.98       




[J57          42750       0.000177611  0.261425    0.21503     0.827715    9315.53       




[J58          43500       0.000160488  0.262345    0.212127    0.827971    9478.92       




[J59          44250       0.000144129  0.263837    0.215076    0.830954    9642.29       




[J60          45000       0.000128571  0.262061    0.212623    0.828828    9805.8        




[J61          45750       0.000113848  0.261971    0.215532    0.828884    9969.22       




[J62          46500       9.99941e-05  0.260924    0.212156    0.827279    10132.8       




[J63          47250       8.70389e-05  0.263298    0.216283    0.827896    10296.3       




[J64          48000       7.50115e-05  0.261331    0.209041    0.831458    10459.9       




[J65          48750       6.39385e-05  0.261472    0.214967    0.828719    10623.5       




[J66          49500       5.38446e-05  0.26223     0.212071    0.82805     10787.1       




[J67          50250       4.47521e-05  0.260335    0.211721    0.830263    10951.1       




[J68          51000       3.66811e-05  0.262164    0.213913    0.830841    11116.2       




[J69          51750       2.96497e-05  0.262662    0.21117     0.830502    11280.3       




[J70          52500       2.36735e-05  0.262561    0.211536    0.830973    11444.3       




[J71          53250       1.87656e-05  0.261779    0.214244    0.831005    11608.6       




[J72          54000       1.49371e-05  0.261001    0.209997    0.8321      11772.8       




[J73          54750       1.21963e-05  0.258968    0.213377    0.83153     11936.8       




[J74          55500       1.05494e-05  0.260165    0.212626    0.831122    12101         




[J75          56250       1e-05       0.259451    0.214515    0.831568    12265         

[fold 2]
train: 48000, val: 12000
load imagenet pretrained: True
mobilenetv2_100: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           750         4.47487e-05  0.396845    0.328448    0.50906     163.086       




[J2           1500        0.000144287  0.32707     0.330132    0.553064    327.054       




[J3           2250        0.000294602  0.32073     0.306127    0.649442    491.148       




[J4           3000        0.000474535  0.311929    0.272313    0.709621    655.244       




[J5           3750        0.000658757  0.305883    0.265625    0.748941    819.311       




[J6           4500        0.000821334  0.301145    0.252896    0.766315    983.354       




[J7           5250        0.000939381  0.29892     0.245305    0.778273    1147.43       




[J8           6000        0.000996281  0.29579     0.24869     0.776455    1311.38       




[J9           6750        0.00099975  0.2955      0.246271    0.784561    1475.37       




[J10          7500        0.000998459  0.29221     0.245475    0.77394     1639.44       




[J11          8250        0.000996073  0.291898    0.244779    0.798328    1803.45       




[J12          9000        0.000992598  0.289748    0.24768     0.788432    1967.38       




[J13          9750        0.000988039  0.288362    0.231326    0.792345    2131.45       




[J14          10500       0.000982409  0.289412    0.239474    0.79108     2295.3        




[J15          11250       0.000975719  0.289405    0.232449    0.804906    2459.01       




[J16          12000       0.000967984  0.287108    0.239448    0.788352    2622.69       




[J17          12750       0.000959221  0.286833    0.24249     0.803106    2786.39       




[J18          13500       0.000949451  0.285785    0.226635    0.805168    2950.15       




[J19          14250       0.000938693  0.284839    0.235695    0.812237    3113.78       




[J20          15000       0.000926973  0.283906    0.233453    0.804067    3277.51       




[J21          15750       0.000914316  0.283266    0.229737    0.808507    3441.25       




[J22          16500       0.000900751  0.282095    0.234075    0.801711    3604.88       




[J23          17250       0.000886307  0.283797    0.232975    0.804196    3768.45       




[J24          18000       0.000871017  0.279728    0.222985    0.811685    3932.16       




[J25          18750       0.000854915  0.280983    0.232317    0.806997    4095.81       




[J26          19500       0.000838036  0.27937     0.224133    0.811167    4259.4        




[J27          20250       0.000820417  0.279185    0.229802    0.815926    4422.96       




[J28          21000       0.000802098  0.278037    0.22791     0.814805    4587.59       




[J29          21750       0.00078312  0.278193    0.223419    0.809458    4751.08       




[J30          22500       0.000763525  0.277068    0.217546    0.822801    4914.55       




[J31          23250       0.000743356  0.276356    0.214249    0.815603    5078.19       




[J32          24000       0.000722657  0.275321    0.220409    0.816326    5241.84       




[J33          24750       0.000701476  0.275223    0.21956     0.813298    5405.36       




[J34          25500       0.000679858  0.273803    0.215777    0.818654    5569.02       




[J35          26250       0.000657852  0.273389    0.218817    0.815816    5732.93       




[J36          27000       0.000635507  0.272619    0.214067    0.821882    5896.54       




[J37          27750       0.000612873  0.274208    0.218174    0.818136    6060.15       




[J38          28500       0.000589999  0.271393    0.213325    0.818487    6223.89       




[J39          29250       0.000566936  0.272063    0.216099    0.815288    6387.54       




[J40          30000       0.000543736  0.271095    0.217474    0.822771    6551.16       




[J41          30750       0.000520449  0.271011    0.219119    0.823843    6714.87       




[J42          31500       0.000497129  0.270686    0.21495     0.823506    6878.67       




[J43          32250       0.000473826  0.26925     0.217481    0.823668    7042.24       




[J44          33000       0.000450592  0.268539    0.211316    0.823882    7206          




[J45          33750       0.000427479  0.271925    0.212816    0.827767    7369.55       




[J46          34500       0.000404537  0.266952    0.21135     0.825534    7533.16       




[J47          35250       0.000381819  0.268672    0.208349    0.823381    7696.66       




[J48          36000       0.000359375  0.265314    0.210642    0.826331    7860.34       




[J49          36750       0.000337253  0.268555    0.210045    0.828154    8024.06       




[J50          37500       0.000315504  0.265421    0.210504    0.827421    8187.52       




[J51          38250       0.000294176  0.265487    0.207006    0.828946    8351.15       




[J52          39000       0.000273315  0.267657    0.204844    0.835431    8514.73       




[J53          39750       0.000252969  0.265385    0.208983    0.830775    8678.32       




[J54          40500       0.000233183  0.265764    0.202999    0.831892    8841.78       




[J55          41250       0.000213999  0.266002    0.216289    0.833475    9005.37       




[J56          42000       0.000195462  0.264767    0.206393    0.835322    9169.01       




[J57          42750       0.000177611  0.262705    0.206198    0.833927    9332.56       




[J58          43500       0.000160488  0.262616    0.207774    0.835419    9496.12       




[J59          44250       0.000144129  0.262973    0.211943    0.835072    9659.81       




[J60          45000       0.000128571  0.261356    0.203887    0.834843    9823.41       




[J61          45750       0.000113848  0.262649    0.212009    0.832707    9987.02       




[J62          46500       9.99941e-05  0.262571    0.204547    0.833344    10150.6       




[J63          47250       8.70389e-05  0.264562    0.211741    0.833581    10314.1       




[J64          48000       7.50115e-05  0.2621      0.201323    0.835154    10477.5       




[J65          48750       6.39385e-05  0.261545    0.206678    0.834378    10641.2       




[J66          49500       5.38446e-05  0.262636    0.203506    0.836038    10804.8       




[J67          50250       4.47521e-05  0.260732    0.203684    0.835282    10968.3       




[J68          51000       3.66811e-05  0.262257    0.205969    0.834724    11131.9       




[J69          51750       2.96497e-05  0.261484    0.203484    0.835038    11295.5       




[J70          52500       2.36735e-05  0.262138    0.204615    0.835474    11459.2       




[J71          53250       1.87656e-05  0.26183     0.20583     0.835544    11622.9       




[J72          54000       1.49371e-05  0.262878    0.204276    0.835164    11787.7       




[J73          54750       1.21963e-05  0.26048     0.206251    0.835386    11951.3       




[J74          55500       1.05494e-05  0.2625      0.201593    0.835723    12114.8       




[J75          56250       1e-05       0.260688    0.205347    0.835427    12278.4       

[fold 3]
train: 48000, val: 12000
load imagenet pretrained: True
mobilenetv2_100: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           750         4.47487e-05  0.39855     0.327532    0.520066    162.556       




[J2           1500        0.000144287  0.327707    0.322954    0.566708    326.222       




[J3           2250        0.000294602  0.321216    0.291351    0.696447    489.885       




[J4           3000        0.000474535  0.310451    0.270911    0.721735    653.594       




[J5           3750        0.000658757  0.305216    0.262262    0.755839    817.299       




[J6           4500        0.000821334  0.301598    0.259243    0.755643    981.08        




[J7           5250        0.000939381  0.299182    0.250463    0.771719    1144.63       




[J8           6000        0.000996281  0.296417    0.250767    0.779372    1308.25       




[J9           6750        0.00099975  0.296886    0.250766    0.781253    1471.95       




[J10          7500        0.000998459  0.293873    0.244054    0.781165    1635.57       




[J11          8250        0.000996073  0.292944    0.244231    0.789711    1799.2        




[J12          9000        0.000992598  0.290305    0.242322    0.804501    1962.75       




[J13          9750        0.000988039  0.288994    0.251256    0.798824    2126.37       




[J14          10500       0.000982409  0.288932    0.240071    0.796779    2289.93       




[J15          11250       0.000975719  0.2869      0.237484    0.804333    2453.53       




[J16          12000       0.000967984  0.287101    0.234265    0.806771    2617.2        




[J17          12750       0.000959221  0.286894    0.248442    0.810411    2780.83       




[J18          13500       0.000949451  0.284971    0.228665    0.803972    2944.6        




[J19          14250       0.000938693  0.286191    0.226749    0.806386    3108.16       




[J20          15000       0.000926973  0.284179    0.229834    0.814407    3271.57       




[J21          15750       0.000914316  0.284041    0.23548     0.802987    3435.17       




[J22          16500       0.000900751  0.283892    0.220198    0.816106    3598.72       




[J23          17250       0.000886307  0.281561    0.22832     0.811678    3762.25       




[J24          18000       0.000871017  0.281354    0.226933    0.813615    3925.88       




[J25          18750       0.000854915  0.281946    0.221301    0.812815    4089.49       




[J26          19500       0.000838036  0.279592    0.224943    0.820564    4253.13       




[J27          20250       0.000820417  0.278563    0.228017    0.817544    4416.7        




[J28          21000       0.000802098  0.278264    0.227617    0.824218    4580.4        




[J29          21750       0.00078312  0.278523    0.227635    0.817737    4744.15       




[J30          22500       0.000763525  0.277668    0.222435    0.822541    4907.76       




[J31          23250       0.000743356  0.274722    0.225334    0.821127    5071.29       




[J32          24000       0.000722657  0.275782    0.223844    0.826252    5234.91       




[J33          24750       0.000701476  0.275064    0.22273     0.822598    5398.49       




[J34          25500       0.000679858  0.273336    0.212931    0.8222      5561.99       




[J35          26250       0.000657852  0.273453    0.228208    0.82038     5725.63       




[J36          27000       0.000635507  0.271297    0.211226    0.819911    5889.17       




[J37          27750       0.000612873  0.272759    0.225603    0.825816    6052.78       




[J38          28500       0.000589999  0.272674    0.220965    0.828045    6216.44       




[J39          29250       0.000566936  0.272092    0.226719    0.828715    6380.13       




[J40          30000       0.000543736  0.271453    0.224474    0.82704     6543.78       




[J41          30750       0.000520449  0.269889    0.21721     0.826428    6707.27       




[J42          31500       0.000497129  0.26993     0.213295    0.829629    6870.99       




[J43          32250       0.000473826  0.271101    0.211376    0.834631    7034.59       




[J44          33000       0.000450592  0.269956    0.220808    0.827812    7199.6        




[J45          33750       0.000427479  0.270626    0.220083    0.830142    7363.04       




[J46          34500       0.000404537  0.26766     0.217342    0.830946    7526.72       




[J47          35250       0.000381819  0.269295    0.211316    0.83275     7690.34       




[J48          36000       0.000359375  0.266274    0.209871    0.832466    7853.86       




[J49          36750       0.000337253  0.268253    0.213372    0.832801    8017.32       




[J50          37500       0.000315504  0.26714     0.211174    0.842177    8180.72       




[J51          38250       0.000294176  0.26632     0.207477    0.838139    8344.24       




[J52          39000       0.000273315  0.268039    0.209157    0.837922    8507.87       




[J53          39750       0.000252969  0.265872    0.212794    0.839374    8671.47       




[J54          40500       0.000233183  0.265547    0.207334    0.836776    8834.98       




[J55          41250       0.000213999  0.267086    0.206688    0.839479    8998.44       




[J56          42000       0.000195462  0.263935    0.212633    0.836451    9162.11       




[J57          42750       0.000177611  0.261307    0.209174    0.840535    9325.56       




[J58          43500       0.000160488  0.262853    0.207304    0.841176    9489.04       




[J59          44250       0.000144129  0.262936    0.208132    0.842595    9652.63       




[J60          45000       0.000128571  0.261268    0.203079    0.842732    9816.19       




[J61          45750       0.000113848  0.263849    0.211098    0.84316     9979.73       




[J62          46500       9.99941e-05  0.261579    0.200798    0.843981    10143.3       




[J63          47250       8.70389e-05  0.263298    0.211768    0.84192     10307         




[J64          48000       7.50115e-05  0.260767    0.201962    0.843672    10470.7       




[J65          48750       6.39385e-05  0.262008    0.211168    0.842972    10634.2       




[J66          49500       5.38446e-05  0.263566    0.20362     0.842392    10797.9       




[J67          50250       4.47521e-05  0.261414    0.204099    0.8436      10961.6       




[J68          51000       3.66811e-05  0.262157    0.205623    0.844637    11125.2       




[J69          51750       2.96497e-05  0.260724    0.201378    0.845131    11288.7       




[J70          52500       2.36735e-05  0.26315     0.20474     0.845053    11452.5       




[J71          53250       1.87656e-05  0.261737    0.206416    0.843773    11616.1       




[J72          54000       1.49371e-05  0.261236    0.204685    0.844472    11779.7       




[J73          54750       1.21963e-05  0.2612      0.206114    0.845335    11943.3       




[J74          55500       1.05494e-05  0.262161    0.203867    0.844779    12106.9       




[J75          56250       1e-05       0.259778    0.205293    0.844928    12270.6       





[fold 4]
train: 48000, val: 12000
load imagenet pretrained: True
mobilenetv2_100: 1280


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           750         4.47487e-05  0.39507     0.329933    0.505125    162.509       




[J2           1500        0.000144287  0.327266    0.324592    0.568911    326.166       




[J3           2250        0.000294602  0.319565    0.289088    0.680282    489.845       




[J4           3000        0.000474535  0.309012    0.277113    0.674103    653.612       




[J5           3750        0.000658757  0.305256    0.271718    0.728659    817.341       




[J6           4500        0.000821334  0.302721    0.27242     0.736762    981.12        




[J7           5250        0.000939381  0.299133    0.261263    0.745474    1144.75       




[J8           6000        0.000996281  0.297588    0.269867    0.782933    1308.34       




[J9           6750        0.00099975  0.295166    0.251621    0.77768     1472.07       




[J10          7500        0.000998459  0.291408    0.243832    0.789609    1635.74       




[J11          8250        0.000996073  0.291201    0.240906    0.792164    1799.31       




[J12          9000        0.000992598  0.289553    0.243497    0.794052    1962.95       




[J13          9750        0.000988039  0.289645    0.240667    0.788529    2126.62       




[J14          10500       0.000982409  0.288731    0.244332    0.7883      2290.18       




[J15          11250       0.000975719  0.287241    0.237849    0.800332    2453.89       




[J16          12000       0.000967984  0.286515    0.241237    0.796645    2617.54       




[J17          12750       0.000959221  0.286165    0.237578    0.796752    2781.2        




[J18          13500       0.000949451  0.284113    0.238241    0.803873    2944.91       




[J19          14250       0.000938693  0.283971    0.23835     0.805736    3108.58       




[J20          15000       0.000926973  0.283692    0.2373      0.801027    3272.31       




[J21          15750       0.000914316  0.282207    0.235838    0.812682    3436.03       




[J22          16500       0.000900751  0.281883    0.238183    0.799862    3599.86       




[J23          17250       0.000886307  0.2811      0.232605    0.817475    3763.48       




[J24          18000       0.000871017  0.280297    0.227729    0.817158    3927.26       




[J25          18750       0.000854915  0.280564    0.237626    0.812411    4090.93       




[J26          19500       0.000838036  0.277993    0.243281    0.807523    4254.53       




[J27          20250       0.000820417  0.279253    0.232379    0.821569    4418.19       




[J28          21000       0.000802098  0.278071    0.234578    0.819169    4581.83       




[J29          21750       0.00078312  0.27609     0.228731    0.817333    4745.39       




[J30          22500       0.000763525  0.277365    0.226316    0.821147    4908.92       




[J31          23250       0.000743356  0.274592    0.227254    0.815619    5072.52       




[J32          24000       0.000722657  0.275475    0.228332    0.823239    5236.17       




[J33          24750       0.000701476  0.27478     0.222714    0.825196    5399.84       




[J34          25500       0.000679858  0.272538    0.224538    0.82007     5563.77       




[J35          26250       0.000657852  0.274023    0.22869     0.823971    5727.36       




[J36          27000       0.000635507  0.271017    0.216265    0.827664    5890.98       




[J37          27750       0.000612873  0.272978    0.225054    0.822688    6054.66       




[J38          28500       0.000589999  0.270706    0.226688    0.822243    6218.26       




[J39          29250       0.000566936  0.269924    0.220144    0.827349    6381.73       




[J40          30000       0.000543736  0.271024    0.221082    0.824797    6545.38       




[J41          30750       0.000520449  0.270175    0.221269    0.82625     6709.03       




[J42          31500       0.000497129  0.270666    0.216642    0.829399    6872.5        




[J43          32250       0.000473826  0.270074    0.21833     0.827219    7036.15       




[J44          33000       0.000450592  0.268486    0.224318    0.822541    7199.74       




[J45          33750       0.000427479  0.269717    0.219091    0.829741    7363.29       




[J46          34500       0.000404537  0.268387    0.217693    0.830064    7526.86       




[J47          35250       0.000381819  0.268862    0.213902    0.829432    7690.51       




[J48          36000       0.000359375  0.266837    0.214064    0.832298    7854.08       




[J49          36750       0.000337253  0.268487    0.215453    0.82804     8017.7        




[J50          37500       0.000315504  0.265445    0.215302    0.830131    8181.25       




[J51          38250       0.000294176  0.26529     0.210302    0.835504    8344.79       




[J52          39000       0.000273315  0.267886    0.215116    0.831315    8508.41       




[J53          39750       0.000252969  0.266444    0.212891    0.8339      8672.11       




[J54          40500       0.000233183  0.265562    0.209174    0.835014    8835.66       




[J55          41250       0.000213999  0.266421    0.217417    0.834841    8999.27       




[J56          42000       0.000195462  0.263332    0.214576    0.836428    9162.85       




[J57          42750       0.000177611  0.262864    0.210108    0.836781    9326.5        




[J58          43500       0.000160488  0.262711    0.21003     0.835353    9490.15       




[J59          44250       0.000144129  0.263994    0.210521    0.835704    9653.71       




[J60          45000       0.000128571  0.261256    0.209983    0.835365    9821.84       




[J61          45750       0.000113848  0.264732    0.216767    0.83406     9985.42       




[J62          46500       9.99941e-05  0.260461    0.205385    0.835548    10148.9       




[J63          47250       8.70389e-05  0.263964    0.212416    0.833944    10312.6       




[J64          48000       7.50115e-05  0.262415    0.208765    0.836252    10476.2       




[J65          48750       6.39385e-05  0.261111    0.210282    0.836785    10639.8       




[J66          49500       5.38446e-05  0.262523    0.206726    0.836526    10803.4       




[J67          50250       4.47521e-05  0.260862    0.209262    0.837096    10967.1       




[J68          51000       3.66811e-05  0.262785    0.208489    0.836909    11130.8       




[J69          51750       2.96497e-05  0.260472    0.206833    0.837417    11294.6       




[J70          52500       2.36735e-05  0.261461    0.209929    0.837587    11458.3       




[J71          53250       1.87656e-05  0.262686    0.208248    0.837625    11622         




[J72          54000       1.49371e-05  0.260881    0.206466    0.837223    11785.9       




[J73          54750       1.21963e-05  0.261014    0.208854    0.836984    11949.5       




[J74          55500       1.05494e-05  0.261902    0.208303    0.836958    12113.2       




[J75          56250       1e-05       0.260222    0.208696    0.835664    12276.9       


# Inference

## Copy best models

In [17]:
best_log_list = []
for pre_eval_cfg, fold_id in zip(pre_eval_cfg_list, FOLDS):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/metric"].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_metric_epoch_{best_epoch}.pth"
    copy_to = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/config.yml", "w") as fw:
        yaml.dump(pre_eval_cfg, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

Unnamed: 0,train/loss,lr,val/loss,val/metric,epoch,iteration,elapsed_time
0,0.25979,1e-05,0.199091,0.848729,75,56250,12244.994898
1,0.261001,1.5e-05,0.209997,0.8321,72,54000,11772.8482
2,0.262636,5.4e-05,0.203506,0.836038,66,49500,10804.761265
3,0.2612,1.2e-05,0.206114,0.845335,73,54750,11943.315481
4,0.262686,1.9e-05,0.208248,0.837625,71,53250,11622.013902


## Inference OOF & Test

In [18]:
def run_inference_loop(cfg, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch["image"], device)
            y = model(x)
            pred_list.append(y.sigmoid().detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [19]:
label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train), N_CLASSES))
score_list = []
test_pred_arr = np.zeros((N_FOLDS, len(smpl_sub), N_CLASSES))
test_path_label = {
    "paths": [DATA / f"test/{img_id[0]}/{img_id}.npy" for img_id in smpl_sub["id"].values],
    "labels": smpl_sub[CLASSES].values.astype("f")
}

for fold_id in FOLDS:
    print(f"\n[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "config.yml", "r") as fr:
        cfg = Config(yaml.safe_load(fr), types=CONFIG_TYPES)
    device = torch.device(cfg["/globals/device"])
    val_idx = train.query("fold == @fold_id").index.values

    # # get_dataloader
    _, val_path_label = get_path_label(cfg, train)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    cfg["/dataset/test"].lazy_init(**test_path_label)
    val_loader = cfg["/loader/val"]
    test_loader = cfg["/loader/test"]
    
    # # get model
    model_path = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    model = cfg["/model"]
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # # inference
    val_pred = run_inference_loop(cfg, model, val_loader, device)
    val_score = roc_auc_score(label_arr[val_idx], val_pred)
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_score])
    
    test_pred_arr[fold_id] = run_inference_loop(cfg, model, test_loader, device)
    
    del cfg, val_idx, val_path_label
    del model, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"val score: {val_score:.4f}")


[fold 0]
load imagenet pretrained: True
mobilenetv2_100: 1280


  0%|          | 0/188 [00:00<?, ?it/s]



  0%|          | 0/625 [00:00<?, ?it/s]



val score: 0.8487

[fold 1]
load imagenet pretrained: True
mobilenetv2_100: 1280


  0%|          | 0/188 [00:00<?, ?it/s]



  0%|          | 0/625 [00:00<?, ?it/s]



val score: 0.8321

[fold 2]
load imagenet pretrained: True
mobilenetv2_100: 1280


  0%|          | 0/188 [00:00<?, ?it/s]



  0%|          | 0/625 [00:00<?, ?it/s]



val score: 0.8360

[fold 3]
load imagenet pretrained: True
mobilenetv2_100: 1280


  0%|          | 0/188 [00:00<?, ?it/s]



  0%|          | 0/625 [00:00<?, ?it/s]



val score: 0.8453

[fold 4]
load imagenet pretrained: True
mobilenetv2_100: 1280


  0%|          | 0/188 [00:00<?, ?it/s]



  0%|          | 0/625 [00:00<?, ?it/s]



val score: 0.8376


In [20]:
oof_score = roc_auc_score(label_arr, oof_pred_arr)
score_list.append(["oof", oof_score])
pd.DataFrame(score_list, columns=["fold", "metric"])

Unnamed: 0,fold,metric
0,0,0.848729
1,1,0.8321
2,2,0.836038
3,3,0.845335
4,4,0.837625
5,oof,0.839739


In [21]:
oof_df = train.copy()
oof_df[CLASSES] = oof_pred_arr
oof_df.to_csv(OUTPUT / "oof_prediction.csv", index=False)

## Make submission

In [22]:
sub_df = smpl_sub.copy()
sub_df[CLASSES] = test_pred_arr.mean(axis=0)
sub_df.to_csv(OUTPUT / "submission.csv", index=False)

In [23]:
sub_df.head()

Unnamed: 0,id,target
0,000bf832cae9ff1,0.071741
1,000c74cc71a1140,0.084769
2,000f5f9851161d3,0.071275
3,000f7499e95aba6,0.126458
4,00133ce6ec257f9,0.10808


# EOF