# Training-Birdclef-2021-Pytorch-resnest50

## About

This notebook is based on the code published [here](https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast), which was very well received in the previous "Cornell Birdcall Identification" competition, and has been modified so that it can be trained on the "BirdCLEF 2021 - Birdcall Identification" data.
In addition, I added some modifications that were not included in the original code, such as using the timm library.
The modifications were made by referring to the public codes [here](https://www.kaggle.com/theoviel/training-a-winning-model) and [here](https://www.kaggle.com/hidehisaarai1213/pytorch-training-birdclef2021-starter).

These public codes have helped me a lot.
I would like to thank @ttahara,@theoviel,@hidehisaarai1213 for making these codes public.

If there are any shortcomings, I would appreciate it if you could point them out.

## Prepare

### import libraries

In [None]:
%%bash
pip install ../input/pytorch-pfn-extras/pytorch-pfn-extras-0.4.1/

In [None]:
!pip install timm

In [None]:
!pip install audiomentations

In [None]:
!apt-get -y install sox

In [None]:
import os
import sys

sys.path = [
    '../input/bird-outputs/src/src/',
] + sys.path


In [None]:
import gc
import time
import shutil
import random
import warnings
import typing as tp
from pathlib import Path
from contextlib import contextmanager

import yaml
from joblib import delayed, Parallel

import cv2
import librosa
import audioread
import soundfile as sf

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
# import resnest.torch as resnest_torch

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions as ppe_extensions

import timm
# from util import f1
from training.mixup import mixup_data
# from params import NUM_WORKERS, NUM_CLASSES
from training.specaugment import SpecAugmentation

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

In [None]:
Path("/root/.cache/torch/checkpoints").mkdir(parents=True, exist_ok=True)

We will use [this data set](https://www.kaggle.com/theoviel/bird-backgrounds) for BackgroundNoise.
However, the audio file appeared to be corrupted.
Therefore, convert the wav file to raw file and then convert the raw file to wav file again.

In [None]:
wav2raw_dir = '/kaggle/working/wav2raw'
raw2wav_dir = '/kaggle/working/raw2wav'

Path(wav2raw_dir).mkdir(parents=True, exist_ok=True)
Path(raw2wav_dir).mkdir(parents=True, exist_ok=True)

In [None]:
import glob
import subprocess
from subprocess import PIPE

tmp_path = Path('../input/bird-backgrounds')
for audio in tmp_path.glob('**/*.wav') :
    raw_file = wav2raw_dir + '/' + audio.name[:-3] + 'raw'
    wav_file = raw2wav_dir + '/' + audio.name
    
    sox2raw = f"sox {audio} {raw_file}"
    proc = subprocess.run(sox2raw, shell=True, stdout=PIPE, stderr=PIPE, text=True)
#     stdout = proc.stdout
#     print('STDOUT: {}'.format(stdout))

    raw2sox = f"sox -t raw -e signed-integer -b 16 -r 32000 {raw_file} -t wav {wav_file}"
    proc = subprocess.run(raw2sox, shell=True, stdout=PIPE, stderr=PIPE, text=True)
#     stdout = proc.stdout
#     print('STDOUT: {}'.format(stdout))


### define utilities

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
#     torch.backends.cudnn.deterministic = True  # type: ignore
#     torch.backends.cudnn.benchmark = True  # type: ignore
    

@contextmanager
def timer(name: str) -> None:
    """Timer Util"""
    t0 = time.time()
    print("[{}] start".format(name))
    yield
    print("[{}] done in {:.0f} s".format(name, time.time() - t0))

### read data

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdclef-2021"
TRAIN_AUDIO_DIR = RAW_DATA / "train_short_audio"
TEST_AUDIO_DIR = RAW_DATA / "test_soundscapes"

In [None]:
train = pd.read_csv(RAW_DATA / "train_metadata.csv")

In [None]:
train.head(3)

### settings

In [None]:
settings_str = """
globals:
  seed: 1213
  device: cuda
  num_epochs: 1
  output_dir: /kaggle/training_output/
  use_fold: 0
  target_sr: 32000

dataset:
  name: SpectrogramDataset
  params:
    img_size: 224
    melspectrogram_parameters:
      n_mels: 128
      fmin: 20
      fmax: 16000
    
split:
  name: StratifiedKFold
  params:
    n_splits: 5
    random_state: 42
    shuffle: True

loader:
  train:
    batch_size: 50
    shuffle: True
    num_workers: 2
    pin_memory: True
    drop_last: True
  val:
    batch_size: 100
    shuffle: False
    num_workers: 2
    pin_memory: True
    drop_last: False

model:
  name: resnest50d_1s4x24d
  params:
    pretrained: True
    n_classes: 397

loss:
  name: BCEWithLogitsLoss
  params: {}

optimizer:
  name: Adam
  params:
    lr: 0.001

scheduler:
  name: CosineAnnealingLR
  params:
    T_max: 10
 
augmentation:
  params:
    specaugment_proba: 0.5
    mixup_proba: 0.5
    alpha: 5
"""

In [None]:
settings = yaml.safe_load(settings_str)

## Definition

### Dataset
* forked from: https://github.com/koukyo1994/kaggle-birdcall-resnet-baseline-training/blob/master/src/dataset.py
* modified partialy


In [None]:
BIRD_CODE = {v: k for k, v in enumerate(train.primary_label.unique())}
INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
from audiomentations import *

BACKGROUND_PATH = '/kaggle/working/raw2wav'

def get_wav_transforms():
    transforms = Compose(
        [
            AddGaussianSNR(max_SNR=0.5, p=0.5),
            AddBackgroundNoise(
                sounds_path=BACKGROUND_PATH, min_snr_in_db=0, max_snr_in_db=2, p=0.5
            ),
        ]
    )

    return transforms

In [None]:
PERIOD = 5

def mono_to_color(
    X: np.ndarray, mean=None, std=None,
    norm_max=None, norm_min=None, eps=1e-6
):
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

class SpectrogramDataset(data.Dataset):
    def __init__(
        self,
        file_list: tp.List[tp.List[str]],img_size=224,train=True,
        waveform_transforms=None, spectrogram_transforms=None, melspectrogram_parameters={}
    ):
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms =  get_wav_transforms() if train else None
        self.spectrogram_transforms = spectrogram_transforms
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code = self.file_list[idx]

        y, sr = sf.read(wav_path)

        if self.waveform_transforms:
            y = self.waveform_transforms(y, sr)
        
        len_y = len(y)
        effective_length = sr * PERIOD
        if len_y < effective_length:
            new_y = np.zeros(effective_length, dtype=y.dtype)
            if self.train:
                start = np.random.randint(effective_length - len_y)
            else:
                start = 0  
            new_y[start:start + len_y] = y
            y = new_y.astype(np.float32)
        elif len_y > effective_length:
            start = np.random.randint(len_y - effective_length)
            y = y[start:start + effective_length].astype(np.float32)
        else:
            y = y.astype(np.float32)

        melspec = librosa.feature.melspectrogram(y, sr=sr, **self.melspectrogram_parameters)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

#         if self.spectrogram_transforms:
#             melspec = self.spectrogram_transforms(melspec)
#         else:
#             pass

        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        
        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return image, labels

### Training Utility

In [None]:
def get_loaders_for_training(
    args_dataset: tp.Dict, args_loader: tp.Dict,
    train_file_list: tp.List[str], val_file_list: tp.List[str]
):
    # # make dataset
    train_dataset = SpectrogramDataset(train_file_list, **args_dataset)
    val_dataset = SpectrogramDataset(val_file_list, **args_dataset, train=False)
    # # make dataloader
    train_loader = data.DataLoader(train_dataset, **args_loader["train"])
    val_loader = data.DataLoader(val_dataset, **args_loader["val"])
    
    return train_loader, val_loader

In [None]:
def get_model(args: tp.Dict):
    model = timm.create_model(args["name"], pretrained=args["params"]["pretrained"])
    del model.fc
    model.fc = nn.Sequential(
        nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(p=0.2),
        nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(p=0.2),
        nn.Linear(1024, args["params"]["n_classes"]))
    
    return model

In [None]:
def train_loop(
    manager, args, model, device,
    train_loader, optimizer, scheduler, loss_func
):
    spec_augmenter = SpecAugmentation(
        time_drop_width=16, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2
    )
    
    """Run minibatch training loop"""
    while not manager.stop_trigger:
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            with manager.run_iteration():
                data, target = data.to(device), target.to(device)
                
                if np.random.rand() < args["augmentation"]["params"]["specaugment_proba"]:
                    data = spec_augmenter(data)

                if np.random.rand() < args["augmentation"]["params"]["mixup_proba"]:
                    data, y_a, y_b, _ = mixup_data(data.cuda(), target.cuda(), alpha=args["augmentation"]["params"]["alpha"])
                    target = torch.clamp(y_a + y_b, 0, 1)

                optimizer.zero_grad()
                output = model(data)
                
                loss = loss_func(output, target)
                ppe.reporting.report({'train/loss': loss.item()})
                loss.backward()
                optimizer.step()
                scheduler.step()

def eval_for_batch(
    args, model, device,
    data, target, loss_func, eval_func_dict={}
):
    """
    Run evaliation for valid
    
    This function is applied to each batch of val loader.
    """
    model.eval()
    data, target = data.to(device), target.to(device)
    output = model(data)
    # Final result will be average of averages of the same size
    val_loss = loss_func(output, target).item()
    ppe.reporting.report({'val/loss': val_loss})
    
    for eval_name, eval_func in eval_func_dict.items():
        eval_value = eval_func(output, target).item()
        ppe.reporting.report({"val/{}".format(eval_aame): eval_value})

In [None]:
def set_extensions(
    manager, args, model, device, test_loader, optimizer,
    loss_func, eval_func_dict={}
):
    """set extensions for PPE"""
        
    my_extensions = [
        # # observe, report
        ppe_extensions.observe_lr(optimizer=optimizer),
        # ppe_extensions.ParameterStatistics(model, prefix='model'),
        # ppe_extensions.VariableStatisticsPlot(model),
        ppe_extensions.LogReport(),
        ppe_extensions.PlotReport(['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
        ppe_extensions.PlotReport(['lr',], 'epoch', filename='lr.png'),
        ppe_extensions.PrintReport([
            'epoch', 'iteration', 'lr', 'train/loss', 'val/loss', "elapsed_time"]),
#         ppe_extensions.ProgressBar(update_interval=100),

        # # evaluation
        (
            ppe_extensions.Evaluator(
                test_loader, model,
                eval_func=lambda data, target:
                    eval_for_batch(args, model, device, data, target, loss_func, eval_func_dict),
                progress_bar=True),
            (1, "epoch"),
        ),
        # # save model snapshot.
        (
            ppe_extensions.snapshot(
                target=model, filename="snapshot_epoch_{.updater.epoch}.pth"),
            ppe.training.triggers.MinValueTrigger(key="val/loss", trigger=(1, 'epoch'))
        ),
    ]
           
    # # set extensions to manager
    for ext in my_extensions:
        if isinstance(ext, tuple):
            manager.extend(ext[0], trigger=ext[1])
        else:
            manager.extend(ext)
        
    return manager

## Training

### prepare data

#### get wav file path

In [None]:
audio_list = [TRAIN_AUDIO_DIR / row['primary_label'] / row['filename'] for _, row, in train.iterrows()]
train_all = train.assign(file_path=audio_list)

In [None]:
train_all.head(3)

#### split data

In [None]:
skf = StratifiedKFold(**settings["split"]["params"])

train_all["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(skf.split(train_all, train_all["primary_label"])):
    train_all.iloc[val_index, -1] = fold_id
    
# # check the propotion
fold_proportion = pd.pivot_table(train_all, index="primary_label", columns="fold", values="filename", aggfunc=len)
print(fold_proportion.shape)

In [None]:
# fold_proportion

In [None]:
use_fold = settings["globals"]["use_fold"]
train_file_list = train_all.query("fold != @use_fold")[["file_path", "primary_label"]].values.tolist()
val_file_list = train_all.query("fold == @use_fold")[["file_path", "primary_label"]].values.tolist()

print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))

## run training

In [None]:
set_seed(settings["globals"]["seed"])
device = torch.device(settings["globals"]["device"])
output_dir = Path(settings["globals"]["output_dir"])

# # # get loader
train_loader, val_loader = get_loaders_for_training(
    settings["dataset"]["params"], settings["loader"], train_file_list, val_file_list)

# # # get model
model = get_model(settings["model"])
model = model.to(device)

# # # get optimizer
optimizer = getattr(
    torch.optim, settings["optimizer"]["name"]
)(model.parameters(), **settings["optimizer"]["params"])

# # # get scheduler
scheduler = getattr(
    torch.optim.lr_scheduler, settings["scheduler"]["name"]
)(optimizer, **settings["scheduler"]["params"])

# # # get loss
loss_func = getattr(nn, settings["loss"]["name"])(**settings["loss"]["params"])

# # # create training manager
trigger = None

manager = ppe.training.ExtensionsManager(
    model, optimizer, settings["globals"]["num_epochs"],
    iters_per_epoch=len(train_loader),
    stop_trigger=trigger,
    out_dir=output_dir
)

# # # set manager extensions
manager = set_extensions(
    manager, settings, model, device,
    val_loader, optimizer, loss_func,
)

In [None]:
# # runtraining
train_loop(
    manager, settings, model, device,
    train_loader, optimizer, scheduler, loss_func)

In [None]:
del train_loader
del val_loader
del model
del optimizer
del scheduler
del loss_func
del manager

gc.collect()

## save results

In [None]:
%%bash
ls /kaggle/training_output

In [None]:
for f_name in ["log","loss.png", "lr.png"]:
    shutil.copy(output_dir / f_name, f_name)

In [None]:
log = pd.read_json("log")
best_epoch = log["val/loss"].idxmin() + 1
log.iloc[[best_epoch - 1],]

In [None]:
shutil.copy(output_dir / "snapshot_epoch_{}.pth".format(best_epoch), "best_model.pth")

In [None]:
m = get_model({
    'name': settings["model"]["name"],
    'params': {
        'pretrained': settings["model"]["params"]["pretrained"], 
        'n_classes': settings["model"]["params"]["n_classes"]
    }})
state_dict = torch.load('best_model.pth')
m.load_state_dict(state_dict)