In [None]:
import sys
sys.path.insert(0, "../input/resnest/")

import os
import gc
import time
import math
import shutil
import random
import warnings
import typing as tp
from pathlib import Path
from contextlib import contextmanager

import yaml
from joblib import delayed, Parallel

import cv2
import librosa
import audioread
import soundfile as sf

import numpy as np
import pandas as pd

from fastprogress import progress_bar
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import resnest.torch as resnest_torch

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

import gc
gc.enable()

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0")

rand_seed = 1120

debug_mode = False
# debug_mode = True

In [None]:
image_size = 224

n_classes = 264

use_fold = 0

model_files = [
    ("resnest50_fast_2s2x40d", "../input/bird-custom-ext-v1-fold0/F1_snapshot_epoch_26.pth"),
    ("resnest50_fast_2s2x40d", "../input/bird-custom-ext-v1-fold1/F1_snapshot_epoch_30.pth"),
    ("resnest50_fast_2s2x40d", "../input/bird-custom-ext-v1-fold2/F1_snapshot_epoch_29.pth"),
    ("resnest50_fast_2s2x40d", "../input/bird-custom-ext-v1-fold3/F1_snapshot_epoch_30.pth"),
    ("resnest50_fast_2s2x40d", "../input/bird-custom-ext-v1-fold4/snapshot_epoch_25.pth"),
]

public_model_file =  ("resnest50_fast_1s1x64d", "../input/training-birdsong-baseline-resnest50-fast/best_model.pth")
resnest_200_model_files =  [
    ("resnest200", "../input/resnest200-v2-fold0/resnest200_v2_epoch_44.pth"),
    ("resnest200", "../input/resnest200-fold0/resnest200_epoch_50.pth")
]
nocall_model_files = [
    ("resnest50_fast_2s2x40d", "../input/bird-custom-nocall-v2-fold1/F1_snapshot_epoch_4.pth"),
    ("resnest50_fast_2s2x40d", "../input/bird-custom-nocall-v2-fold1/F1_snapshot_epoch_10.pth")
]

infer_batch_size = 1024

prob_threshold = 0.3
# prob_threshold = 0.5
# prob_threshold = 0.6

num_workers = 4

TARGET_SR = 32000

melspectrogram_parameters = {"n_mels": 128, "fmin": 20, "fmax": 16000}

### Utilities

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # for reproducibility


@contextmanager
def timer(name: str) -> None:
    """Timer Util"""
    t0 = time.time()
    print("[{}] start".format(name))
    yield
    print("[{}] done in {:.0f} s".format(name, time.time() - t0))


def print_model_info(model):
    total = 0
    trainable = 0
    non_trainable = 0
    for p in model.parameters():
        if p.requires_grad:
            trainable += p.numel()
        else:
            non_trainable += p.numel()
        total += p.numel()

    print(f"Total number of model parameters: {total:,}")
    print(f"Total number of trainable parameters: {trainable:,}")
    print(f"Total number of non-trainable parameters: {non_trainable:,}")


set_seed(rand_seed)

### Read data

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
TEST_AUDIO_DIR = RAW_DATA / "test_audio"

# ROOT = Path("/workspace/Kaggle/Bird")
# TRAIN_AUDIO_DIR = Path(
#     "/workspace/Kaggle/Bird/birdsong-recognition/train_audio")
# TRAIN_RESAMPLED_AUDIO_DIR = ROOT / "birdsong-resampled-train-audio"
# TEST_AUDIO_DIR = Path("/workspace/Kaggle/Bird/birdsong-recognition/test_audio")

In [None]:
train = pd.read_csv(RAW_DATA / "train.csv")
if not TEST_AUDIO_DIR.exists():
    TEST_AUDIO_DIR = INPUT_ROOT / "birdcall-check" / "test_audio"
    test = pd.read_csv(INPUT_ROOT / "birdcall-check" / "test.csv")
else:
    test = pd.read_csv(RAW_DATA / "test.csv")

# train = pd.read_csv(TRAIN_RESAMPLED_AUDIO_DIR / "train_mod.csv")

# if not TEST_AUDIO_DIR.exists():
#     TEST_AUDIO_DIR = ROOT / "birdcall-check" / "test_audio"
#     test = pd.read_csv(ROOT / "birdcall-check" / "test.csv")
# else:
#     test = pd.read_csv(ROOT / "birdsong-recognition" / "test.csv")

In [None]:
train.head()

In [None]:
test.head()

In [None]:
sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
# sub = pd.read_csv(ROOT / "birdsong-recognition/sample_submission.csv")
print(sub.shape)
print(sub.head())
# sub.to_csv("submission.csv", index=False)  # this will be overwritten if everything goes well

## Definition

### Dataset

