# Birdsong Pytorch Baseline: ResNeSt50-fast (Inference)

## About

In this notebook, I try ResNeSt, which is the one of state of the art in image recognition.  

This is a notebook for **_inference & submission_**. I shared training process as another one:  
https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast  
If you want to know experimental details, see it.

Most of this notebook consists of [great baseline](https://www.kaggle.com/hidehisaarai1213/inference-pytorch-birdcall-resnet-baseline) shared by @hidehisaarai1213 .  
Thank you for sharing !

## Prepare

### import libraries

In [None]:
import os
import gc
import time
import math
import shutil
import random
import warnings
import typing as tp
from pathlib import Path
from contextlib import contextmanager
import torchvision.models as models
import yaml
from joblib import delayed, Parallel

import cv2
import librosa
import audioread
import soundfile as sf

import numpy as np
import pandas as pd

from fastprogress import progress_bar
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair
import torch.utils.data as data

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

### define utilities

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
#     torch.backends.cudnn.deterministic = True  # type: ignore
#     torch.backends.cudnn.benchmark = True  # type: ignore
    

@contextmanager
def timer(name: str) -> None:
    """Timer Util"""
    t0 = time.time()
    #print("[{}] start".format(name))
    yield
    #print("[{}] done in {:.0f} s".format(name, time.time() - t0))

In [None]:
# logger = get_logger("main.log")
set_seed(1213)

### read data

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
# TRAIN_RESAMPLED_AUDIO_DIRS = [
#   INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)
# ]
TEST_AUDIO_DIR = RAW_DATA / "test_audio"

In [None]:
train = pd.read_csv(RAW_DATA / "train.csv")

In [None]:
if not TEST_AUDIO_DIR.exists():
    TEST_AUDIO_DIR = INPUT_ROOT / "birdcall-check" / "test_audio"
    test = pd.read_csv(INPUT_ROOT / "birdcall-check" / "test.csv")
else:
    test = pd.read_csv(RAW_DATA / "test.csv")

In [None]:
train.head()

In [None]:
test.head()

In [None]:
sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
sub.to_csv("submission.csv", index=False)  # this will be overwritten if everything goes well

### set parameters

In [None]:
TARGET_SR = 32000
model_config = {
    "base_model_name": "resnest50_fast_1s1x64d",
    "pretrained": False,
    "num_classes": 264,
    "trained_weights": "../input/bird-seed-v2/bird.pth"
}

melspectrogram_parameters = {
    "n_mels": 155,
    "fmin": 0,
    "fmax": 16000,
    "n_fft": 1024,
    "hop_length": 256
    

}



## Definition

### Dataset

For `site_3`, I decided to use the same procedure as I did for `site_1` and `site_2`, which is, crop 5 seconds out of the clip and provide prediction on that short clip.
The only difference is that I crop 5 seconds short clip from start to the end of the `site_3` clip and aggeregate predictions for each short clip after I did prediction for all those short clips.

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}
'''

BIRD_CODE = {
    'amegfi': 0, 'blujay': 1, 'horlar': 2, 'norcar': 3, 'mallar3': 4,
    'bkhgro': 5, 'yerwar': 6, 'orcwar': 7, 'norwat': 8, 'carwre': 9,
    'normoc': 10, 'marwre': 11, 'houwre': 12, 'barswa': 13, 'eastow': 14,
    'easmea': 15, 'winwre3': 16, 'foxspa': 17, 'sonspa': 18, 'amered': 19,
    'scoori': 20, 'boboli': 21, 'tuftit': 22, 'bkcchi': 23, 'bulori': 24,
    'comred': 25, 'houspa': 26, 'brespa': 27, 'linspa': 28, 'swathr': 29,
    'wesmea': 30, 'woothr': 31, 'chswar': 32, 'eucdov': 33, 'brncre': 34,
    'norfli': 35, 'comyel': 36, 'wewpew': 37, 'cangoo': 38, 'indbun': 39,
    'redcro': 40, 'haiwoo': 41, 'ruckin': 42, 'houfin': 43, 'spotow': 44,
    'stejay': 45, 'hoowar': 46, 'chispa': 47, 'astfly': 48, 'amecro': 49,

}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}
'''

