This is my attempt to refactor and clean [this](https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast/) excellent notebook. It's work in progress and I'll update it as I go.

Type-checking, style checking, and formatting done with `black`, `flake8`, `isort`, and `mypy` via [nbQA](https://github.com/nbQA-dev/nbQA)

Configurations used:

`.pre-commit-config.yaml`:

```yaml
  - repo: https://github.com/nbQA-dev/nbQA
    rev: 0.1.20
    hooks:
      - id: nbqa
        args: ['black']
        name: nbqa-black
        additional_dependencies: ['black']
      - id: nbqa
        args: ['flake8']
        name: nbqa-flake8
        additional_dependencies: ['flake8']
      - id: nbqa
        args: ['isort']
        name: nbqa-isort
        additional_dependencies: ['isort']
      - id: nbqa
        args: ['blackdoc']
        name: nbqa-blackdoc
        additional_dependencies: ['blackdoc']
      - id: nbqa
        args: ['mypy']
        name: nbqa-mypy
        additional_dependencies: ['mypy']
      - id: nbqa
        args: ['pydocstyle']
        name: nbqa-pydocstyle
        additional_dependencies: ['pydocstyle']
```

`.nbqa.ini`:

```ini
[black]
addopts = --line-length=96
mutate = 1

[flake8]
config=.flake8

[isort]
addopts = --profile=black
mutate = 1

[blackdoc]
addopts = --line-length=96
mutate = 1

[mypy]
addopts = --ignore-missing-imports --disallow-untyped-defs

[pydocstyle]
addopts = --add-ignore=D100,D101,D105,D103,D107
```

# Birdsong Pytorch Baseline: ResNeSt50-fast (Training)

## About

In this notebook, I try ResNeSt, which is the one of state of the art in image recognition.  

For the fair comparison with @hidehisaarai1213 's [great baseline](https://www.kaggle.com/hidehisaarai1213/inference-pytorch-birdcall-resnet-baseline), I used a model with the same depth and as the same experimental settings as possible. But There are some differences mainly because of the GPU resource limitation.

The experimental settings are as follows:

* Randomly crop 5 seconds for each train audio clip each epoch.
* No augmentation.
* Used pretrained weight of _`ResNeSt50-fast-1s1x64d`_ provided by the authors at [their repository](https://github.com/zhanghang1989/ResNeSt).
* Used `BCELoss`
* Trained **_50_** epoch and saved the weight which got best **_loss_** (this is because f1 score relies on thresholds.)
* `Adam` optimizer (`lr=0.001`) with `CosineAnnealingLR` (`T_max=10`).
* Used `StratifiedKFold(n_splits=5)` to split dataset and used only first fold
* `batch_size`: **_50_**
* melspectrogram parameters
  - `n_mels`: 128
  - `fmin`: 20
  - `fmax`: 16000
* image size: 224x547

I forked a lot of codes such as preprocessing from @hidehisaarai1213 's [notebook](https://www.kaggle.com/hidehisaarai1213/inference-pytorch-birdcall-resnet-baseline) and [GitHub repository](https://github.com/koukyo1994/kaggle-birdcall-resnet-baseline-training). Many thanks!!!


### Note

#### about dataset
I prepared resmpaled train dataset for this notebook, see more details in:
https://www.kaggle.com/c/birdsong-recognition/discussion/164197


#### about custom packages
In this **_training notebook_**, I used two custom packages, `pytorch-pfn-extras` for training and the authors' official implementation of `ResNeSt` for building model.  
On the other hand, as stated in [code requirements](https://www.kaggle.com/c/birdsong-recognition/overview/code-requirements), participants are **not allowed** to use custom packages in **_submission notebook_**.

If you fork this notebook, keep the above things in mind.


### Reference

#### ResNeSt: Split-Attention Networks
* author: Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Muller, R. Manmatha, Mu Li and Alex Smola 
* paper: [arXiv 2004.08955](https://arxiv.org/abs/2004.08955)
* code: [GitHub](https://github.com/koukyo1994/kaggle-birdcall-resnet-baseline-training)

#### pytorch-pfn-extras
* author: Preferred Networks, Inc.
* code: [GitHub](https://github.com/pfnet/pytorch-pfn-extras)

## Prepare

### import libraries

In [None]:
!pip install ../input/pytorch-pfn-extras/pytorch-pfn-extras-0.2.1/
!pip install ../input/resnest50-fast-package/resnest-0.0.6b20200701/resnest/

In [None]:
import gc
import os
import random
import shutil
import typing as tp
from pathlib import Path

import cv2
import librosa
import numpy as np
import pandas as pd
import pytorch_pfn_extras as ppe
import resnest.torch as resnest_torch
import soundfile as sf
import torch
import torch.nn as nn
import torch.utils.data as data
import yaml
from pytorch_pfn_extras.training import extensions as ppe_extensions
from sklearn.model_selection import StratifiedKFold

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

In [None]:
QUICK_RUN = True

if QUICK_RUN:
    NUM_EPOCHS = 1
    N_SPLITS = 2
else:
    NUM_EPOCHS = 50
    N_SPLITS = 5

In [None]:
Path("/root/.cache/torch/checkpoints").mkdir(parents=True)

In [None]:
!cp ../input/resnest50-fast-package/resnest50_fast_*.pth /root/.cache/torch/checkpoints/

### define utilities

In [None]:
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

### read data

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_RESAMPLED_AUDIO_DIRS = [
    INPUT_ROOT / f"birdsong-resampled-train-audio-{i:0>2}" for i in range(5)
]

In [None]:
train = pd.read_csv(TRAIN_RESAMPLED_AUDIO_DIRS[0] / "train_mod.csv")

In [None]:
train.head().T

### settings

In [None]:
settings_str = f"""
globals:
  seed: 1213
  device: cuda
  num_epochs: {NUM_EPOCHS}
  output_dir: /kaggle/training_output/
  use_fold: 0

dataset:
  name: SpectrogramDataset
  params:
    img_size: 224
    melspectrogram_parameters:
      n_mels: 128
      fmin: 20
      fmax: 16000

split:
  name: StratifiedKFold
  params:
    n_splits: {N_SPLITS}
    random_state: 42
    shuffle: True

loader:
  train:
    batch_size: 50
    shuffle: True
    num_workers: 2
    pin_memory: True
    drop_last: True
  val:
    batch_size: 100
    shuffle: False
    num_workers: 2
    pin_memory: True
    drop_last: False

model:
  name: resnest50_fast_1s1x64d
  params:
    pretrained: True
    n_classes: 264

loss:
  name: BCEWithLogitsLoss
  params: {{}}

optimizer:
  name: Adam
  params:
    lr: 0.001

scheduler:
  name: CosineAnnealingLR
  params:
    T_max: 10
"""

In [None]:
settings = yaml.safe_load(settings_str)

### preprocess audio data

Code is forked from: https://github.com/koukyo1994/kaggle-birdcall-resnet-baseline-training/blob/master/input/birdsong-recognition/prepare.py

I modified this partially. 

However, in this notebook, I used uploaded resampled audio because this preprocessing is too heavy for kaggle notebook.

## Definition

### Dataset
* forked from: https://github.com/koukyo1994/kaggle-birdcall-resnet-baseline-training/blob/master/src/dataset.py
* modified partialy


In [None]:
BIRD_CODE = {
    "aldfly": 0,
    "ameavo": 1,
    "amebit": 2,
    "amecro": 3,
    "amegfi": 4,
    "amekes": 5,
    "amepip": 6,
    "amered": 7,
    "amerob": 8,
    "amewig": 9,
    "amewoo": 10,
    "amtspa": 11,
    "annhum": 12,
    "astfly": 13,
    "baisan": 14,
    "baleag": 15,
    "balori": 16,
    "banswa": 17,
    "barswa": 18,
    "bawwar": 19,
    "belkin1": 20,
    "belspa2": 21,
    "bewwre": 22,
    "bkbcuc": 23,
    "bkbmag1": 24,
    "bkbwar": 25,
    "bkcchi": 26,
    "bkchum": 27,
    "bkhgro": 28,
    "bkpwar": 29,
    "bktspa": 30,
    "blkpho": 31,
    "blugrb1": 32,
    "blujay": 33,
    "bnhcow": 34,
    "boboli": 35,
    "bongul": 36,
    "brdowl": 37,
    "brebla": 38,
    "brespa": 39,
    "brncre": 40,
    "brnthr": 41,
    "brthum": 42,
    "brwhaw": 43,
    "btbwar": 44,
    "btnwar": 45,
    "btywar": 46,
    "buffle": 47,
    "buggna": 48,
    "buhvir": 49,
    "bulori": 50,
    "bushti": 51,
    "buwtea": 52,
    "buwwar": 53,
    "cacwre": 54,
    "calgul": 55,
    "calqua": 56,
    "camwar": 57,
    "cangoo": 58,
    "canwar": 59,
    "canwre": 60,
    "carwre": 61,
    "casfin": 62,
    "caster1": 63,
    "casvir": 64,
    "cedwax": 65,
    "chispa": 66,
    "chiswi": 67,
    "chswar": 68,
    "chukar": 69,
    "clanut": 70,
    "cliswa": 71,
    "comgol": 72,
    "comgra": 73,
    "comloo": 74,
    "commer": 75,
    "comnig": 76,
    "comrav": 77,
    "comred": 78,
    "comter": 79,
    "comyel": 80,
    "coohaw": 81,
    "coshum": 82,
    "cowscj1": 83,
    "daejun": 84,
    "doccor": 85,
    "dowwoo": 86,
    "dusfly": 87,
    "eargre": 88,
    "easblu": 89,
    "easkin": 90,
    "easmea": 91,
    "easpho": 92,
    "eastow": 93,
    "eawpew": 94,
    "eucdov": 95,
    "eursta": 96,
    "evegro": 97,
    "fiespa": 98,
    "fiscro": 99,
    "foxspa": 100,
    "gadwal": 101,
    "gcrfin": 102,
    "gnttow": 103,
    "gnwtea": 104,
    "gockin": 105,
    "gocspa": 106,
    "goleag": 107,
    "grbher3": 108,
    "grcfly": 109,
    "greegr": 110,
    "greroa": 111,
    "greyel": 112,
    "grhowl": 113,
    "grnher": 114,
    "grtgra": 115,
    "grycat": 116,
    "gryfly": 117,
    "haiwoo": 118,
    "hamfly": 119,
    "hergul": 120,
    "herthr": 121,
    "hoomer": 122,
    "hoowar": 123,
    "horgre": 124,
    "horlar": 125,
    "houfin": 126,
    "houspa": 127,
    "houwre": 128,
    "indbun": 129,
    "juntit1": 130,
    "killde": 131,
    "labwoo": 132,
    "larspa": 133,
    "lazbun": 134,
    "leabit": 135,
    "leafly": 136,
    "leasan": 137,
    "lecthr": 138,
    "lesgol": 139,
    "lesnig": 140,
    "lesyel": 141,
    "lewwoo": 142,
    "linspa": 143,
    "lobcur": 144,
    "lobdow": 145,
    "logshr": 146,
    "lotduc": 147,
    "louwat": 148,
    "macwar": 149,
    "magwar": 150,
    "mallar3": 151,
    "marwre": 152,
    "merlin": 153,
    "moublu": 154,
    "mouchi": 155,
    "moudov": 156,
    "norcar": 157,
    "norfli": 158,
    "norhar2": 159,
    "normoc": 160,
    "norpar": 161,
    "norpin": 162,
    "norsho": 163,
    "norwat": 164,
    "nrwswa": 165,
    "nutwoo": 166,
    "olsfly": 167,
    "orcwar": 168,
    "osprey": 169,
    "ovenbi1": 170,
    "palwar": 171,
    "pasfly": 172,
    "pecsan": 173,
    "perfal": 174,
    "phaino": 175,
    "pibgre": 176,
    "pilwoo": 177,
    "pingro": 178,
    "pinjay": 179,
    "pinsis": 180,
    "pinwar": 181,
    "plsvir": 182,
    "prawar": 183,
    "purfin": 184,
    "pygnut": 185,
    "rebmer": 186,
    "rebnut": 187,
    "rebsap": 188,
    "rebwoo": 189,
    "redcro": 190,
    "redhea": 191,
    "reevir1": 192,
    "renpha": 193,
    "reshaw": 194,
    "rethaw": 195,
    "rewbla": 196,
    "ribgul": 197,
    "rinduc": 198,
    "robgro": 199,
    "rocpig": 200,
    "rocwre": 201,
    "rthhum": 202,
    "ruckin": 203,
    "rudduc": 204,
    "rufgro": 205,
    "rufhum": 206,
    "rusbla": 207,
    "sagspa1": 208,
    "sagthr": 209,
    "savspa": 210,
    "saypho": 211,
    "scatan": 212,
    "scoori": 213,
    "semplo": 214,
    "semsan": 215,
    "sheowl": 216,
    "shshaw": 217,
    "snobun": 218,
    "snogoo": 219,
    "solsan": 220,
    "sonspa": 221,
    "sora": 222,
    "sposan": 223,
    "spotow": 224,
    "stejay": 225,
    "swahaw": 226,
    "swaspa": 227,
    "swathr": 228,
    "treswa": 229,
    "truswa": 230,
    "tuftit": 231,
    "tunswa": 232,
    "veery": 233,
    "vesspa": 234,
    "vigswa": 235,
    "warvir": 236,
    "wesblu": 237,
    "wesgre": 238,
    "weskin": 239,
    "wesmea": 240,
    "wessan": 241,
    "westan": 242,
    "wewpew": 243,
    "whbnut": 244,
    "whcspa": 245,
    "whfibi": 246,
    "whtspa": 247,
    "whtswi": 248,
    "wilfly": 249,
    "wilsni1": 250,
    "wiltur": 251,
    "winwre3": 252,
    "wlswar": 253,
    "wooduc": 254,
    "wooscj2": 255,
    "woothr": 256,
    "y00475": 257,
    "yebfly": 258,
    "yebsap": 259,
    "yehbla": 260,
    "yelwar": 261,
    "yerwar": 262,
    "yetvir": 263,
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
PERIOD = 5


def mono_to_color(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """
    Make 2D image 3D. Normalise and scale to be between 0 and 255.

    Parameters
    ----------
    X
        2D image.
    eps
        Epsilon, small number to add to std to avoid dividing by zero.

    Returns
    -------
    np.ndarray
        3D image, normalised, scaled.
    """
    original_shape = X.shape
    assert len(X.shape) == 2

    X = np.stack([X, X, X], axis=-1)
    assert X.shape == (*original_shape, 3)

    mean = X.mean()
    X = X - mean
    std = X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    if (_max - _min) > eps:
        V = Xstd
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class SpectrogramDataset(data.Dataset):
    def __init__(
        self,
        file_list: tp.List[tp.List[str]],
        img_size: int,
        melspectrogram_parameters: tp.Dict[str, int],
    ):
        """
        Initialise.

        Parameters
        ----------
        file_list
            List of pairs of elements, which are [file_path, ebird_code].
        img_size
            Desired width of image that'll be obtained by converting audio.
        melspectrogram_parameters
            Parameters to be passed on to `librosa.feature.melspectrogram`.
        """
        assert all(len(i) == 2 for i in file_list)

        self.file_list = file_list
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self) -> int:
        return len(self.file_list)

    def __getitem__(self, idx: int) -> tp.Tuple[np.ndarray, np.ndarray]:
        """
        Get item.

        Steps are:

        - If audio is too short, pad it with zeros.
          If it's too long, randomly select a subset of it.
        - Compute the mel-scaled (power) spectrogram of the signal
        - Convert the power spectrogram to decibel units.
        - Stack signal so as to get 3D image, and normalise

        Parameters
        ----------
        idx
            Index

        Returns
        -------
        image
            Processed version of audio.
        labels
            Array of zeros, equals one in position corresponding
            to this item's target.
        """
        wav_path, ebird_code = self.file_list[idx]

        audio, sound_rate = sf.read(wav_path)

        len_audio = len(audio)
        effective_length = sound_rate * PERIOD

        if len_audio < effective_length:
            new_audio = np.zeros(effective_length, dtype=audio.dtype)
            start = np.random.randint(effective_length - len_audio)
            new_audio[start : start + len_audio] = audio
            audio = new_audio.astype(np.float32)
        elif len_audio > effective_length:
            start = np.random.randint(len_audio - effective_length)
            audio = audio[start : start + effective_length].astype(np.float32)
        else:
            audio = audio.astype(np.float32)

        melspec = librosa.feature.melspectrogram(
            audio, sr=sound_rate, **self.melspectrogram_parameters
        )
        # Note: 512 is the default hop length
        n_mels = self.melspectrogram_parameters["n_mels"]
        assert melspec.shape == (n_mels, 1 + len(audio) // 512)

        melspec = librosa.power_to_db(melspec).astype(np.float32)
        assert melspec.shape == (n_mels, 1 + len(audio) // 512)

        image = mono_to_color(melspec)
        assert image.shape == (n_mels, 1 + len(audio) // 512, 3)

        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        assert image.shape == (self.img_size, int(width / height * self.img_size), 3)

        image = np.moveaxis(image, 2, 0)
        assert image.shape == (3, self.img_size, int(width / height * self.img_size))

        image = (image / 255.0).astype(np.float32)

        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return image, labels

### Training Utility

In [None]:
def get_loaders_for_training(
    args_dataset: tp.Dict,
    args_loader: tp.Dict,
    train_file_list: tp.List[tp.List[str]],
    val_file_list: tp.List[tp.List[str]],
) -> tp.Tuple[data.DataLoader, data.DataLoader]:
    """
    Make dataloaders from datasets and filelists.

    Parameters
    ----------
    args_dataset
        Additional arguments to pass to SpectrogramDataset
    args_loader
        Additional arguments to pass to DataLoader
    train_file_list
        List of pairs of elements, which are [file_path, ebird_code], for train.
    val_file_list
        List of pairs of elements, which are [file_path, ebird_code], for train.
    """
    train_dataset = SpectrogramDataset(train_file_list, **args_dataset)
    val_dataset = SpectrogramDataset(val_file_list, **args_dataset)

    train_loader = data.DataLoader(train_dataset, **args_loader["train"])
    val_loader = data.DataLoader(val_dataset, **args_loader["val"])

    return train_loader, val_loader

In [None]:
def get_model(
    args: tp.Dict[str, tp.Union[str, tp.Dict[str, tp.Union[bool, int]]]]
) -> nn.Module:
    """
    Get pre-trained model and customise head.

    Parameters
    ----------
    args
        Additional arguments to pass to Pytorch model
    """
    name = args["name"]
    assert isinstance(name, str)
    params = args["params"]
    assert isinstance(params, dict)
    pretrained = params["pretrained"]
    assert isinstance(pretrained, bool)
    n_classes = params["n_classes"]
    assert isinstance(n_classes, int)

    model = getattr(resnest_torch, name)(pretrained=pretrained)
    del model.fc
    model.fc = nn.Sequential(
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(1024, 1024),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(1024, n_classes),
    )

    return model

In [None]:
def train_loop(
    manager: ppe.training.ExtensionsManager,
    model: nn.Module,
    device: torch.device,
    train_loader: data.DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    loss_func: tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> None:
    """
    Run minibatch training loop.

    Parameters
    ----------
    manager
        Interface to extend training loop
    model
        Model (not yet fine-tuned)
    device
        GPU, CPU, ...
    optimizer
        Adapts learning rate for each weight.
    scheduler
        Decreases learning rate as training progresses
    loss_func
        Loss function
    """
    while not manager.stop_trigger:
        model.train()
        for batch_idx, (data_, target) in enumerate(train_loader):
            with manager.run_iteration():
                data_, target = data_.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data_)
                loss = loss_func(output, target)
                ppe.reporting.report({"train/loss": loss.item()})
                loss.backward()
                optimizer.step()
                scheduler.step()


def eval_for_batch(
    model: nn.Module,
    device: torch.device,
    data_: torch.Tensor,
    target: torch.Tensor,
    loss_func: tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> None:
    """
    Run evaliation for valid for each batch of val loader.

    Parameters
    ----------
    model
        Trained model
    device
        GPU, CPU, ...
    data_
        features for batch of val data
    target
        target for batch of val data
    loss_func
        Loss function (here, BCEWithLogits)
    """
    img_size = settings["dataset"]["params"]["img_size"]
    batch_size = settings["loader"]["val"]["batch_size"]
    assert data_.shape[0] <= batch_size
    assert data_.shape[1] == 3
    assert data_.shape[2] == img_size
    # Last dimension is the height, which is variable
    assert target.shape[0] <= batch_size
    assert target.shape[1] == len(BIRD_CODE)

    model.eval()
    data_, target = data_.to(device), target.to(device)
    output = model(data_)
    val_loss = loss_func(output, target).item()
    ppe.reporting.report({"val/loss": val_loss})

In [None]:
def set_extensions(
    manager: ppe.training.ExtensionsManager,
    model: nn.Module,
    device: torch.device,
    test_loader: data.DataLoader,
    optimizer: torch.optim.Optimizer,
    loss_func: tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> ppe.training.ExtensionsManager:
    """
    Configure extensions manager.

    The extensions manager will:
    - plot the train and val losses as a function of epoch number
    - plot the learning rate
    - print the train and val losses for each epoch and iteration and the
      time each takes. This happens at every epoch, I think.
    - save a snapshot of the model at each epoch. This happens when val
      loss is lowest, I think.

    Parameters
    ----------
    manager
        Interface to extend training loop
    device
        CPU, GPU, ...
    test_loader
        Loader for evaluation data
    optimizer
        Adapt learning rate for each weight
    loss_func
        Loss function
    """
    evaluation_extensions = (
        ppe_extensions.Evaluator(
            test_loader,
            model,
            eval_func=lambda data_, target: eval_for_batch(
                model, device, data_, target, loss_func
            ),
            progress_bar=True,
        ),
        (1, "epoch"),
    )
    snapshot_extensions = (
        ppe_extensions.snapshot(target=model, filename="snapshot_epoch_{.updater.epoch}.pth"),
        ppe.training.triggers.MinValueTrigger(key="val/loss", trigger=(1, "epoch")),
    )

    my_extensions = [
        ppe_extensions.observe_lr(optimizer=optimizer),
        ppe_extensions.LogReport(),
        ppe_extensions.PlotReport(["train/loss", "val/loss"], "epoch", filename="loss.png"),
        ppe_extensions.PlotReport(["lr"], "epoch", filename="lr.png"),
        ppe_extensions.PrintReport(
            ["epoch", "iteration", "lr", "train/loss", "val/loss", "elapsed_time"]
        ),
        evaluation_extensions,
        snapshot_extensions,
    ]

    for ext in my_extensions:
        if isinstance(ext, tuple):
            manager.extend(ext[0], trigger=ext[1])
        else:
            manager.extend(ext)

    return manager

## Training

### prepare data

#### get wav file path

In [None]:
tmp_list = [
    [ebird_dir.name, wav_file.name, wav_file.as_posix()]
    for audio_dir in TRAIN_RESAMPLED_AUDIO_DIRS
    for ebird_dir in audio_dir.iterdir()
    if ebird_dir.name != "train_mod.csv"
    for wav_file in ebird_dir.iterdir()
]

train_wav_path_exist = pd.DataFrame(
    tmp_list, columns=["ebird_code", "resampled_filename", "file_path"]
)

del tmp_list

train_all = pd.merge(
    train,
    train_wav_path_exist,
    on=["ebird_code", "resampled_filename"],
    how="inner",
    validate="1:1",
)

print(train.shape)
print(train_wav_path_exist.shape)
print(train_all.shape)

In [None]:
train_all.head()

#### split data

In [None]:
skf = StratifiedKFold(**settings["split"]["params"])

train_all["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(
    skf.split(train_all, train_all["ebird_code"])
):
    train_all.iloc[val_index, -1] = fold_id

# check the propotion
fold_proportion = pd.pivot_table(
    train_all, index="ebird_code", columns="fold", values="xc_id", aggfunc=len
)
print(fold_proportion.shape)

In [None]:
fold_proportion

In [None]:
use_fold = settings["globals"]["use_fold"]
train_file_list = train_all.query("fold != @use_fold")[
    ["file_path", "ebird_code"]
].values.tolist()
val_file_list = train_all.query("fold == @use_fold")[
    ["file_path", "ebird_code"]
].values.tolist()

print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))

## run training

In [None]:
set_seed(settings["globals"]["seed"])
device = torch.device(settings["globals"]["device"])
output_dir = Path(settings["globals"]["output_dir"])

train_loader, val_loader = get_loaders_for_training(
    settings["dataset"]["params"], settings["loader"], train_file_list, val_file_list
)

model = get_model(settings["model"])
model = model.to(device)

optimizer = getattr(torch.optim, settings["optimizer"]["name"])(
    model.parameters(), **settings["optimizer"]["params"]
)

scheduler = getattr(torch.optim.lr_scheduler, settings["scheduler"]["name"])(
    optimizer, **settings["scheduler"]["params"]
)

loss_func = getattr(nn, settings["loss"]["name"])(**settings["loss"]["params"])

trigger = None
manager = ppe.training.ExtensionsManager(
    model,
    optimizer,
    settings["globals"]["num_epochs"],
    iters_per_epoch=len(train_loader),
    stop_trigger=trigger,
    out_dir=output_dir,
)

manager = set_extensions(manager, model, device, val_loader, optimizer, loss_func)

In [None]:
train_loop(manager, model, device, train_loader, optimizer, scheduler, loss_func)

In [None]:
del train_loader
del val_loader
del model
del optimizer
del scheduler
del loss_func
del manager

gc.collect()

## save results

In [None]:
for f_name in ["log", "loss.png", "lr.png"]:
    shutil.copy(output_dir / f_name, f_name)

In [None]:
log = pd.read_json("log")
best_epoch = log["val/loss"].idxmin() + 1  # PPE starts counts epochs at 1
log.iloc[[best_epoch - 1]]

In [None]:
shutil.copy(output_dir / "snapshot_epoch_{}.pth".format(best_epoch), "best_model.pth")

Here is how you would load this model for inference / continuing training

In [None]:
m = get_model(
    {"name": settings["model"]["name"], "params": {"pretrained": False, "n_classes": 264}}
)
state_dict = torch.load("best_model.pth")
m.load_state_dict(state_dict)