For `site_3`, I decided to use the same procedure as I did for `site_1` and `site_2`, which is, crop 5 seconds out of the clip and provide prediction on that short clip.
The only difference is that I crop 5 seconds short clip from start to the end of the `site_3` clip and aggeregate predictions for each short clip after I did prediction for all those short clips.

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263, 'nocall': 264
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    """
    Code from https://www.kaggle.com/daisukelab/creating-fat2019-preprocessed-data
    """
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class TestDataset(data.Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 clip: np.ndarray,
                 img_size=image_size,
                 melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id

        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * 5):
                    break
                start = end
                end = end + SR * 5

                melspec = librosa.feature.melspectrogram(
                    y_batch, sr=SR, **self.melspectrogram_parameters)
                melspec = librosa.power_to_db(melspec).astype(np.float32)
                image = mono_to_color(melspec)
                height, width, _ = image.shape
                image = cv2.resize(
                    image,
                    (int(width * self.img_size / height), self.img_size))
                image = np.moveaxis(image, 2, 0)
                image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)

            start_index = SR * start_seconds
            end_index = SR * end_seconds

            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(
                y, sr=SR, **self.melspectrogram_parameters)
            melspec = librosa.power_to_db(melspec).astype(np.float32)

            image = mono_to_color(melspec)
            height, width, _ = image.shape
            image = cv2.resize(
                image, (int(width * self.img_size / height), self.img_size))
            image = np.moveaxis(image, 2, 0)
            image = (image / 255.0).astype(np.float32)

            return image, row_id, site

In [None]:
def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    if hasattr(layer, 'weight'):
        nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)

def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class CustomResNeSt(nn.Module):
    def __init__(self, pretrained_model, n_classes):
        
        super(CustomResNeSt, self).__init__()

        self.backbone = getattr(resnest_torch, pretrained_model)(pretrained=False)

        self.backbone.fc = nn.Sequential(
            nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, n_classes))
        
        self.init_weight()

    def init_weight(self):
        for layer in self.backbone.fc:
            init_layer(layer)

    def forward(self, input):
        """
        Input: (batch_size, data_length)"""

        if self.training:
            out = self.backbone(x)
        else:
            out = self.backbone(input)

        return out

In [None]:
def get_model(pretrained_model, model_file_path):
    model = CustomResNeSt(pretrained_model, n_classes)

    print(f"Loading model states from {model_file_path} ......")
    model.load_state_dict(torch.load(model_file_path))
    device = torch.device("cuda")
    model.to(device)
    model.eval()

    print_model_info(model)

    gc.collect()
    torch.cuda.empty_cache()

    return model

In [None]:
def get_resnest_model(pretrained_model, model_file_path):
    model = getattr(resnest_torch, pretrained_model)(pretrained=False)
    del model.fc
    # # use the same head as the baseline notebook.
    model.fc = nn.Sequential(nn.Linear(2048,
                                       1024), nn.ReLU(), nn.Dropout(p=0.2),
                             nn.Linear(1024, 1024), nn.ReLU(),
                             nn.Dropout(p=0.2), nn.Linear(1024, n_classes))

    print(f"Loading model states from {model_file_path} ......")
    model.load_state_dict(torch.load(model_file_path))
    device = torch.device("cuda")
    model.to(device)
    model.eval()

    print_model_info(model)

    gc.collect()
    torch.cuda.empty_cache()

    return model

In [None]:
def get_nocall_resnest_model(pretrained_model, model_file_path):
    model = CustomResNeSt(pretrained_model, n_classes+1)

    print(f"Loading model states from {model_file_path} ......")
    model.load_state_dict(torch.load(model_file_path))
    device = torch.device("cuda")
    model.to(device)
    model.eval()

    print_model_info(model)

    gc.collect()
    torch.cuda.empty_cache()

    return model

## Prediction loop

