## About

I've spent several days to make a successful submission and finally got a way to do that after 3 days of struggle. 
I want Kaggle competitors to feel easy to participate in this competition, therefore I decided to share my Notebook.

I would like to thank [@radek1](https://www.kaggle.com/radek1) for creating [a good starter notebook](https://www.kaggle.com/c/birdsong-recognition/discussion/160222), [@shonenkov](https://www.kaggle.com/shonenkov) for [a nice checking dataset and notebook](https://www.kaggle.com/shonenkov/sample-submission-using-custom-check) and several discussions to make this competition better, [@cwthompson](https://www.kaggle.com/cwthompson) for [showing the way to submit](https://www.kaggle.com/cwthompson/birdsong-making-a-prediction) using `test_audio`.

I also would like to thank [@stefankahl](https://www.kaggle.com/stefankahl), [@tomdenton](https://www.kaggle.com/tomdenton), [@sohier](https://www.kaggle.com/sohier) for hosting a really interesting competition.

In this notebook I tried to make submission using ResNet based model trained with log melspectrogram. I will create a notebook to show the way I trained the model but here I briefly describe my approach.

* Randomly crop 5 seconds for each train audio clip each epoch.
* No augmentation.
* Use pretrained weight of `torchvision.models.resnet50`.
* Used `BCELoss`.
* Trained 100 epoch and used the weight which got best F1 (at 92epoch).
* `Adam` optimizer (`lr=0.001`) with `CosineAnnealingLR` (`T_max=10`).
* Use `StratifiedKFold(n_splits=5)` to split dataset and used only first fold

Here are the parameter details.

* `batch_size`: 100 (on V100, took 2 ~ 3hrs to run 100epochs)
* melspectrogram parameters
  - `n_mels`: 128
  - `fmin`: 20
  - `fmax`: 16000
* image size: 224 x 541 (I don't remember the exact width)

### Future direction

There are a lot many to do to make improvement. It was a big challenge for me to make successful submission with very few feedback signal (like `Submission CSV Not Found` or `Notebook Exceeded Allowed Compute`), but this is just a beginning of the real challenge.
As described in https://www.kaggle.com/c/birdsong-recognition/discussion/160222#895234 , data augmentation is a key. I worked on [Freesound Audio Tagging 2019](https://www.kaggle.com/c/freesound-audio-tagging-2019) last year, which was also an audio competition (which is comparatively rare in Kaggle), and at that time data augmentation like pitch shift or reverb effect gave us a boost. This competition is not about environmental sound but about bird song, therefore we need to check what augmentation works best on this data by experiment. Maybe we can get a boost with different augmentation for different audio class.

Mixup / BClearning or mixing different audio class may give us a rise I believe, since the test set has multiple sounds in the clip whereas train set has basically one class for one clip (of course we can use background sound information to treat train set as multilabel problem).

Training procedure also has an important role, whether we use a procedure for multilabel problem (by using background sound) or for multiclass problem. The challenge of this competition can also be treated as *Domain Adaptation* problem, so we can use techniques for that.

Model selection is also important, deeper model may give us a rise, but from my experience, *too deep* model are sometimes defeated by shallower model in audio classification.

## Libraries

In [None]:
import cv2
import audioread
import logging
import os
import random
import time
import warnings

import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from fastprogress import progress_bar
from sklearn.metrics import f1_score
from torchvision import models

## Utilities

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
    
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger
    
    
@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)

In [None]:
logger = get_logger("main.log")
set_seed(1213)

## Data Loading

In [None]:
TEST = Path("../input/birdsong-recognition/test_audio").exists()

In [None]:
if TEST:
    DATA_DIR = Path("../input/birdsong-recognition/")
else:
    # dataset created by @shonenkov, thanks!
    DATA_DIR = Path("../input/birdcall-check/")
    

test = pd.read_csv(DATA_DIR / "test.csv")
test_audio = DATA_DIR / "test_audio"


test.head()

In [None]:
sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
sub.to_csv("submission.csv", index=False)  # this will be overwritten if everything goes well

In [None]:
# load the training dataset 
TRAIN = Path('/kaggle/input/birdsong-recognition/train.csv')
train = pd.read_csv(TRAIN)
train.head()


In [None]:
train.info()

## Define Model

### Model

In [None]:
class ResNet(nn.Module):
    def __init__(self, pretrained=False, num_classes=264):
        super().__init__()
        base_model = models.resnet50(
            pretrained=pretrained)
        layers = list(base_model.children())[:-2]
        layers.append(nn.AdaptiveMaxPool2d(1))
        self.encoder = nn.Sequential(*layers)

        in_features = base_model.fc.in_features

        self.classifier = nn.Sequential(
            nn.Linear(in_features, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, num_classes))

    def forward(self, x):
        batch_size = x.size(0)
        x = self.encoder(x)
        x = torch.flatten(x, start_dim = 1)  # batch_size, dim 
        x = self.classifier(x)
        multiclass_proba = F.softmax(x, dim=1)
        multilabel_proba = F.sigmoid(x)
        return multilabel_proba
#         return {
#             "logits": x,
#             "multiclass_proba": multiclass_proba,
#             "multilabel_proba": multilabel_proba
#         }

## Parameters

In [None]:
weights_path = "../input/birdcall-resnet50-init-weights/best.pth"

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
# find sample rate of trainin dataset 
train_csv = pd.read_csv('/kaggle/input/birdsong-recognition/train.csv')
train_csv.head()  # ebidr_code filename

In [None]:
train_csv['full_path'] = '/kaggle/input/birdsong-recognition/train_audio/'+ train_csv['ebird_code']+ '/'+ train_csv['filename']
astfly = train_csv[train_csv['ebird_code'] == 'astfly'].sample(1, set_seed())['full_path'].values[0]

In [None]:
train_csv.shape

In [None]:
import torchaudio
import torchaudio.transforms as transforms
# get_sample_rate = lambda x: torchaudio.load(x)[-1]
# train_csv['sample_rate'] = train_csv['full_path'].apply(lambda x: torchaudio.load(x)[1] if torchaudio.load(x) else None)
# train_csv.head()
sample_example = []

for index in random.sample(range(0, train_csv.shape[0]), 100):
    _, sample_rate = torchaudio.load(train_csv['full_path'][index])
    sample_example.append(sample_rate)
print(np.average(sample_example)) # 48000

In [None]:
class AudioUtil():
#     def __init__(self)

    def open(self, file_path):
        return torchaudio.load(file_path)

    def resample(self, aud, newsr):
        sig, sr = aud
        if sr == newsr:
            return aud 
        num_channels = sig.shape[0]
        output = []
        resig = torchaudio.transforms.Resample(sr, newsr)(sig)
        if num_channels == 1:
            
            return torch.cat([resig, resig, resig]), newsr
        elif num_channels == 2:
            return torch.cat([resig, torch.average(resig, axis = 0).reshape(1, -1)]), newsr
    
    def pad_trunc(self, aud, max_s):  # max_s = 5
        sig, sr = aud
        num_rows, sig_len = sig.shape
        max_len = sr * max_s

        if (sig_len > max_len):
          # Truncate the signal to the given length
            sig = sig[:,:max_len]

        elif (sig_len < max_len):
          # Length of padding to add at the beginning and end of the signal

            pad_end_len = max_len - sig_len
            pad_end = torch.zeros((num_rows, pad_end_len))
            sig = torch.cat((sig, pad_end), 1)
          
        return (sig, sr)
    
    

    def time_shift(self, aud, shift_limit):
        sig, sr = aud
        _, sig_len = sig.shape
        shift_amt = int(random.random() * shift_limit * sig_len)
        return (sig.roll(shift_amt), sr)
    

    def spectro_gram(self, aud, n_mels=64, n_fft=1024, hop_len=None):
        sig, sr = aud
        top_db = 80

        # spec has shape [channel, n_mels, time], where channel is mono, stereo etc
        spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)

        # Convert to decibels
        spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
        return (spec)

    def spectro_augment(self, spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
        _, n_mels, n_steps = spec.shape
        mask_value = spec.mean()
        aug_spec = spec

        freq_mask_param = max_mask_pct * n_mels
        for _ in range(n_freq_masks):
            aug_spec = transforms.FrequencyMasking(freq_mask_param)(aug_spec, mask_value)

        time_mask_param = max_mask_pct * n_steps
        for _ in range(n_time_masks):
            aug_spec = transforms.TimeMasking(time_mask_param)(aug_spec, mask_value)

        return aug_spec
            

In [None]:
class TrainDataset(data.Dataset):  # train size = 3, 224, x
    def __init__(self, df, img_size = 224, sr = 32000, duration = 5):
        self.df = df 
        self.img_size = img_size 
        self.melspectrogram_parameters = melspectrogram_parameters
        self.sr = sr
        self.duration = duration

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):

        sample = self.df.loc[idx, :]
        file_path = '/kaggle/input/birdsong-recognition/train_audio/' + sample['ebird_code'] + '/' + sample['filename']

        aud = AudioUtil().open(file_path)
        reaud = AudioUtil().resample(aud, self.sr)  # and channel into 3 
        dur_aud = AudioUtil().pad_trunc(reaud, self.duration)
        shift_aud = AudioUtil().time_shift(dur_aud, shift_limit = 0.5)
        sgram = AudioUtil().spectro_gram(shift_aud, n_mels=self.img_size, n_fft=1024, hop_len=None)  # size would be 3, 244, 313; why 
        aug_sgram = AudioUtil().spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

        return aug_sgram, idx



In [None]:
class TestDataset(data.Dataset):  # train size = 3, 224, x
    def __init__(self, df, img_size = 224, sr = 32000, duration = 5):
        self.df = df 
        self.img_size = img_size 
        self.melspectrogram_parameters = melspectrogram_parameters
        self.sr = sr
        self.duration = duration

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):

        sample = self.df.loc[idx, :]
        site = sample.site 
        audio_id = sample.audio_id
        duration = sample.seconds
        file_path = '/kaggle/input/birdcall-check/test_audio/' + audio_id
        aud = AudioUtil().open(file_path)
        reaud = AudioUtil().resample(aud, self.sr)  # and channel into 3 
        
        if site == 'site_3':
            images = []
            start = 0
            duration = 5
            while start < duration:
                end = 5 * self.sr + start

                if end > duration:
                    reaud1 = reaud[:, start:]
                    reaud1 = AudioUtil.pad_trunc(reaud1, self.duration)
                else:
                    reaud1 = reaud[:, start:end]
                
                
                sgram = AudioUtil().spectro_gram(reaud1, n_mels=self.img_size, n_fft=1024, hop_len=None)  # size would be 3, 244, 313
                print(sgram.shape)
                images.append(sgram)
                start = end 
            return np.array(images), audio_id, site

        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = self.sr * start_seconds
            end_index = self.sr * end_seconds
            dur_aud = reaud[:, start_index:end_index]
            sgram = AudioUtil().spectro_gram(dur_aud, n_mels=self.img_size, n_fft=1024, hop_len=None)  # size would be 3, 244, 313
            return sgram, audio_id, site


