# About

I'd like to share an example for [small models](https://www.kaggle.com/c/seti-breakthrough-listen/discussion/242644).

### Experimental Settings

### model
* backbone: resnet18d (use the pretrained model provided by [timm](https://github.com/rwightman/pytorch-image-models))
* head classifier: one linear layer
* num of input channels: **1**

### data augmentation
* implemented by [albumentations](https://albumentations.ai/docs/) **except for Mixup**
* Train
  * Resize
  * HorizontalFlip
  * VerticalFlip
  * ShiftScaleRotate
  * RandomResizedCrop
  * Mixup(alpha=1.0)
* Val, Test
  * Resize

### learning settings
* CV Strategy: Stratified KFold (K=5)
* max epochs: 18
* data:
  * input image size: 1x320x320
  * batch size: 64
* loss: [BCEWithLogitsLoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss)
* optimizer: [AdamW](https://pytorch.org/docs/stable/optim.html#torch.optim.AdamW)
  * weight decay: 1.0e-04
* learning rate scheduler: [OneCycleLR](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.OneCycleLR) 
  * epochs: 18
  * max_lr: 1e-3
  * pct_start: 0.111
  * anneal_strategy: cos
  * div_factor: 1.0e+2
  * inal_div_factor: 1
  
### NOTE: I use only on-target ('A') observations

```python
img = np.load(path)[[0, 2, 4]]          # shape: (3, 273, 256)
img = np.vstack(img)                    # shape: (819, 256)
img = img.transpose(1, 0)               # shape: (256, 819)
```

### Submission -> 

# Prapere

## Install

## Import

In [1]:
import os
import gc
import copy
import yaml
import random
import shutil
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.config import Config
from pytorch_pfn_extras.training import extensions as ppe_exts, triggers as ppe_triggers

In [2]:
# timm.list_models(pretrained=True)

In [3]:
ROOT = Path.cwd()
INPUT = ROOT / "../../"
OUTPUT = ROOT / "output"
DATA = INPUT / "seti-breakthrough-listen"
TRAIN = DATA / "train"
TEST = DATA / "test"

TMP = ROOT / "tmp"
TMP.mkdir(exist_ok=True)

RANDAM_SEED = 1086
CLASSES = ["target"]
N_CLASSES = len(CLASSES)
FOLDS = [0, 1, 2, 3, 4]
N_FOLDS = len(FOLDS)

## Read Data, Split folds

In [4]:
train = pd.read_csv(DATA / "train_labels.csv")
smpl_sub = pd.read_csv(DATA / "sample_submission.csv")

In [5]:
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDAM_SEED)
train["fold"] = -1
for fold_id, (_, val_idx) in enumerate(skf.split(train["id"], train["target"])):
    train.loc[val_idx, "fold"] = fold_id

In [6]:
train.groupby("fold").agg(total=("id", len), pos=("target", sum))

Unnamed: 0_level_0,total,pos
fold,Unnamed: 1_level_1,Unnamed: 2_level_1
0,12000,1200
1,12000,1200
2,12000,1200
3,12000,1200
4,12000,1200


## Definition of Model, Dataset, Metric

### Model

In [7]:
class BasicImageModel(nn.Module):
    
    def __init__(
        self, base_name: str, dims_head: tp.List[int],
        pretrained=False, in_channels: int=3
    ):
        """Initialize"""
        self.base_name = base_name
        super(BasicImageModel, self).__init__()
        
        # # prepare backbone
        if hasattr(timm.models, base_name):
            base_model = timm.create_model(
                base_name, num_classes=0, pretrained=pretrained, in_chans=in_channels)
            in_features = base_model.num_features
            print("load imagenet pretrained:", pretrained)
        else:
            raise NotImplementedError

        self.backbone = base_model
        print(f"{base_name}: {in_features}")
        
        # # prepare head clasifier
        if dims_head[0] is None:
            dims_head[0] = in_features

        layers_list = []
        for i in range(len(dims_head) - 2):
            in_dim, out_dim = dims_head[i: i + 2]
            layers_list.extend([
                nn.Linear(in_dim, out_dim),
                nn.ReLU(), nn.Dropout(0.5),
            ])
        layers_list.append(
            nn.Linear(dims_head[-2], dims_head[-1])
        )
        self.head_cls = nn.Sequential(*layers_list)

    def forward(self, x):
        """Forward"""
        h = self.backbone(x)
        h = self.head_cls(h)
        return h

### Dataset

In [8]:
FilePath = tp.Union[str, Path]
Label = tp.Union[int, float, np.ndarray]


class SetiSimpleDataset(torch.utils.data.Dataset):
    """
    Dataset using 6 channels by stacking them along time-axis

    Attributes
    ----------
    paths : tp.Sequence[FilePath]
        Sequence of path to cadence snippet file
    labels : tp.Sequence[Label]
        Sequence of label for cadence snippet file
    transform: albumentations.Compose
        composed data augmentations for data
    """

    def __init__(
        self,
        paths: tp.Sequence[FilePath],
        labels: tp.Sequence[Label],
        transform: A.Compose,
    ):
        """Initialize"""
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """Return num of cadence snippets"""
        return len(self.paths)

    def __getitem__(self, index: int):
        """Return transformed image and label for given index."""
        path, label = self.paths[index], self.labels[index]
        img = self._read_cadence_array(path)
        img = self.transform(image=img)["image"]
        return {"image": img, "target": label}

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)  # shape: (6, 273, 256)
        img = np.vstack(img)  # shape: (1638, 256)
        img = img.transpose(1, 0)  # shape: (256, 1638)
        img = img.astype("f")[..., np.newaxis]  # shape: (256, 1638, 1)
        return img

    def lazy_init(self, paths=None, labels=None, transform=None):
        """Reset Members"""
        if paths is not None:
            self.paths = paths
        if labels is not None:
            self.labels = labels
        if transform is not None:
            self.transform = transform


class SetiAObsDataset(SetiSimpleDataset):
    """Use only on-target observation"""

    def _read_cadence_array(self, path: Path):
        """Read cadence file and reshape"""
        img = np.load(path)[[0, 2, 4]]  # shape: (3, 273, 256)
        img = np.vstack(img)  # shape: (819, 256)
        img = img.transpose(1, 0)  # shape: (256, 819)
        img = img.astype("f")[..., np.newaxis]  # shape: (819, 256, 1)
        return img

### Metric

In [9]:
Batch = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]]
ModelOut = tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor], torch.Tensor]