In [None]:
'''def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    """
    Code from https://www.kaggle.com/daisukelab/creating-fat2019-preprocessed-data
    """
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class TestDataset(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * 5):
                    break
                start = end
                end = end + SR * 5
                
                melspec = librosa.feature.melspectrogram(y_batch,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
                melspec = librosa.power_to_db(melspec).astype(np.float32)
                
                #melspec = librosa.pcen(melspec, sr=32000, hop_length=melspectrogram_parameters['hop_length'])
                
                print('this is the melspec shape ',melspec.shape)
                image = mono_to_color(melspec)
                height, width, _ = image.shape
                #image = cv2.resize(image,(self.img_size, self.img_size))
                image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                image = np.flipud(image)
                image = np.moveaxis(image, 2, 0)
                image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y, sr=SR, **self.melspectrogram_parameters)
            melspec = librosa.power_to_db(melspec).astype(np.float32)
            
            #melspec = librosa.pcen(melspec, sr=32000, hop_length=melspectrogram_parameters['hop_length'])

            image = mono_to_color(melspec)
            height, width, _ = image.shape
            #image = cv2.resize(image,(self.img_size, self.img_size))
            image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
            image = np.flipud(image)
            image = np.moveaxis(image, 2, 0)
            image = (image / 255.0).astype(np.float32)

            return image, row_id, site'''

In [None]:
'''def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    """
    Code from https://www.kaggle.com/daisukelab/creating-fat2019-preprocessed-data
    """
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class TestDataset(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end   =  start + 32000 * 1 # 1 sec
            images = []
            
            melspec = librosa.feature.melspectrogram(y,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
            melspec_base = librosa.power_to_db(melspec).astype(np.float32)
            print('site 3 mel ',melspec_base.shape)
            
            
            
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != 32000*1:# to break at the last if the time exceeds 5 sec
                    #print('put break')
                    break
                start = end - (32000*1)//2
                end = start + 32000 * 1
                
                melspec = librosa.feature.melspectrogram(y_batch,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
                melspec_base = librosa.power_to_db(melspec).astype(np.float32)
                
                #melspec = librosa.pcen(melspec, sr=32000, hop_length=melspectrogram_parameters['hop_length'])
                

                #print('start ',start)
                #print('end ',end)

                #melspec = melspec_base[:, start : end ]
                #print('this is the melspec shape site 3',melspec.shape)

                image = mono_to_color(melspec_base)
                height, width, _ = image.shape
                #image = cv2.resize(image,(self.img_size, self.img_size))
                image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                #print(image.shape)
                image = np.flipud(image)
                image = np.moveaxis(image, 2, 0)
                image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            #print('this is site 3 images ',images.shape)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y, sr=SR, **self.melspectrogram_parameters)
            melspec_base = librosa.power_to_db(melspec).astype(np.float32)
            #print('this is the melspec shape site 1',melspec.shape)
            images = []
            for iii in range(9):
                    start =  iii * (626//5)//2
                    end   =  start + (626//5)
                    if end > 626:
                        #print("mistake on site 1")
                        break
                    #print('start ',start)
                    #print('end ',end)

                    melspec = melspec_base[:, start : end ]
                    #print('this is the melspec shape 1',melspec.shape)

                    image = mono_to_color(melspec)
                    height, width, _ = image.shape
                    #image = cv2.resize(image,(self.img_size, self.img_size))
                    image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                    image = np.flipud(image)
                    image = np.moveaxis(image, 2, 0)
                    image = (image / 255.0).astype(np.float32)
                    images.append(image)
            #print('this is site images shape ',np.array(images).shape)

            return np.array(images), row_id, site'''