In [None]:
def prediction_for_clip(test_df: pd.DataFrame,
                        clip,
                        models,
                        public_model,
                        resnest_models,
                        nocall_models,
                        mel_params: dict, 
                        threshold=0.5):

    for m in models:
        m.eval()
    public_model.eval()
    for m in resnest_models:
        m.eval()
    for m in nocall_models:
        m.eval()
        
    prediction_dict = {}

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=image_size,
                          melspectrogram_parameters=mel_params)

    # Site 1 & 2: use normal batch size
    if test_df.site.values[0] in ["site_1", "site_2"]:
        loader = data.DataLoader(dataset,
                                 batch_size=infer_batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False,
                                 pin_memory=True)

        for image, row_id, site in progress_bar(loader):
            
            with torch.no_grad():
                image = image.to(device)
                preds = np.zeros((image.size(0), n_classes+1))
                for m in models:
                    prediction = F.sigmoid(m(image))
                    proba = prediction.detach().cpu().numpy()
                    preds[:, :-1] += proba
                preds /= len(models)
                
                public_preds = np.zeros((image.size(0), n_classes+1))
                prediction = F.sigmoid(public_model(image))
                public_proba = prediction.detach().cpu().numpy()
                public_preds[:, :-1] = public_proba
                
                resnest_preds = np.zeros((image.size(0), n_classes+1))
                for m in resnest_models:
                    prediction = F.sigmoid(m(image))
                    proba = prediction.detach().cpu().numpy()
                    resnest_preds[:, :-1] += proba
                resnest_preds /= len(resnest_models)
                
                nocall_preds = np.zeros((image.size(0), n_classes+1))
                for m in nocall_models:
                    prediction = F.sigmoid(m(image))
                    proba = prediction.detach().cpu().numpy()
                    nocall_preds += proba
                nocall_preds /= len(nocall_models)
                
                preds = (preds + public_preds + resnest_preds + nocall_preds)/4
                
                events = preds >= threshold              
                
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()

                    # Keep separate result per frame
                    frame_id = row_id[i]
                    if len(labels) == 0:
                        prediction_dict[frame_id] = "nocall"
                    else:
                        if len(labels) > 1 and 264 in labels:
                            labels.remove(264)
                        labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
                        label_string = " ".join(labels_str_list)
                        prediction_dict[frame_id] = label_string

    else:
        # Site 3: batch size = 1
        loader = data.DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False,
                                 pin_memory=True)
    
        for image, row_id, site in progress_bar(loader):
            site = site[0]
            row_id = row_id[0]

            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = infer_batch_size
            # batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch.ndim == 3:
                    batch = batch.unsqueeze(0)

                with torch.no_grad():
                    batch = batch.to(device)
                    
                    preds = np.zeros((batch.size(0), n_classes+1))
                    for m in models:
                        prediction = F.sigmoid(m(batch))
                        proba = prediction.detach().cpu().numpy()
                        preds[:, :-1] += proba
                    preds /= len(models)
                    
                    public_preds = np.zeros((batch.size(0), n_classes+1))
                    prediction = F.sigmoid(public_model(batch))
                    public_proba = prediction.detach().cpu().numpy()
                    public_preds[:, :-1] = public_proba

                    resnest_preds = np.zeros((batch.size(0), n_classes+1))
                    for m in resnest_models:
                        prediction = F.sigmoid(m(batch))
                        proba = prediction.detach().cpu().numpy()
                        resnest_preds[:, :-1] += proba
                    resnest_preds /= len(resnest_models)

                    nocall_preds = np.zeros((batch.size(0), n_classes+1))
                    for m in nocall_models:
                        prediction = F.sigmoid(m(batch))
                        proba = prediction.detach().cpu().numpy()
                        nocall_preds += proba
                    nocall_preds /= len(nocall_models)

                    preds = (preds + public_preds + resnest_preds + nocall_preds)/4

                    events = preds >= threshold
                    for i in range(len(events)):
                        event = events[i, :]
                        labels = np.argwhere(event).reshape(-1).tolist()

                        for label in labels:
                            all_events.add(label)
                        
            labels = list(all_events)

            if len(labels) == 0:
                prediction_dict[row_id] = "nocall"
            else:
                if len(labels) > 1 and 264 in labels:
                    labels.remove(264)
                labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
                label_string = " ".join(labels_str_list)
                prediction_dict[row_id] = label_string

    del loader, dataset

    gc.collect()
    torch.cuda.empty_cache()

    return prediction_dict

In [None]:
def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               mel_params: dict,
               target_sr: int,
               threshold=0.5):
    models = []
    for m, p in model_files:
        models.append(get_model(m, p))
        
    public_model = get_resnest_model(public_model_file[0], public_model_file[1])
    
    resnest_models = []
    for m, p in resnest_200_model_files:
        resnest_models.append(get_resnest_model(m, p))
    
    nocall_models = []
    for m, p in nocall_model_files:
        nocall_models.append(get_nocall_resnest_model(m, p))
    
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_id in unique_audio_id:
        with timer(f"Loading {audio_id}"):
            clip, _ = librosa.load(test_audio / (audio_id + ".mp3"),
                                   sr=target_sr,
                                   mono=True,
                                   res_type="kaiser_fast")

        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        with timer(f"Prediction on {audio_id}"):
            prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                                  clip=clip,
                                                  models=models,
                                                  public_model=public_model,
                                                  resnest_models=resnest_models,
                                                  nocall_models=nocall_models,
                                                  mel_params=mel_params,
                                                  threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({"row_id": row_id, "birds": birds})
        prediction_dfs.append(prediction_df)

    prediction_df = pd.concat(prediction_dfs, axis=0,
                              sort=False).reset_index(drop=True)
    return prediction_df

## Prediction

In [None]:
submission = prediction(test_df=test,
                        test_audio=TEST_AUDIO_DIR,
                        mel_params=melspectrogram_parameters,
                        target_sr=TARGET_SR,
                        threshold=prob_threshold)
submission.to_csv("submission.csv", index=False)

In [None]:
print(submission.shape)
submission.head(50)

## EOF