class ROCAUC(nn.Module):
    """ROC AUC score"""

    def __init__(self, average="macro") -> None:
        """Initialize."""
        self.average = average
        super(ROCAUC, self).__init__()

    def forward(self, y, t) -> float:
        """Forward."""
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        if isinstance(t, torch.Tensor):
            t = t.detach().cpu().numpy()

        return roc_auc_score(t, y, average=self.average)


def micro_average(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """Return Metric Wrapper for Simple Mean Metric"""
    metric_sum = [0.]
    n_examples = [0]
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple): 
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        metric = metric_func(y, t).item()
        metric_sum[0] += metric * y.shape[0]
        n_examples[0] += y.shape[0]

        if is_last_batch:
            final_metric = metric_sum[0] / n_examples[0]
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            metric_sum[0] = 0.
            n_examples[0] = 0

    return wrapper


def calc_across_all_batchs(
    metric_func: nn.Module,
    report_name: str, prefix="val",
    pred_index: int=-1, label_index: int=-1,
    pred_key: str="logit", label_key: str="target",
) -> tp.Callable:
    """
    Return Metric Wrapper for Metrics caluculated on all data
    
    storing predictions and labels of evry batch, finally calculating metric on them.
    """
    pred_list = []
    label_list = []
    
    def wrapper(batch: Batch, model_output: ModelOut, is_last_batch: bool):
        """Wrapping metric function for evaluation"""
        if isinstance(batch, tuple):
            t = batch[label_index]
        elif isinstance(batch, dict):
            t = batch[label_key]
        else:
            raise NotImplementedError

        if isinstance(model_output, tuple):
            y = model_output[pred_index]
        elif isinstance(model_output, dict):
            y = model_output[pred_key]
        else:
            y = model_output

        pred_list.append(y.numpy())
        label_list.append(t.numpy())

        if is_last_batch:
            pred = np.concatenate(pred_list, axis=0)
            label = np.concatenate(label_list, axis=0)
            final_metric = metric_func(pred, label)
            ppe.reporting.report({f"{prefix}/{report_name}": final_metric})
            # # reset state
            pred_list[:] = []
            label_list[:] = []

    return wrapper

# Train

## config_types for evaluating configuration