In [None]:
def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    """
    Code from https://www.kaggle.com/daisukelab/creating-fat2019-preprocessed-data
    """
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class TestDataset(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end   =  start + 32000 * 1 # 1 sec
            images = []
            
            melspec = librosa.feature.melspectrogram(y,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
            melspec_base = librosa.power_to_db(melspec).astype(np.float32)
            #print('site 3 mel ',melspec_base.shape)
            images = []
            for iii in range((melspec_base.shape[-1]//125) + (melspec_base.shape[-1]//125) -1):
                    start =  iii * 125//2
                    end   =  start + 125
                    if end > melspec_base.shape[-1]:
                        #print("mistake on site 1")
                        break
                    #print('start ',start)
                    #print('end ',end)

                    melspec = melspec_base[:, start : end ]
                    #print('this is the melspec shape 1',melspec.shape)

                    image = mono_to_color(melspec)
                    height, width, _ = image.shape
                    #image = cv2.resize(image,(self.img_size, self.img_size))
                    image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                    image = np.flipud(image)
                    image = np.moveaxis(image, 2, 0)
                    image = (image / 255.0).astype(np.float32)
                    images.append(image)
            
            
            
            images = np.asarray(images)
            #print('this is site 3 images ',images.shape)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y, sr=SR, **self.melspectrogram_parameters)
            melspec_base = librosa.power_to_db(melspec).astype(np.float32)
            #print('this is the melspec shape site 1',melspec.shape)
            images = []
            for iii in range(9):
                    start =  iii * (626//5)//2
                    end   =  start + (626//5)
                    if end > 626:
                        #print("mistake on site 1")
                        break
                    #print('start ',start)
                    #print('end ',end)

                    melspec = melspec_base[:, start : end ]
                    #print('this is the melspec shape 1',melspec.shape)

                    image = mono_to_color(melspec)
                    height, width, _ = image.shape
                    #image = cv2.resize(image,(self.img_size, self.img_size))
                    image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                    image = np.flipud(image)
                    image = np.moveaxis(image, 2, 0)
                    image = (image / 255.0).astype(np.float32)
                    images.append(image)
            #print('this is site images shape ',np.array(images).shape)

            return np.array(images), row_id, site

### model

* I forked this code from authors' original implementation. [GitHub](https://github.com/zhanghang1989/ResNeSt)

In [None]:
"""
class Autopool(nn.Module):
    def __init__(self, input_size, ):
        super(Autopool, self).__init__()
        self.alpha = nn.Parameter(requires_grad= True)
        self.alpha.data = torch.ones([input_size], dtype=torch.float32, requires_grad= True, device=device)
        self.sigmoid_layer = nn.Sigmoid()
        #self.softmax_layer = nn.Softmax(dim=2)
        
        
    def forward(self,x):
        
        sigmoid_output = self.sigmoid_layer(x)
        alpa_mult_out = torch.mul(sigmoid_output, self.alpha)
        
        
        max_tensor = torch.max(alpa_mult_out,dim = 1)
        max_tensor_unsqueezed = max_tensor.values.unsqueeze(dim=1)
        #print('alpa_mult_out shape ',alpa_mult_out.shape)
        #print('max_tensor shape ',max_tensor.values.shape)
        
        softmax_numerator = torch.exp(alpa_mult_out.sub(max_tensor_unsqueezed))
        
        softmax_den =torch.sum(softmax_numerator, dim = 1)
        softmax_den = softmax_den.unsqueeze(dim=1)
        
        weights  = softmax_numerator/softmax_den
        
        final_out = torch.sum(torch.mul(sigmoid_output, weights),dim = 1)
        return final_out, sigmoid_output

        


class Yuvsub(nn.Module):   
    def __init__(self, args):
        super(Yuvsub, self).__init__()
        self.species = nn.Sequential(
            nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(1024, 264))
        
        
        
    def forward(self, GAP):
        
        #print('gap shape ',GAP.shape)
        #GAP = torch.flatten(GAP, 1)
        #print('gap shape ',GAP.shape)
        spe = self.species(GAP)
        
        #cnt = self.counter(GAP)
        return spe

class Yuvgru(nn.Module):
    def __init__(self, args):
        super(Yuvgru, self).__init__()
        self.gru_layer = torch.nn.GRU(input_size = 1024, hidden_size=args["gru_hidden_size"],num_layers = args["gru_layers"],
                                      dropout=0,bidirectional = args["gru_bidirectional"])
    def forward(self,x):
        Routput, hn = self.gru_layer(x)## not passing the hiddenstate and it will be default to 0
        return Routput, hn

class YuvNet(nn.Module):   
    def __init__(self, args):
        super(YuvNet, self).__init__()

        #print(args["name"] , args["params"]["pretrained"])
        self.model =ResNet(
                        Bottleneck, [3, 4, 6, 3],
                        radix=1, groups=1, bottleneck_width=64,
                        deep_stem=True, stem_width=32, avg_down=True,
                        avd=True, avd_first=True)
        del self.model.fc
        self.model.fc = Yuvsub(args)
        #self.Gru = Yuvgru(args)
        
        
        self.autopool = Autopool(264)
           
    def forward(self, x):
        batch_size, time_steps, C, H, W = x.size()
        c_in = x.view(batch_size * time_steps, C, H, W)
        
        #print('c_in shape ',c_in.shape)
        #print('c_in type ',c_in.dtype)
        spe = self.model(c_in)
        #print('shape of spe ',spe.shape)
        spe = spe.view(batch_size, time_steps, -1)
        final_output, sigmoid_output = self.autopool(spe)
        
        '''print('gap shape ',GAP.shape)
        GAP = torch.flatten(GAP, 1)
        print('gap shape ',GAP.shape)
        spe = self.species(GAP)
        cnt = self.counter(GAP)'''
        #return final_output, sigmoid_output
        return final_output, sigmoid_output


"""

class Autopool(nn.Module):
    def __init__(self, input_size, ):
        super(Autopool, self).__init__()
        self.alpha = nn.Parameter(requires_grad= True)
        self.alpha.data = torch.ones([input_size], dtype=torch.float32, requires_grad= True, device=device)
        self.sigmoid_layer = nn.Sigmoid()        
    def forward(self,x):       
        sigmoid_output = self.sigmoid_layer(x)
        alpa_mult_out = torch.mul(sigmoid_output, self.alpha)        
        max_tensor = torch.max(alpa_mult_out,dim = 1)
        max_tensor_unsqueezed = max_tensor.values.unsqueeze(dim=1)        
        softmax_numerator = torch.exp(alpa_mult_out.sub(max_tensor_unsqueezed))        
        softmax_den =torch.sum(softmax_numerator, dim = 1)
        softmax_den = softmax_den.unsqueeze(dim=1)       
        weights  = softmax_numerator/softmax_den       
        final_out = torch.sum(torch.mul(sigmoid_output, weights),dim = 1)
        return final_out, sigmoid_output

class Yuvsub(nn.Module):   
    def __init__(self, args):
        super(Yuvsub, self).__init__()
        self.species = nn.Sequential(
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(p=0.2),
            nn.Linear(512, 264))
        
        
        
    def forward(self, GAP):
        
        #print('gap shape ',GAP.shape)
        #GAP = torch.flatten(GAP, 1)
        #print('gap shape ',GAP.shape)
        spe = self.species(GAP)
        
        #cnt = self.counter(GAP)
        return spe

class Yuvgru(nn.Module):
    def __init__(self, args):
        super(Yuvgru, self).__init__()
        self.gru_layer = torch.nn.GRU(input_size = 1024, hidden_size=args["gru_hidden_size"],num_layers = args["gru_layers"],
                                      dropout=0,bidirectional = args["gru_bidirectional"])
    def forward(self,x):
        Routput, hn = self.gru_layer(x)## not passing the hiddenstate and it will be default to 0
        return Routput, hn

class YuvNet(nn.Module):   
    def __init__(self, args):
        super(YuvNet, self).__init__()

        #print(args["name"] , args["params"]["pretrained"])
        #self.model =getattr(resnest_torch, args["name"])(pretrained=args["params"]["pretrained"])
        self.model = models.resnet18(pretrained = False)
        del self.model.fc
        self.model.fc = Yuvsub(args)
        #self.Gru = Yuvgru(args)
        
        
        self.autopool = Autopool(264)
           
    def forward(self, x):
        batch_size, time_steps, C, H, W = x.size()
        c_in = x.view(batch_size * time_steps, C, H, W)
        
        #print('c_in shape ',c_in.shape)
        #print('c_in type ',c_in.dtype)
        spe = self.model(c_in)
        #print('shape of spe ',spe.shape)
        spe = spe.view(batch_size, time_steps, -1)
        final_output, sigmoid_output = self.autopool(spe)
        
        '''print('gap shape ',GAP.shape)
        GAP = torch.flatten(GAP, 1)
        print('gap shape ',GAP.shape)
        spe = self.species(GAP)
        cnt = self.counter(GAP)'''
        #return final_output, sigmoid_output
        return final_output, sigmoid_output


In [None]:
device = torch.device("cuda")
def get_model(args: tp.Dict):
    # # get resnest50_fast_1s1x64d
    
    model_1 = YuvNet(args)
    state_dict = torch.load("../input/bird-fold1-150-v1/bird.pth")
    model_1.load_state_dict(state_dict)
    model_1.to(device)
    model_1.eval()
    
    model_2 = YuvNet(args)
    state_dict = torch.load("../input/bird-label-v1/bird.pth")
    model_2.load_state_dict(state_dict)
    model_2.to(device)
    model_2.eval()
    
    model_3 = YuvNet(args)
    state_dict = torch.load("../input/bird-fold-22/bird.pth")
    model_3.load_state_dict(state_dict)
    model_3.to(device)
    model_3.eval()
    
    model_4 = YuvNet(args)
    state_dict = torch.load("../input/bird-light/bird.pth")
    model_4.load_state_dict(state_dict)
    model_4.to(device)
    model_4.eval()
    
    return model_1, model_2 , model_3, model_4

In [None]:
model_1, model_2, model_3, model_4 = get_model(model_config)

In [None]:
model_1

In [None]:
model_2

## Prediction loop

In [None]:
def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model: nn.Module, 
                        mel_params: dict, 
                        threshold=0.5):
    #print(test_df)
    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_1.eval()
    model_2.eval()
    model_3.eval()
    model_4.eval()
    prediction_dict = {}
    for image, row_id, site in progress_bar(loader):
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            #image = image.unsqueeze(0)
            image = image.to(device)
            #print('site 1 image ',image.shape)
            with torch.no_grad():
                prediction1,_ = model_1(image)
                prediction2,_ = model_2(image)
                prediction3,_ = model_3(image)
                prediction4,_ = model_4(image)
                proba1 = prediction1.detach().cpu().numpy().reshape(-1)
                proba2 = prediction2.detach().cpu().numpy().reshape(-1)
                proba3 = prediction3.detach().cpu().numpy().reshape(-1)
                proba4 = prediction4.detach().cpu().numpy().reshape(-1)
            #print('proba1 ',proba1.shape)
            #print('proba2 ',proba2.shape)
            proba = (proba1 + proba2 + proba3 + proba4)/4.
            #print('proba ',proba.shape)
            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()
            #print('this is site 1 labels ',labels)

        else:
            # to avoid prediction on large batch
            #image = image.squeeze(0)
            batch_size = 32
            image = image.squeeze(0)
            whole_size = image.size(0)
            #print('site 3 image sha ', image.size())
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                #print('site 3 batch ',batch.shape)
                #if batch.ndim == 3:
                #    batch = batch.unsqueeze(0)
                #batch = batch.unsqueeze(0)
                batch = batch.unsqueeze(0)
                batch = batch.to(device)
                #print('site 3 batch ',batch.shape)
                '''with torch.no_grad():
                    prediction,_ = model(batch)
                    proba = prediction.detach().cpu().numpy()'''
                with torch.no_grad():
                    prediction1,_ = model_1(batch)
                    prediction2,_ = model_2(batch)
                    proba1 = prediction1.detach().cpu().numpy().reshape(-1)
                    proba2 = prediction2.detach().cpu().numpy().reshape(-1)
                #print('this is site 1 prediciton shape ',proba)
                #print('proba1 ',proba1.shape)
                #print('proba2 ',proba2.shape)
                proba = (proba1 + proba2)/2.
                #print('proba ',proba.shape)
                #print(proba)   
                events = proba >= threshold
                #print('event ',events)
                labels = np.argwhere(events).reshape(-1).tolist()
                for label in labels:
                        all_events.add(label)
                '''
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)'''
                        
            labels = list(all_events)
            #print('this is site 3 labels ',labels)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [None]:
test.head()

In [None]:
test.seconds.max()

In [None]:
def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               model_config: dict,
               mel_params: dict,
               target_sr: int,
               threshold=0.5):
    model = get_model(model_config)
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_id in unique_audio_id:
        with timer(f"Loading {audio_id}"):
            clip, _ = librosa.load(test_audio / (audio_id + ".mp3"),
                                   sr=target_sr,
                                   mono=True,
                                   res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        with timer(f"Prediction on {audio_id}"):
            prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                                  clip=clip,
                                                  model=model,
                                                  mel_params=mel_params,
                                                  threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

## Prediction

In [None]:
submission = prediction(test_df=test,
                        test_audio=TEST_AUDIO_DIR,
                        model_config=model_config,
                        mel_params=melspectrogram_parameters,
                        target_sr=TARGET_SR,
                        threshold=0.4)
submission.to_csv("submission.csv", index=False)

In [None]:
submission

## EOF