In [None]:
from torch.utils.data import random_split
my_ds = TrainDataset(df = train_csv)
# mydataset.__getitem__(10000)[0].shape

num_items = len(my_ds)
num_train = round(num_items * 0.8)
num_valid = num_items - num_train
train_ds, valid_ds = random_split(my_ds, [num_train, num_valid])


train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 16, shuffle = True)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size = 16, shuffle = False)


In [None]:
def valid(model, val_dl):
    correct_prediction = 0
    total_prediction = 0

    # Disable gradient updates
    with torch.no_grad():
        for data in val_dl:
            # Get the input features and target labels, and put them on the GPU
            inputs, labels = data[0].to(device), data[1].to(device)

            # Normalize the inputs
            inputs_m, inputs_s = inputs.mean(), inputs.std()
            inputs = (inputs - inputs_m) / inputs_s

            # Get predictions
            outputs = model(inputs)

            # Get the predicted class with the highest score
            _, prediction = torch.max(outputs,1)
            # Count of predictions that matched the target label
            correct_prediction += (prediction == labels).sum().item()
            total_prediction += prediction.shape[0]

    acc = correct_prediction/total_prediction
    print(f'Validation Accuracy: {acc:.2f}, Total items: {total_prediction}')



In [None]:
## start train model 

def training(model, train_dl, valid_dl, num_epochs):
    # Loss Function, Optimizer and Scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001,
                                                steps_per_epoch=int(len(train_dl)),
                                                epochs=num_epochs,
                                                anneal_strategy='linear')

    # Repeat for each epoch
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_prediction = 0
        total_prediction = 0

        # Repeat for each batch in the training set
        for i, data in enumerate(train_dl):
            # Get the input features and target labels, and put them on the GPU
            inputs, labels = data[0].to(device), data[1].to(device)

            # Normalize the inputs
            inputs_m, inputs_s = inputs.mean(), inputs.std()
            inputs = (inputs - inputs_m) / inputs_s

            # Zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Keep stats for Loss and Accuracy
            running_loss += loss.item()

            # Get the predicted class with the highest score
            _, prediction = torch.max(outputs,1)
            # Count of predictions that matched the target label
            correct_prediction += (prediction == labels).sum().item()
            total_prediction += prediction.shape[0]

    #         if i % 10 == 0:    # print every 10 mini-batches
    #             print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
                # Run inference on trained model with the validation set


        # Print stats at the end of the epoch
        num_batches = len(train_dl)
        avg_loss = running_loss / num_batches
        acc = correct_prediction/total_prediction
        print(f'Epoch: {epoch}, Loss: {avg_loss:.2f}, Accuracy: {acc:.2f}')
        if epoch % 10 == 0:
            valid(model, val_dl)

        print('Finished Training')