I use [pytorch-pfn-extras](https://github.com/pfnet/pytorch-pfn-extras) for training NNs. This library has useful config systems but requires some preparation.

For more details, see [docs](https://github.com/pfnet/pytorch-pfn-extras/blob/master/docs/config.md).

In [10]:
CONFIG_TYPES = {
    # # utils
    "__len__": lambda obj: len(obj),
    "method_call": lambda obj, method: getattr(obj, method)(),

    # # Dataset, DataLoader
    "SetiSimpleDataset": SetiSimpleDataset,
    "SetiAObsDataset": SetiAObsDataset,
    "DataLoader": torch.utils.data.DataLoader,

    # # Data Augmentation
    "Compose": A.Compose, "OneOf": A.OneOf,
    "Resize": A.Resize,
    "HorizontalFlip": A.HorizontalFlip, "VerticalFlip": A.VerticalFlip,
    "ShiftScaleRotate": A.ShiftScaleRotate,
    "RandomResizedCrop": A.RandomResizedCrop,
    "Cutout": A.Cutout,
    "ToTensorV2": ToTensorV2,

    # # Model
    "BasicImageModel": BasicImageModel,

    # # Optimizer
    "AdamW": optim.AdamW,

    # # Scheduler
    "OneCycleLR": lr_scheduler.OneCycleLR,

    # # Loss,Metric
    "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
    "ROCAUC": ROCAUC,

    # # Metric Wrapper
    "micro_average": micro_average,
    "calc_across_all_batchs": calc_across_all_batchs,

    # # PPE Extensions
    "ExtensionsManager": ppe.training.ExtensionsManager,

    "observe_lr": ppe_exts.observe_lr,
    "LogReport": ppe_exts.LogReport,
    "PlotReport": ppe_exts.PlotReport,
    "PrintReport": ppe_exts.PrintReport,
    "PrintReportNotebook": ppe_exts.PrintReportNotebook,
    "ProgressBar": ppe_exts.ProgressBar,
    "ProgressBarNotebook": ppe_exts.ProgressBarNotebook,
    "snapshot": ppe_exts.snapshot,
    "LRScheduler": ppe_exts.LRScheduler, 

    "MinValueTrigger": ppe_triggers.MinValueTrigger,
    "MaxValueTrigger": ppe_triggers.MaxValueTrigger,
    "EarlyStoppingTrigger": ppe_triggers.EarlyStoppingTrigger,
}

## configration

In [11]:
pre_eval_cfg = yaml.safe_load(
"""
globals:
  seed: 1086
  val_fold: null  # indicate when training
  output_path: null # indicate when training
  device: cuda:0
  enable_amp: False
  max_epoch: 100

model:
  type: BasicImageModel
  dims_head: [null, 1]
  base_name: efficientnetv2_rw_s
  pretrained: True
  in_channels: 1

dataset:
  height: 320
  width: 320
  mixup: {enabled: True, alpha: 1.0}
  train:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: HorizontalFlip, p: 0.5}
        - {type: VerticalFlip, p: 0.5}
        - {type: ShiftScaleRotate, p: 0.5, shift_limit: 0.2, scale_limit: 0.2,
            rotate_limit: 20, border_mode: 0, value: 0, mask_value: 0}
        - {type: RandomResizedCrop, p: 1.0,
            scale: [0.9, 1.0], height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}
  val:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform:
      type: Compose
      transforms:
        - {type: Resize, p: 1.0, height: "@/dataset/height", width: "@/dataset/width"}
        - {type: ToTensorV2, always_apply: True}  
  test:
    type: SetiAObsDataset
    paths: null  # set by lazy_init
    labels: null  # set by lazy_init
    transform: "@/dataset/val/transform"

loader:
  train: {type: DataLoader, dataset: "@/dataset/train",
    batch_size: 25, num_workers: 4, shuffle: True, pin_memory: True, drop_last: True}
  val: {type: DataLoader, dataset: "@/dataset/val",
    batch_size: 25, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}
  test: {type: DataLoader, dataset: "@/dataset/test",
    batch_size: 25, num_workers: 4, shuffle: False, pin_memory: True, drop_last: False}

optimizer:
  type: AdamW
  params: {type: method_call, obj: "@/model", method: parameters}
  lr: 1.0e-05
  weight_decay: 1.0e-04

scheduler:
  type: OneCycleLR
  optimizer: "@/optimizer"
  epochs: "@/globals/max_epoch"
  steps_per_epoch: {type: __len__, obj: "@/loader/train"}
  max_lr: 1.0e-3
  pct_start: 0.111
  anneal_strategy: cos
  div_factor: 1.0e+2
  final_div_factor: 1

loss: {type: BCEWithLogitsLoss}

eval:
  - type: micro_average
    metric_func: {type: BCEWithLogitsLoss}
    report_name: loss
  - type: calc_across_all_batchs
    metric_func: {type: ROCAUC}
    report_name: metric

manager:
  type: ExtensionsManager
  models: "@/model"
  optimizers: "@/optimizer"
  max_epochs: "@/globals/max_epoch"
  iters_per_epoch: {type: __len__, obj: "@/loader/train"}
  out_dir: "@/globals/output_path"
  #  stop_trigger: {type: EarlyStoppingTrigger,
  #    monitor: val/metric, mode: max, patience: 5, verbose: True,
  #    check_trigger: [1, epoch], max_trigger: ["@/globals/max_epoch", epoch]}

extensions:
  # # log
  - {type: observe_lr, optimizer: "@/optimizer"}
  - {type: LogReport}
  - {type: PlotReport, y_keys: lr, x_key: epoch, filename: lr.png}
  - {type: PlotReport, y_keys: [train/loss, val/loss], x_key: epoch, filename: loss.png}
  - {type: PlotReport, y_keys: val/metric, x_key: epoch, filename: metric.png}
  - {type: PrintReport, entries: [
      epoch, iteration, lr, train/loss, val/loss, val/metric, elapsed_time]}
  - {type: ProgressBarNotebook, update_interval: 20}
  # snapshot
  - extension: {type: snapshot, target: "@/model", filename: "snapshot_by_metric_epoch_{.epoch}.pth"}
    trigger: {type: MaxValueTrigger, key: "val/metric", trigger: [1, epoch]}
  # # lr scheduler
  - {type: LRScheduler, scheduler: "@/scheduler", trigger: [1,  iteration]}
"""
)

## functions for training

In [12]:
def set_random_seed(seed: int = 42, deterministic: bool = False):
    """Set seeds"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = deterministic  # type: ignore


def to_device(
    tensors: tp.Union[tp.Tuple[torch.Tensor], tp.Dict[str, torch.Tensor]],
    device: torch.device, *args, **kwargs
):
    if isinstance(tensors, tuple):
        return (t.to(device, *args, **kwargs) for t in tensors)
    elif isinstance(tensors, dict):
        return {
            k: t.to(device, *args, **kwargs) for k, t in tensors.items()}
    else:
        return tensors.to(device, *args, **kwargs)

In [13]:
def get_path_label(cfg: Config, train_all: pd.DataFrame):
    """Get file path and target info."""
    use_fold = cfg["/globals/val_fold"]

    train_df = train_all[train_all["fold"] != use_fold]
    val_df = train_all[train_all["fold"] == use_fold]
    
    train_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in train_df["id"].values],
        "labels": train_df[CLASSES].values.astype("f")}
    val_path_label = {
        "paths": [TRAIN / f"{img_id[0]}/{img_id}.npy" for img_id in val_df["id"].values],
        "labels": val_df[CLASSES].values.astype("f")
    }
    return train_path_label, val_path_label


def get_eval_func(cfg, model, device):
    
    def eval_func(**batch):
        """Run evaliation for val or test. This function is applied to each batch."""
        batch = to_device(batch, device)
        x = batch["image"]
        with amp.autocast(cfg["/globals/enable_amp"]): 
            y = model(x)
        return y.detach().cpu().to(torch.float32)  # input of metrics

    return eval_func


def mixup_data(use_mixup, x, t, alpha=1.0, use_cuda=True, device="cuda:0"):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if not use_mixup:
        return x, t, None, None
    
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda(device)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    t_a, t_b = t, t[index]
    return mixed_x, t_a, t_b, lam


def get_criterion(use_mixup, loss_func):

    def mixup_criterion(pred, t_a, t_b, lam):
        return lam * loss_func(pred, t_a) + (1 - lam) * loss_func(pred, t_b)

    def single_criterion(pred, t_a, t_b, lam):
        return loss_func(pred, t_a)
    
    if use_mixup:
        return mixup_criterion
    else:
        return single_criterion

In [14]:
def train_one_fold(cfg, train_all):
    """Main"""
    torch.backends.cudnn.benchmark = True
    set_random_seed(cfg["/globals/seed"], deterministic=True)
    device = torch.device(cfg["/globals/device"])
    
    train_path_label, val_path_label = get_path_label(cfg, train_all)
    print("train: {}, val: {}".format(len(train_path_label["paths"]), len(val_path_label["paths"])))
   
    cfg["/dataset/train"].lazy_init(**train_path_label)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    train_loader = cfg["/loader/train"]
    val_loader = cfg["/loader/val"]

    model = cfg["/model"]
    model.to(device)
    optimizer = cfg["/optimizer"]
    loss_func = cfg["/loss"]
    loss_func.to(device)
    
    manager = cfg["/manager"]
    for ext in cfg["/extensions"]:
        if isinstance(ext, dict):
            manager.extend(**ext)
        else:
            manager.extend(ext)

    evaluator = ppe_exts.Evaluator(
        val_loader, model, eval_func=get_eval_func(cfg, model, device),
        metrics=cfg["/eval"], progress_bar=False)
    manager.extend(evaluator, trigger=(1, "epoch"))

    use_amp = cfg["/globals/enable_amp"]
    scaler = amp.GradScaler(enabled=use_amp)
    use_mixup = cfg["/dataset/mixup/enabled"]
    mixup_alpha = cfg["/dataset/mixup/alpha"]
    
    while not manager.stop_trigger:
        model.train()
        for batch in train_loader:
            with manager.run_iteration():
                batch = to_device(batch, device)
                x, t = batch["image"], batch["target"]
                # # for mixup
                mixed_x, t_a, t_b, lam = mixup_data(use_mixup, x, t, mixup_alpha, device=cfg["/globals/device"])
                criterion = get_criterion(use_mixup, loss_func)
                
                optimizer.zero_grad()
                with amp.autocast(use_amp):
                    y = model(mixed_x)
                    loss = criterion(y, t_a, t_b, lam)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                ppe.reporting.report({'train/loss': loss.item()})

## run train

In [15]:
pre_eval_cfg_list = []
for fold_id in FOLDS:
    tmp_cfg = copy.deepcopy(pre_eval_cfg)
    tmp_cfg["globals"]["val_fold"] = fold_id
    tmp_cfg["globals"]["output_path"] = str(TMP / f"fold{fold_id}")
    pre_eval_cfg_list.append(tmp_cfg)

In [None]:
for pre_eval_cfg in pre_eval_cfg_list:
    cfg = Config(pre_eval_cfg, types=CONFIG_TYPES)
    print(f"\n[fold {cfg['/globals/val_fold']}]")
    train_one_fold(cfg, train)
    with torch.cuda.device(cfg["/globals/device"]):
        torch.cuda.empty_cache()
    del cfg
    gc.collect()


[fold 0]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…

epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.351755    0.368666    0.577105    561.232       




[J2           3840        8.71758e-05  0.308377    0.260624    0.757121    1125.4        




[J3           5760        0.000177926  0.293959    0.24584     0.784233    1689.59       




[J4           7680        0.000294704  0.288816    0.254956    0.781617    2254.24       




[J5           9600        0.000428217  0.289676    0.241164    0.796126    2818.77       




[J6           11520       0.000567841  0.290633    0.22913     0.802372    3383.46       




[J7           13440       0.000702463  0.290717    0.264974    0.759294    3948.34       




[J8           15360       0.000821372  0.291541    0.237573    0.790133    4513.15       




[J9           17280       0.000915105  0.292948    0.240325    0.785345    5077.64       




[J10          19200       0.000976202  0.291698    0.235459    0.796163    5642.29       




[J11          21120       0.000999802  0.290859    0.237104    0.798526    6206.58       




[J12          23040       0.00099975  0.288778    0.236622    0.789414    6771.32       




[J13          24960       0.000998885  0.28558     0.223824    0.808853    7335.75       




[J14          26880       0.000997403  0.285702    0.22564     0.796309    7900.52       




[J15          28800       0.000995306  0.284829    0.224706    0.809563    8465.34       




[J16          30720       0.000992598  0.283286    0.216788    0.817306    9030.92       




[J17          32640       0.00098928  0.28187     0.222634    0.822022    9597.07       




[J18          34560       0.000985357  0.280623    0.21997     0.815497    10163.6       




[J19          36480       0.000980835  0.280411    0.219111    0.820924    10729.5       




[J20          38400       0.000975719  0.279878    0.223109    0.812077    11295.9       




[J21          40320       0.000970015  0.279478    0.213626    0.819472    11862         




[J22          42240       0.00096373  0.276883    0.21233     0.824068    12428.6       




[J23          44160       0.000956872  0.278363    0.204401    0.830382    12994.8       




[J24          46080       0.000949451  0.276266    0.213779    0.825946    13561.6       




[J25          48000       0.000941474  0.274737    0.210323    0.834491    14127.5       




[J26          49920       0.000932952  0.273153    0.207555    0.835176    14693.7       




[J27          51840       0.000923896  0.275094    0.208211    0.836547    15259.1       




[J28          53760       0.000914316  0.274191    0.202399    0.840739    15825.3       




[J29          55680       0.000904226  0.272641    0.206784    0.838385    16391.3       




[J30          57600       0.000893637  0.272933    0.205575    0.832327    16957.2       




[J31          59520       0.000882563  0.27171     0.204972    0.840284    17523.2       




[J32          61440       0.000871017  0.270403    0.200325    0.845145    18089.4       




[J33          63360       0.000859015  0.270115    0.201708    0.843905    18656.1       




[J34          65280       0.00084657  0.270016    0.200269    0.841674    19222         




[J35          67200       0.000833699  0.270344    0.203225    0.836753    19788.6       




[J36          69120       0.000820417  0.269167    0.202028    0.836386    20354.9       




[J37          71040       0.000806742  0.268528    0.196989    0.844952    20921         




[J38          72960       0.000792689  0.268213    0.202268    0.842397    21487.4       




[J39          74880       0.000778278  0.267151    0.199721    0.847202    22053.5       




[J40          76800       0.000763525  0.265627    0.199301    0.848159    22619.7       




[J41          78720       0.000748449  0.267576    0.198412    0.846281    23186.2       




[J42          80640       0.00073307  0.266818    0.19943     0.844821    23751.6       




[J43          82560       0.000717405  0.266719    0.191038    0.850171    24317.5       




[J44          84480       0.000701476  0.26516     0.192286    0.850817    24883.6       




[J45          86400       0.000685301  0.264391    0.196699    0.846094    25448.7       




[J46          88320       0.000668901  0.264107    0.190705    0.850363    26013.8       




[J47          90240       0.000652296  0.2643      0.193779    0.848659    26580.1       




[J48          92160       0.000635507  0.263353    0.190363    0.851367    27145.6       




[J49          94080       0.000618556  0.263905    0.192392    0.851356    27711.3       




[J50          96000       0.000601462  0.263183    0.194866    0.85227     28276.3       




[J51          97920       0.000584249  0.26232     0.18714     0.853285    28841.6       




[J52          99840       0.000566936  0.261574    0.191432    0.851081    29407.5       




[J53          101760      0.000549546  0.261845    0.189553    0.851833    29974         




[J54          103680      0.0005321   0.260775    0.186983    0.855213    30540         




[J55          105600      0.00051462  0.259568    0.188488    0.859646    31105.9       




[J56          107520      0.000497129  0.26064     0.189463    0.855088    31670.9       




[J57          109440      0.000479647  0.260094    0.192693    0.855874    32235.8       




[J58          111360      0.000462197  0.257459    0.190527    0.854168    32801.1       




[J59          113280      0.0004448   0.259161    0.187098    0.854818    33366.2       




[J60          115200      0.000427479  0.25915     0.183834    0.860655    33931.4       




[J61          117120      0.000410254  0.25845     0.185283    0.856481    34496.6       




[J62          119040      0.000393147  0.256756    0.185343    0.859101    35061.8       




[J63          120960      0.000376181  0.257409    0.180879    0.860705    35626.4       




[J64          122880      0.000359375  0.257412    0.182283    0.86009     36191.4       




[J65          124800      0.000342751  0.257239    0.181384    0.86071     36756.1       




[J66          126720      0.000326329  0.255755    0.183383    0.859426    37321         




[J67          128640      0.000310131  0.256359    0.181872    0.861665    37885.8       




[J68          130560      0.000294176  0.256206    0.184367    0.860452    38451.5       




[J69          132480      0.000278484  0.25493     0.180379    0.861647    39016.2       




[J70          134400      0.000263075  0.254356    0.179028    0.861829    39581.3       




[J71          136320      0.000247968  0.256071    0.182715    0.857917    40145.9       




[J72          138240      0.000233183  0.253631    0.182741    0.861434    40710.8       




[J73          140160      0.000218736  0.253189    0.180068    0.86057     41275.8       




[J74          142080      0.000204647  0.253705    0.181046    0.862901    41840.9       




[J75          144000      0.000190933  0.253518    0.181528    0.863468    42406         




[J76          145920      0.000177611  0.253932    0.178177    0.864199    42971.8       




[J77          147840      0.000164698  0.252846    0.177719    0.864396    43537.1       




[J78          149760      0.00015221  0.253153    0.181165    0.862317    44102.3       




[J79          151680      0.000140163  0.251638    0.179942    0.862707    44667.2       




[J80          153600      0.000128571  0.251999    0.179312    0.864211    45231.9       




[J81          155520      0.000117449  0.251223    0.180418    0.863693    45797.2       




[J82          157440      0.000106811  0.25097     0.182092    0.864651    46362         




[J83          159360      9.66699e-05  0.250541    0.181817    0.866219    46926.9       




[J84          161280      8.70389e-05  0.250641    0.181045    0.862958    47491.8       




[J85          163200      7.79299e-05  0.250474    0.177978    0.866736    48057.3       




[J86          165120      6.93541e-05  0.250602    0.176485    0.867049    48622.5       




[J87          167040      6.13224e-05  0.248718    0.176885    0.867071    49187.5       




[J88          168960      5.38446e-05  0.251822    0.175794    0.868366    49752.4       




[J89          170880      4.69302e-05  0.250322    0.177668    0.867344    50318         




[J90          172800      4.05877e-05  0.248615    0.177552    0.868879    50882.8       




[J91          174720      3.48252e-05  0.25016     0.177112    0.8671      51448.3       




[J92          176640      2.96497e-05  0.249449    0.179389    0.867565    52013.1       




[J93          178560      2.50679e-05  0.248913    0.175857    0.868209    52578.1       




[J94          180480      2.10853e-05  0.250701    0.174678    0.867238    53142.7       




[J95          182400      1.77069e-05  0.249962    0.176246    0.867509    53708.6       




[J96          184320      1.49371e-05  0.249798    0.175248    0.86756     54273.5       




[J97          186240      1.27791e-05  0.248987    0.175175    0.867792    54838.3       




[J98          188160      1.12358e-05  0.248571    0.178229    0.867629    55403.1       




[J99          190080      1.0309e-05  0.249796    0.176304    0.867872    55967.7       




[J100         192000      1e-05       0.250206    0.175772    0.86805     56532.7       

[fold 1]
train: 48000, val: 12000
load imagenet pretrained: True
efficientnetv2_rw_s: 1792


VBox(children=(HBox(children=(FloatProgress(value=0.0, bar_style='info', description='total', max=1.0), HTML(v…



epoch       iteration   lr          train/loss  val/loss    val/metric  elapsed_time




[J1           1920        2.96752e-05  0.352073    0.329914    0.569657    563.823       




[J2           3840        8.71758e-05  0.308112    0.270538    0.727613    1128.74       




[J3           5760        0.000177926  0.293804    0.244363    0.766193    1693.68       




[J4           7680        0.000294704  0.28763     0.246545    0.775053    2259.1        




[J5           9600        0.000428217  0.286901    0.238275    0.769538    2824.33       




[J6           11520       0.000567841  0.28921     0.264982    0.780397    3389.27       




[J7           13440       0.000702463  0.289438    0.245115    0.77195     3954.15       




[J8           15360       0.000821372  0.290819    0.255631    0.746785    4519.35       




[J9           17280       0.000915105  0.290346    0.259375    0.758322    5083.67       




[J10          19200       0.000976202  0.290312    0.243243    0.770911    5648.59       




[J11          21120       0.000999802  0.289729    0.244101    0.774737    6213.22       




[J12          23040       0.00099975  0.288883    0.244635    0.782046    6778.06       




[J13          24960       0.000998885  0.285611    0.23663     0.771652    7343.56       




[J14          26880       0.000997403  0.284948    0.251742    0.775484    7908.72       




[J15          28800       0.000995306  0.283862    0.235031    0.780348    8473.46       




[J16          30720       0.000992598  0.282471    0.233247    0.794564    9038.21       




[J17          32640       0.00098928  0.280462    0.230199    0.792181    9603.33       




[J18          34560       0.000985357  0.280057    0.233387    0.800358    10168.3       




[J19          36480       0.000980835  0.279097    0.225595    0.796687    10733.6       




[J20          38400       0.000975719  0.277793    0.22555     0.803815    11298         




[J21          40320       0.000970015  0.279142    0.226335    0.796652    11863.4       




[J22          42240       0.00096373  0.275784    0.22427     0.803807    12428.2       




[J23          44160       0.000956872  0.275678    0.221094    0.798226    12993.2       




[J24          46080       0.000949451  0.274753    0.223081    0.805593    13558.4       




[J25          48000       0.000941474  0.274073    0.217308    0.812051    14123.6       




[J26          49920       0.000932952  0.27305     0.220046    0.811528    14688.7       




[J27          51840       0.000923896  0.273237    0.225271    0.800639    15253.5       




[J28          53760       0.000914316  0.272074    0.213647    0.817265    15818.2       




[J29          55680       0.000904226  0.270693    0.21362     0.808076    16383.4       




[J30          57600       0.000893637  0.271056    0.217109    0.812905    16948.4       




[J31          59520       0.000882563  0.269909    0.211038    0.811792    17513.8       




[J32          61440       0.000871017  0.268491    0.211295    0.818128    18078.8       




[J33          63360       0.000859015  0.269689    0.20992     0.820028    18644.1       




[J34          65280       0.00084657  0.269025    0.209464    0.823642    19209.5       




[J35          67200       0.000833699  0.26908     0.214139    0.814423    19774.1       




[J36          69120       0.000820417  0.268147    0.215855    0.816397    20338.9       




[J37          71040       0.000806742  0.2664      0.206264    0.828841    20903.8       




[J38          72960       0.000792689  0.266166    0.207672    0.824396    21469.7       




[J39          74880       0.000778278  0.26536     0.209175    0.821713    22035.1       




[J40          76800       0.000763525  0.266496    0.207178    0.829884    22600         




[J41          78720       0.000748449  0.263251    0.205032    0.827315    23164.8       




[J42          80640       0.00073307  0.264441    0.200268    0.83411     23729.9       




[J43          82560       0.000717405  0.263176    0.207684    0.8329      24294.9       




[J44          84480       0.000701476  0.263813    0.20325     0.830628    24859.8       




[J45          86400       0.000685301  0.263733    0.203845    0.832526    25424.8       




[J46          88320       0.000668901  0.262131    0.206058    0.826515    25989.5       




[J47          90240       0.000652296  0.263283    0.201613    0.83435     26554.5       




# Inference

## Copy best models

In [None]:
best_log_list = []
for pre_eval_cfg, fold_id in zip(pre_eval_cfg_list, FOLDS):
    exp_dir_path = TMP / f"fold{fold_id}"
    log = pd.read_json(exp_dir_path / "log")
    best_log = log.iloc[[log["val/metric"].idxmax()],]
    best_epoch = best_log.epoch.values[0]
    best_log_list.append(best_log)
    
    best_model_path = exp_dir_path / f"snapshot_by_metric_epoch_{best_epoch}.pth"
    copy_to = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    shutil.copy(best_model_path, copy_to)
    
    for p in exp_dir_path.glob("*.pth"):
        p.unlink()
    
    shutil.copytree(exp_dir_path, f"./fold{fold_id}")
    
    with open(f"./fold{fold_id}/config.yml", "w") as fw:
        yaml.dump(pre_eval_cfg, fw)
    
pd.concat(best_log_list, axis=0, ignore_index=True)

## Inference OOF & Test

In [None]:
def run_inference_loop(cfg, model, loader, device):
    model.to(device)
    model.eval()
    pred_list = []
    with torch.no_grad():
        for batch in tqdm(loader):
            x = to_device(batch["image"], device)
            y = model(x)
            pred_list.append(y.sigmoid().detach().cpu().numpy())
        
    pred_arr = np.concatenate(pred_list)
    del pred_list
    return pred_arr

In [None]:
label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train), N_CLASSES))
score_list = []
test_pred_arr = np.zeros((N_FOLDS, len(smpl_sub), N_CLASSES))
test_path_label = {
    "paths": [DATA / f"test/{img_id[0]}/{img_id}.npy" for img_id in smpl_sub["id"].values],
    "labels": smpl_sub[CLASSES].values.astype("f")
}

for fold_id in FOLDS:
    print(f"\n[fold {fold_id}]")
    tmp_dir = Path(f"./fold{fold_id}")
    with open(tmp_dir / "config.yml", "r") as fr:
        cfg = Config(yaml.safe_load(fr), types=CONFIG_TYPES)
    device = torch.device(cfg["/globals/device"])
    val_idx = train.query("fold == @fold_id").index.values

    # # get_dataloader
    _, val_path_label = get_path_label(cfg, train)
    cfg["/dataset/val"].lazy_init(**val_path_label)
    cfg["/dataset/test"].lazy_init(**test_path_label)
    val_loader = cfg["/loader/val"]
    test_loader = cfg["/loader/test"]
    
    # # get model
    model_path = OUTPUT / f"best_metric_model_fold{fold_id}.pth"
    model = cfg["/model"]
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # # inference
    val_pred = run_inference_loop(cfg, model, val_loader, device)
    val_score = roc_auc_score(label_arr[val_idx], val_pred)
    oof_pred_arr[val_idx] = val_pred
    score_list.append([fold_id, val_score])
    
    test_pred_arr[fold_id] = run_inference_loop(cfg, model, test_loader, device)
    
    del cfg, val_idx, val_path_label
    del model, val_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"val score: {val_score:.4f}")

In [None]:
oof_score = roc_auc_score(label_arr, oof_pred_arr)
score_list.append(["oof", oof_score])
pd.DataFrame(score_list, columns=["fold", "metric"])

In [None]:
oof_df = train.copy()
oof_df[CLASSES] = oof_pred_arr
oof_df.to_csv(OUTPUT / "oof_prediction.csv", index=False)

## Make submission

In [None]:
sub_df = smpl_sub.copy()
sub_df[CLASSES] = test_pred_arr.mean(axis=0)
sub_df.to_csv(OUTPUT / "submission.csv", index=False)

In [None]:
sub_df.head()

# EOF