# Birdsong Pytorch Baseline: ResNeSt50-fast (Inference)

This is my attempt to refactor and clean [this](https://www.kaggle.com/ttahara/inference-birdsong-baseline-resnest50-fast) excellent notebook. It's work in progress and I'll update it as I go. Am doing this mainly for my own understanding, and am sharing in case it's of any use to anyone else.

Type-checking, style checking, and formatting done with `black`, `flake8`, `isort`, and `mypy` via [nbQA](https://github.com/nbQA-dev/nbQA)

Configurations used:

`.pre-commit-config.yaml`:

```yaml
  - repo: https://github.com/nbQA-dev/nbQA
    rev: 0.1.22
    hooks:
      - id: nbqa
        args: ['black']
        name: nbqa-black
        additional_dependencies: ['black']
      - id: nbqa
        args: ['flake8']
        name: nbqa-flake8
        additional_dependencies: ['flake8']
        alias: nbqa-flake8
      - id: nbqa
        args: ['isort']
        additional_dependencies: ['isort']
      - id: nbqa
        args: ['blackdoc']
        name: nbqa-blackdoc
        additional_dependencies: ['blackdoc']
      - id: nbqa
        args: ['mypy']
        name: nbqa-mypy
        alias: nbqa-mypy
        additional_dependencies: ['mypy']
      - id: nbqa
        args: ['pydocstyle']
        name: nbqa-pydocstyle
        additional_dependencies: ['pydocstyle']
```

`.nbqa.ini`:

```ini
[black]
addopts = --line-length=96
mutate = 1

[flake8]
config=.flake8

[isort]
addopts = --profile=black
mutate = 1

[blackdoc]
addopts = --line-length=96
mutate = 1

[mypy]
addopts = --ignore-missing-imports --disallow-untyped-defs

[pydocstyle]
addopts = --add-ignore=D100,D101,D105,D103,D107
```

## About

In this notebook, I try ResNeSt, which is the one of state of the art in image recognition.  

This is a notebook for **_inference & submission_**. I shared training process as another one:  
https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast  
If you want to know experimental details, see it.

Most of this notebook consists of [great baseline](https://www.kaggle.com/hidehisaarai1213/inference-pytorch-birdcall-resnet-baseline) shared by @hidehisaarai1213 .  
Thank you for sharing !

## Prepare

### import libraries

In [None]:
!pip install ../input/resnest50-fast-package/resnest-0.0.6b20200701/resnest/

In [None]:
import os
import random
import time
import typing as tp
import warnings
from contextlib import contextmanager
from pathlib import Path

import cv2
import librosa
import numpy as np
import pandas as pd
import resnest.torch as resnest_torch
import torch
import torch.nn as nn
import torch.utils.data as data
from fastprogress import progress_bar

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

### define utilities

In [None]:
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


@contextmanager
def timer(name: str) -> tp.Iterator[None]:
    """Timer util."""
    t0 = time.time()
    print("[{}] start".format(name))
    yield
    print("[{}] done in {:.0f} s".format(name, time.time() - t0))

In [None]:
set_seed(1213)

### read data

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TEST_AUDIO_DIR = RAW_DATA / "test_audio"

In [None]:
train = pd.read_csv(RAW_DATA / "train.csv")

In [None]:
if not TEST_AUDIO_DIR.exists():
    TEST_AUDIO_DIR = INPUT_ROOT / "birdcall-check" / "test_audio"
    test = pd.read_csv(INPUT_ROOT / "birdcall-check" / "test.csv")
else:
    test = pd.read_csv(RAW_DATA / "test.csv")

In [None]:
train.head()

In [None]:
test.head()

In [None]:
sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
sub.to_csv(
    "submission.csv", index=False
)  # this will be overwritten if everything goes well

### set parameters

In [None]:
TARGET_SR = 32000
model_config: tp.Dict[str, tp.Union[str, int, bool]] = {
    "base_model_name": "resnest50_fast_1s1x64d",
    "pretrained": False,
    "num_classes": 264,
    "trained_weights": "../input/training-birdsong-baseline-resnest50-fast/best_model.pth",
}

melspectrogram_parameters = {"n_mels": 128, "fmin": 20, "fmax": 16000}

## Definition

### Dataset

For `site_3`, I decided to use the same procedure as I did for `site_1` and `site_2`, which is, crop 5 seconds out of the clip and provide prediction on that short clip.
The only difference is that I crop 5 seconds short clip from start to the end of the `site_3` clip and aggeregate predictions for each short clip after I did prediction for all those short clips.

In [None]:
BIRD_CODE = {
    "aldfly": 0,
    "ameavo": 1,
    "amebit": 2,
    "amecro": 3,
    "amegfi": 4,
    "amekes": 5,
    "amepip": 6,
    "amered": 7,
    "amerob": 8,
    "amewig": 9,
    "amewoo": 10,
    "amtspa": 11,
    "annhum": 12,
    "astfly": 13,
    "baisan": 14,
    "baleag": 15,
    "balori": 16,
    "banswa": 17,
    "barswa": 18,
    "bawwar": 19,
    "belkin1": 20,
    "belspa2": 21,
    "bewwre": 22,
    "bkbcuc": 23,
    "bkbmag1": 24,
    "bkbwar": 25,
    "bkcchi": 26,
    "bkchum": 27,
    "bkhgro": 28,
    "bkpwar": 29,
    "bktspa": 30,
    "blkpho": 31,
    "blugrb1": 32,
    "blujay": 33,
    "bnhcow": 34,
    "boboli": 35,
    "bongul": 36,
    "brdowl": 37,
    "brebla": 38,
    "brespa": 39,
    "brncre": 40,
    "brnthr": 41,
    "brthum": 42,
    "brwhaw": 43,
    "btbwar": 44,
    "btnwar": 45,
    "btywar": 46,
    "buffle": 47,
    "buggna": 48,
    "buhvir": 49,
    "bulori": 50,
    "bushti": 51,
    "buwtea": 52,
    "buwwar": 53,
    "cacwre": 54,
    "calgul": 55,
    "calqua": 56,
    "camwar": 57,
    "cangoo": 58,
    "canwar": 59,
    "canwre": 60,
    "carwre": 61,
    "casfin": 62,
    "caster1": 63,
    "casvir": 64,
    "cedwax": 65,
    "chispa": 66,
    "chiswi": 67,
    "chswar": 68,
    "chukar": 69,
    "clanut": 70,
    "cliswa": 71,
    "comgol": 72,
    "comgra": 73,
    "comloo": 74,
    "commer": 75,
    "comnig": 76,
    "comrav": 77,
    "comred": 78,
    "comter": 79,
    "comyel": 80,
    "coohaw": 81,
    "coshum": 82,
    "cowscj1": 83,
    "daejun": 84,
    "doccor": 85,
    "dowwoo": 86,
    "dusfly": 87,
    "eargre": 88,
    "easblu": 89,
    "easkin": 90,
    "easmea": 91,
    "easpho": 92,
    "eastow": 93,
    "eawpew": 94,
    "eucdov": 95,
    "eursta": 96,
    "evegro": 97,
    "fiespa": 98,
    "fiscro": 99,
    "foxspa": 100,
    "gadwal": 101,
    "gcrfin": 102,
    "gnttow": 103,
    "gnwtea": 104,
    "gockin": 105,
    "gocspa": 106,
    "goleag": 107,
    "grbher3": 108,
    "grcfly": 109,
    "greegr": 110,
    "greroa": 111,
    "greyel": 112,
    "grhowl": 113,
    "grnher": 114,
    "grtgra": 115,
    "grycat": 116,
    "gryfly": 117,
    "haiwoo": 118,
    "hamfly": 119,
    "hergul": 120,
    "herthr": 121,
    "hoomer": 122,
    "hoowar": 123,
    "horgre": 124,
    "horlar": 125,
    "houfin": 126,
    "houspa": 127,
    "houwre": 128,
    "indbun": 129,
    "juntit1": 130,
    "killde": 131,
    "labwoo": 132,
    "larspa": 133,
    "lazbun": 134,
    "leabit": 135,
    "leafly": 136,
    "leasan": 137,
    "lecthr": 138,
    "lesgol": 139,
    "lesnig": 140,
    "lesyel": 141,
    "lewwoo": 142,
    "linspa": 143,
    "lobcur": 144,
    "lobdow": 145,
    "logshr": 146,
    "lotduc": 147,
    "louwat": 148,
    "macwar": 149,
    "magwar": 150,
    "mallar3": 151,
    "marwre": 152,
    "merlin": 153,
    "moublu": 154,
    "mouchi": 155,
    "moudov": 156,
    "norcar": 157,
    "norfli": 158,
    "norhar2": 159,
    "normoc": 160,
    "norpar": 161,
    "norpin": 162,
    "norsho": 163,
    "norwat": 164,
    "nrwswa": 165,
    "nutwoo": 166,
    "olsfly": 167,
    "orcwar": 168,
    "osprey": 169,
    "ovenbi1": 170,
    "palwar": 171,
    "pasfly": 172,
    "pecsan": 173,
    "perfal": 174,
    "phaino": 175,
    "pibgre": 176,
    "pilwoo": 177,
    "pingro": 178,
    "pinjay": 179,
    "pinsis": 180,
    "pinwar": 181,
    "plsvir": 182,
    "prawar": 183,
    "purfin": 184,
    "pygnut": 185,
    "rebmer": 186,
    "rebnut": 187,
    "rebsap": 188,
    "rebwoo": 189,
    "redcro": 190,
    "redhea": 191,
    "reevir1": 192,
    "renpha": 193,
    "reshaw": 194,
    "rethaw": 195,
    "rewbla": 196,
    "ribgul": 197,
    "rinduc": 198,
    "robgro": 199,
    "rocpig": 200,
    "rocwre": 201,
    "rthhum": 202,
    "ruckin": 203,
    "rudduc": 204,
    "rufgro": 205,
    "rufhum": 206,
    "rusbla": 207,
    "sagspa1": 208,
    "sagthr": 209,
    "savspa": 210,
    "saypho": 211,
    "scatan": 212,
    "scoori": 213,
    "semplo": 214,
    "semsan": 215,
    "sheowl": 216,
    "shshaw": 217,
    "snobun": 218,
    "snogoo": 219,
    "solsan": 220,
    "sonspa": 221,
    "sora": 222,
    "sposan": 223,
    "spotow": 224,
    "stejay": 225,
    "swahaw": 226,
    "swaspa": 227,
    "swathr": 228,
    "treswa": 229,
    "truswa": 230,
    "tuftit": 231,
    "tunswa": 232,
    "veery": 233,
    "vesspa": 234,
    "vigswa": 235,
    "warvir": 236,
    "wesblu": 237,
    "wesgre": 238,
    "weskin": 239,
    "wesmea": 240,
    "wessan": 241,
    "westan": 242,
    "wewpew": 243,
    "whbnut": 244,
    "whcspa": 245,
    "whfibi": 246,
    "whtspa": 247,
    "whtswi": 248,
    "wilfly": 249,
    "wilsni1": 250,
    "wiltur": 251,
    "winwre3": 252,
    "wlswar": 253,
    "wooduc": 254,
    "wooscj2": 255,
    "woothr": 256,
    "y00475": 257,
    "yebfly": 258,
    "yebsap": 259,
    "yehbla": 260,
    "yelwar": 261,
    "yerwar": 262,
    "yetvir": 263,
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
SR = 32000


def mono_to_color(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """
    Make 2D image 3D. Normalise and scale to be between 0 and 255.

    Parameters
    ----------
    X
        2D image.
    eps
        Epsilon, small number to add to std to avoid dividing by zero.

    Returns
    -------
    np.ndarray
        3D image, normalised, scaled.
    """
    original_shape = X.shape
    assert len(X.shape) == 2

    X = np.stack([X, X, X], axis=-1)
    assert X.shape == (*original_shape, 3)

    mean = X.mean()
    X = X - mean
    std = X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    if (_max - _min) > eps:
        V = Xstd
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


def _get_image_from_audio(
    y: np.ndarray, melspectrogram_parameters: tp.Dict[str, int], img_size: int
) -> np.ndarray:
    """
    Get 3D image from audio.

    Steps are:
    - Compute the mel-scaled (power) spectrogram of the signal
    - Convert the power spectrogram to decibel units.
    - Stack signal so as to get 3D image, and normalise

    Parameters
    ----------
    y
        Audio signal.
    melspectrogram_parameters
        Arguments to be passed on to melspectrogram.
    img_size
        Desired width of image that'll be obtained by converting audio.
    """
    melspec = librosa.feature.melspectrogram(y, sr=SR, **melspectrogram_parameters)
    melspec = librosa.power_to_db(melspec).astype(np.float32)

    image = mono_to_color(melspec)
    height, width, _ = image.shape
    image = cv2.resize(image, (int(width * img_size / height), img_size))
    image = np.moveaxis(image, 2, 0)
    image = (image / 255.0).astype(np.float32)
    return image


class TestDataset(data.Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        clip: np.ndarray,
        img_size: int,
        melspectrogram_parameters: tp.Dict[str, int],
    ) -> None:
        """
        Initialise test dataset.

        Parameters
        ----------
        df
            Rows of test DataFrame corresponding to given audio clip.
        clip
            Audio clip as time series.
        img_size
            Desired width of image that'll be obtained by converting audio.
        melspectrogram_parameters
            Arguments to be passed on to melspectrogram.
        """
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> tp.Tuple[np.ndarray, str, str]:
        """
        Preprocess row of data.

        If the row comes from site 1 or site 2, we just convert the
        sound file to a 3D image. Otherwise, we split the sound file
        into 5-second chunks (discarding the last incomplete one) and
        transform each to an image.

        Parameters
        ----------
        idx
            Index of element to pre-process.

        Returns
        -------
        img
            3D image representation of sound file.
        row_id
            Row id corresponding to element.
        site
            Site where recording was taken (site 1, site 2, or site 3)
        """
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id

        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * 5):
                    break
                start = end
                end = end + SR * 5
                image = _get_image_from_audio(
                    y_batch, self.melspectrogram_parameters, self.img_size
                )
                images.append(image)
            images = np.asarray(images)
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)

            start_index = SR * start_seconds
            end_index = SR * end_seconds

            y = self.clip[start_index:end_index].astype(np.float32)
            image = _get_image_from_audio(y, self.melspectrogram_parameters, self.img_size)

        return image, row_id, site

In [None]:
def get_model(args: tp.Dict) -> nn.Module:
    """
    Get pre-trained model and customise head.

    Parameters
    ----------
    args
        Additional arguments to pass to Pytorch model
    """
    model = getattr(resnest_torch, args["base_model_name"])(pretrained=args["pretrained"])
    del model.fc
    model.fc = nn.Sequential(
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(1024, 1024),
        nn.ReLU(),
        nn.Dropout(p=0.2),
        nn.Linear(1024, args["num_classes"]),
    )

    state_dict = torch.load(args["trained_weights"])
    model.load_state_dict(state_dict)
    device = torch.device("cuda")
    model.to(device)
    model.eval()

    return model

## Prediction loop

In [None]:
def prediction_for_clip(
    test_df: pd.DataFrame,
    clip: np.ndarray,
    model: nn.Module,
    mel_params: tp.Dict[str, int],
    threshold: float = 0.5,
) -> tp.Dict[str, str]:
    """
    Make prediction for single audio clip.

    Audio clip may correspond to multiple rows from the dataset
    (e.g. seconds 0-5, then seconds 5-10, ...).

    Parameters
    ----------
    test_df
        Portion of test dataset corresponding to this audio clip.
    clip
        Audio as floating point time series.
    model
        Model (already trained).
    mel_params
        Parameters for melspectrogram..
    threshold
        Predict all targets whose logits are above this value.

    Returns
    -------
    Dict
        Keys are row_ids, values are list of birds detected.
    """
    assert clip.ndim == 1
    pd.testing.assert_index_equal(
        test_df.columns, pd.Index(["site", "row_id", "seconds", "audio_id"])
    )

    dataset = TestDataset(
        df=test_df, clip=clip, img_size=224, melspectrogram_parameters=mel_params
    )
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    prediction_dict = {}
    for image, row_id, site in progress_bar(loader):
        assert image.ndim == 4
        # Last dimension is height
        batch_size, channels, frequency, time = image.shape
        assert (batch_size, channels, frequency) == (1, 3, 224)
        assert len(row_id) == 1
        assert len(site) == 1

        (site,) = site
        (row_id,) = row_id

        image = image.to(device)
        with torch.no_grad():
            prediction = torch.sigmoid(model(image))
            proba = prediction.detach().cpu().numpy().reshape(-1)
            assert proba.shape == (264,)

        events = proba >= threshold
        labels = np.argwhere(events).reshape(-1).tolist()

        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [None]:
def prediction(
    test_df: pd.DataFrame,
    test_audio: Path,
    model_config: tp.Dict[str, tp.Union[str, int, bool]],
    mel_params: tp.Dict[str, int],
    target_sr: int,
    threshold: float = 0.5,
) -> pd.DataFrame:
    """
    Get predictions.

    For each unique audio ID, predict which bird is present in each 5 second clip.

    Parameters
    ----------
    test_df
        test csv file
    test_audio
        Directory containing audio recordings for test dataset.
    model_config
        Configs to load model with
    mel_params
        Parameters for melspectrogram (used in data loader).
    target_sr
        Target sampling rate
    threshold
        Predict all targets whose logits are above this value.

    Returns
    -------
    DataFrame
        Predictions for each row of the input data.
    """
    model = get_model(model_config)
    unique_audio_id = test_df.audio_id.unique()

    prediction_dfs = []
    for audio_id in unique_audio_id:
        with timer(f"Loading {audio_id}"), warnings.catch_warnings():
            # libsndfile doesn't handle mp3
            # see https://github.com/librosa/librosa/issues/1015
            warnings.simplefilter("ignore", UserWarning)
            clip, _ = librosa.load(
                test_audio / (audio_id + ".mp3"),
                sr=target_sr,
                mono=True,
                res_type="kaiser_fast",
            )

        test_df_for_audio_id = test_df.query(f"audio_id == '{audio_id}'")
        test_df_for_audio_id = test_df_for_audio_id.reset_index(drop=True)
        with timer(f"Prediction on {audio_id}"):
            prediction_dict = prediction_for_clip(
                test_df_for_audio_id,
                clip=clip,
                model=model,
                mel_params=mel_params,
                threshold=threshold,
            )
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({"row_id": row_id, "birds": birds})
        prediction_dfs.append(prediction_df)

    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    assert len(prediction_df) == len(test_df)
    return prediction_df

## Prediction

In [None]:
submission = prediction(
    test_df=test,
    test_audio=TEST_AUDIO_DIR,
    model_config=model_config,
    mel_params=melspectrogram_parameters,
    target_sr=TARGET_SR,
    threshold=0.6,
)
submission.to_csv("submission.csv", index=False)

In [None]:
submission

## EOF