num_epochs=20   # Just for demo, adjust this higher.
myModel = ResNet(pretrained=True)
training(myModel, train_dl, valid_dl, num_epochs)

In [None]:
def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               model_config: dict,
               mel_params: dict,
               weights_path: str,
               threshold=0.5):
    model = get_model(model_config, weights_path)
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_id in unique_audio_id:        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        with timer(f"Prediction on {audio_id}", logger):
            prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                                  model=model,
                                                  mel_params=mel_params,
                                                  threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [None]:

        
def prediction_for_clip(test_df: pd.DataFrame, 
                        model: ResNet, 
                        mel_params: dict, 
                        threshold=0.5):

    dataset = TestDataset(df=test_df, 
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in progress_bar(loader):
                    # Normalize the inputs
        inputs_m, inputs_s = image.mean(), image.std()
        image = (image - inputs_m) / inputs_s
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            image = image.to(device)

            with torch.no_grad():
                prediction = model(image)
                proba = prediction["multilabel_proba"].detach().cpu().numpy().reshape(-1)

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()

        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch.ndim == 3:
                    batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    proba = prediction["multilabel_proba"].detach().cpu().numpy()
                    
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [None]:
submission = prediction(test_df=test,
                        test_audio=test_audio,
                        model_config=model_config,
                        mel_params=melspectrogram_parameters,
                        weights_path=weights_path,
                        threshold=0.8)
submission.to_csv("submission.csv", index=False)

## Prediction loop

In [None]:
def get_model(config: dict, weights_path: str):
    model = ResNet(**config)
    checkpoint = torch.load(weights_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    device = torch.device("cuda")
    model.to(device)
    model.eval()
    return model