In [None]:
!pip uninstall torch --y
!pip uninstall torchaudio --y
!pip install ../input/pytorch-160-with-torchvision-070/torch-1.6.0cu101-cp37-cp37m-linux_x86_64.whl
!pip install ../input/torchaudio/torchaudio-0.6.0-cp37-cp37m-manylinux1_x86_64.whl

In [None]:
!pip install ../input/timm-wheel/*.whl

### import libraries

In [None]:
import os
import gc
import time
import math
import shutil
import random
import warnings
warnings.filterwarnings('ignore',".*PySoundFile")
import os
from joblib import delayed, Parallel

import cv2
import librosa
import audioread
import soundfile as sf

import numpy as np
import pandas as pd

from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader
import timm

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

### read data

In [None]:
RAW_DATA = "../input/birdsong-recognition/"
TRAIN_AUDIO_DIR = "../input/birdsong-recognition/train_audio/"
TEST_AUDIO_DIR = "../input/birdsong-recognition/test_audio/"

In [None]:
train = pd.read_csv(RAW_DATA + "train.csv")

In [None]:
if not os.path.exists(TEST_AUDIO_DIR):
    print("Run check")
    TEST_AUDIO_DIR = "../input/birdcall-check/test_audio/"
    test = pd.read_csv("../input/birdcall-check/test.csv")
else:
    print("run submission")
    test = pd.read_csv("../input/birdsong-recognition/test.csv")

### set parameters

## Definition

### Dataset

For `site_3`, I decided to use the same procedure as I did for `site_1` and `site_2`, which is, crop 5 seconds out of the clip and provide prediction on that short clip.
The only difference is that I crop 5 seconds short clip from start to the end of the `site_3` clip and aggeregate predictions for each short clip after I did prediction for all those short clips.

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263,'nocall':-1,
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
class Dataset_Test:
    def __init__(
            self,
            df,
            data_dir,
            sample_rate=32000,
            duration=5,
            waveform_transforms=None):
        self.data_dir=data_dir
        self.clip_frames=sample_rate*duration
        self.df=df
        self.sample_rate=sample_rate
        self.duration=duration
        self.audio_ids=df.audio_id.unique()
    def __len__(self):
        return len(self.audio_ids)

    def __getitem__(self, idx: int):
        audio_id=self.audio_ids[idx]
        clip, _ = librosa.load(self.data_dir+'{}.mp3'.format(audio_id),
                                sr=self.sample_rate,
                                mono=True,
                                res_type="kaiser_fast")
        clip=clip.astype(np.float32)
        total_frames=len(clip)
        target_df=self.df[self.df.audio_id==audio_id]
        all_5s_clips=[]
        row_ids=[]
        site=target_df.site.iloc[0]
        
        if site=='site_3':
            startpoints=list(range(0,total_frames,self.clip_frames))
            if (len(startpoints)>1 and total_frames-startpoints[-1]<= 2.5*self.sample_rate ):
                startpoints.pop(-1)
            for start in startpoints:
                clip_5s=clip[start:min(total_frames,start+self.clip_frames)]
                if len(clip_5s)<self.clip_frames:
                    pad0=(self.clip_frames-len(clip_5s))//2
                    pad1=self.clip_frames-len(clip_5s)-pad0
                    clip_5s=np.pad(clip_5s,[pad0,pad1],constant_values=0)
                all_5s_clips.append(clip_5s)
            row_ids.append("site_3_{}".format(audio_id))
        else:
            for i in range(len(target_df)):
                row=target_df.iloc[i]
                end=int(row.seconds*self.sample_rate)
                clip_5s=clip[max(0,end-self.clip_frames):min(total_frames,end)]
                if len(clip_5s)<self.clip_frames:
                    pad0=(self.clip_frames-len(clip_5s))//2
                    pad1=self.clip_frames-len(clip_5s)-pad0
                    clip_5s=np.pad(clip_5s,[pad0,pad1],constant_values=0)
                all_5s_clips.append(clip_5s)
                row_ids.append("{}_{}_{}".format(site,audio_id,int(row.seconds)))
        all_5s_clips=np.stack(all_5s_clips,axis=0)
        return torch.tensor(all_5s_clips),row_ids,site

Model

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x


class AttBlock(nn.Module):
    '''
    Input: [b,c,t//stride]
    '''
    def __init__(self, n_in, n_out, activation='linear', hidden=512,dropout=0):
        super(AttBlock, self).__init__()
        att_module = [
            nn.Conv1d(n_in, hidden, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Tanh(),
        ]
        if dropout>0:
            att_module.append(nn.Dropout(dropout))
        att_module.append(nn.Conv1d(hidden,1,kernel_size=1,stride=1,padding=0,bias=True)) #[b,1,t]
        self.att=nn.Sequential(*att_module)
        self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) #[b,class,t]
        self.nonlinear_transform=nn.Sigmoid() if activation=='sigmoid' else Identity()

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        A=self.att(x)
        A=torch.softmax(A,dim=-1)
        instance_pred = self.nonlinear_transform(self.cla(x))
        x = torch.sum(A * instance_pred, dim=2)
        return x, A, instance_pred
    

def interpolate(x, ratio):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate

    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output, frames_num):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.

    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad

    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    pad = framewise_output[:, -1:, :].repeat(1, frames_num - framewise_output.shape[1], 1)
    """tensor for padding"""

    output = torch.cat((framewise_output, pad), dim=1)
    """(batch_size, frames_num, classes_num)"""

    return output

class SED_Model_ATT(nn.Module):
    def __init__(self,arch,num_class=264,dropout=0.4,dropout_att=0.25,**kwargs):
        super().__init__()
        self.base_model=timm.create_model(arch,pretrained=False,**kwargs)
        feature_dim=self.base_model.classifier.in_features
        self.interpolate_ratio=32
        self.dropout = nn.Dropout(p=dropout)
        self.att=AttBlock(feature_dim,num_class,activation='sigmoid',dropout=dropout_att)

    def forward(self,x):
        frames_num = x.size(3)
 
        x=self.base_model.forward_features(x) #[b,c,Mel,T]
        x=x.mean(dim=2)
        x=self.dropout(x) #[b,c,T]


        clipwise_output,att,segmentwise_output=self.att(x.float())
        clipwise_output=torch.clamp(clipwise_output,0.0,1.0)

        segmentwise_output=segmentwise_output.transpose(1,2)
        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        #print(frames_num,framewise_output.shape)
        #framewise_output = pad_framewise_output(framewise_output, frames_num)

        output_dict = {'framewise_output': framewise_output,
                        'clipwise_output': clipwise_output,
                       'att':att}
        return output_dict

In [None]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(
            self.eps) + ')'

class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None,gem=False):
        super().__init__()
        sz = sz or (1,1)
        self.ap = GeM() if gem else nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)

    def forward(self, x):
        return torch.cat([self.ap(x), self.mp(x)], 1).view(x.size(0),-1)

class AttMaxPool(nn.Module):
    '''
    Input: [b,c,t//stride]
    '''
    def __init__(self, n_in,hidden=512,dropout=0,drop_time=0.):
        super().__init__()
        att_module = [
            nn.Conv1d(n_in, hidden, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Tanh(),
        ]
        if dropout>0:
            att_module.append(nn.Dropout(dropout))

        att_module.append(nn.Conv1d(hidden,1,kernel_size=1,stride=1,padding=0,bias=True)) #[b,1,t]
        self.att=nn.Sequential(*att_module)
        self.drop_time=drop_time

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        mp=torch.max(x,dim=2)[0]
        A=self.att(x)
        A=torch.softmax(A,dim=-1)
        if self.drop_time>0 and self.training:
            A=nn.functional.dropout(A,p=self.drop_time,training=self.training)
            A/=A.sum(dim=2,keepdim=True)
            print("drop")
        x = torch.sum(A * x, dim=2)
        x = torch.cat([x,mp],dim=1)
        return x

class ModelWarper(nn.Module):
    def __init__(self,arch,dropout=0.5,num_class=264):
        super().__init__()
        self.basemodel=timm.create_model(arch,pretrained=False)

        features=self.basemodel.get_classifier().in_features

        self.pooling=AdaptiveConcatPool2d()
        self.dropout=nn.Dropout(dropout)
        self.fc=nn.Linear(2*features,out_features=num_class)

    def forward(self,x):
        x=self.basemodel.forward_features(x)
        x=self.pooling(x)
        x=x.view(x.size(0),-1)
        x=self.dropout(x)
        x=self.fc(x)
        return x
    
class ModelWarper_Att(nn.Module):
    def __init__(self,arch,dropout=0.5,dropout_att=0.25,num_class=264,drop_time=0,**kwargs):
        super().__init__()
        if arch=='PANN_res54':
            from .pann_backbone import ResNet54
            self.basemodel=ResNet54()
            ckpt=torch.load("./model/ResNet54.pth")
            self.basemodel.load_state_dict(ckpt['model'],strict=False)
            features=2048
        else:
            self.basemodel=timm.create_model(arch,pretrained=False,**kwargs)
            features=self.basemodel.get_classifier().in_features

        self.pooling=AttMaxPool(features,dropout=dropout_att,drop_time=drop_time)
        self.dropout=nn.Dropout(dropout)
        self.fc=nn.Linear(2*features,out_features=num_class)

    def forward(self,x):
        x=self.basemodel.forward_features(x)
        #print(x.shape)
        x=x.mean(dim=2)
        x=self.pooling(x)
        #x=x.view(x.size(0),-1)
        x=self.dropout(x)
        x=self.fc(x)
        return x

In [None]:
class SpecNorm(nn.Module):
    def __init__(self,eps=1e-6):
        super().__init__()
        self.eps=eps

    def forward(self,x):
        b,c,m,t=x.shape
        x=x.view(b*c,-1) #[b*c,mel*t]
        _min = x.min(dim=1,keepdim=True)[0]
        _max = x.max(dim=1,keepdim=True)[0]
        x=(x-_min)/(_max-_min+self.eps)
        #x = torch.stack([x, x, x], dim=1).view(b,3,m,t)  # ->[B,3,bins,Time]
        return x.view(b,c,m,t)

class MelTransformer(nn.Module):
    def __init__(self,sample_rate=32000,
                 n_fft=[1024,1024,1024],
                 hop_length=313,
                 n_mels=224,
                 f_min=20,
                 f_max=16000,
                 **kwargs):
        super().__init__()
        self.m1=torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft[0],
            hop_length=hop_length,
            f_min=f_min,
            f_max=f_max,
            n_mels=n_mels,
            **kwargs
        )
        self.m2=torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft[1],
            hop_length=hop_length,
            f_min=f_min,
            f_max=f_max,
            n_mels=n_mels,
            **kwargs
        )
        self.m3=torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft[2],
            hop_length=hop_length,
            f_min=f_min,
            f_max=f_max,
            n_mels=n_mels,
            **kwargs
        )
        self.amp2db=torchaudio.transforms.AmplitudeToDB()
        self.norm=SpecNorm()
    def forward(self,x):
        with torch.no_grad():
            x=torch.stack([self.m1(x),self.m2(x),self.m3(x)],dim=1) #[b,3,m,t]
            x=self.amp2db(x)
            x=self.norm(x)
        return x


import torchaudio
import torch.nn.functional as F

def resize(x,size=(256,512)):
    return F.interpolate(x,size=size,mode='bilinear')

def randomCrop(x,height=224):
    #x [3,h,w]
    start=np.random.randint(0,x.shape[1]-height)
    return x[:,start:start+height,:]

def centerCrop(x,height=224):
    #x [3,h,w]
    start=(x.shape[1]-height)//2
    return x[:,start:start+height,:]

class SpecAugmenter:
    def __init__(self,resize=(256,512),crop_height=224,valid=False):
        self.resize_size=resize
        self.crop_height=crop_height
        self.valid=valid

    @torch.no_grad()
    def __call__(self,specs):
        result=[]
        if self.resize_size:
            if type(self.resize_size) is int:
                specs=resize(specs,(self.resize_size,specs.shape[3]))
            else:
                specs=resize(specs,self.resize_size)
        for x in specs:
            if self.valid:
                if self.crop_height:
                    x=centerCrop(x,self.crop_height)
            else:
                if self.crop_height:
                    x=randomCrop(x,self.crop_height)
                if random.random()<0.5:
                    x=torchaudio.transforms.F.mask_along_axis(x,x.shape[1]//8,0,axis=1)
                if random.random() < 0.5:
                    x=torchaudio.transforms.F.mask_along_axis(x,x.shape[2]//8,0,axis=2)
            result.append(x)
        return torch.stack(result,dim=0)

In [None]:
preprocessor=dict()

mel_config = {
        "sample_rate": 32000,
        "n_fft": [1024,1024,1024],
        "hop_length": 313,
        "n_mels": 128,
        "f_min": 20,
        "f_max": 16000,
        }
transformer = MelTransformer(**mel_config).cuda()
augmenter=SpecAugmenter(valid=True,resize=256,crop_height=None)

preprocessor['standard']={
    'transformer':transformer,
    'augmenter':augmenter
}


mel_config_2 = {
        "sample_rate": 32000,
        "n_fft": [1024,1024,1024],
        "hop_length": 209,
        "n_mels": 128,
        "f_min": 20,
        "f_max": 16000,
        }
transformer = MelTransformer(**mel_config_2).cuda()
augmenter=SpecAugmenter(valid=True,resize=320,crop_height=None)
preprocessor['big']={
    'transformer':transformer,
    'augmenter':augmenter
}


mel_config_3 = {
        "sample_rate": 32000,
        "n_fft": [1024,1024,1024],
        "hop_length": 313,
        "n_mels": 96,
        "f_min": 20,
        "f_max": 16000,
        }
transformer = MelTransformer(**mel_config_3).cuda()
augmenter=SpecAugmenter(valid=True,resize=224,crop_height=None)
preprocessor['small']={
    'transformer':transformer,
    'augmenter':augmenter
}




In [None]:
model_configs=[
    {
        'arch':ModelWarper,
        'preprocess':'standard',
        'path':"../input/bird-resnest50-mp3ft-5fold/resnest50d_1s4x24d_fold0_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    
    {
        'arch':ModelWarper,
        'preprocess':'standard',
        'path':"../input/bird-resnest50-mp3ft-5fold/resnest50d_1s4x24d_fold1_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
        {
        'arch':ModelWarper,
        'preprocess':'standard',
        'path':"../input/bird-resnest50-mp3ft-5fold/resnest50d_1s4x24d_fold2_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
        {
        'arch':ModelWarper,
        'preprocess':'standard',
        'path':"../input/bird-resnest50-mp3ft-5fold/resnest50d_1s4x24d_fold3_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
        {
        'arch':ModelWarper,
        'preprocess':'standard',
        'path':"../input/bird-resnest50-mp3ft-5fold/resnest50d_1s4x24d_fold4_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    {
        'arch':ModelWarper,
        'preprocess':'big',
        'path':"../input/birdcall-bigsize/resnest50d_1s4x24d_fold0_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    {
        'arch':ModelWarper,
        'preprocess':'big',
        'path':"../input/birdcall-bigsize/resnest50d_1s4x24d_fold1_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    {
        'arch':ModelWarper,
        'preprocess':'big',
        'path':"../input/birdcall-bigsize/resnest50d_1s4x24d_fold2_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    {
        'arch':ModelWarper,
        'preprocess':'big',
        'path':"../input/birdcall-bigsize/resnest50d_1s4x24d_fold3_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
    {
        'arch':ModelWarper,
        'preprocess':'big',
        'path':"../input/birdcall-bigsize/resnest50d_1s4x24d_fold4_epoch9.pth",
        'base':"resnest50d_1s4x24d",
    },
        {
        'arch':ModelWarper,
        'preprocess':'small',
        'path':"../input/regnety40-224/regnety_040_fold0_epoch9.pth",
        'base':"regnety_040",
    },
        {
        'arch':ModelWarper,
        'preprocess':'small',
        'path':"../input/regnety40-224/regnety_040_fold1_epoch9.pth",
        'base':"regnety_040",
    },
        {
        'arch':ModelWarper,
        'preprocess':'small',
        'path':"../input/regnety40-224/regnety_040_fold2_epoch9.pth",
        'base':"regnety_040",
    },
        {
        'arch':ModelWarper,
        'preprocess':'small',
        'path':"../input/regnety40-224/regnety_040_fold3_epoch9.pth",
        'base':"regnety_040",
    },
]

In [None]:
modellist=[]
pplist=[]
for config in model_configs:
    model=config['arch'](config['base'],num_class=264)
    print("Loading",config['path'])
    ckpt=torch.load(config['path'])['state_dict']
    model.load_state_dict(ckpt)
    model.cuda()
    model.eval()
    modellist.append(model)    
    pplist.append(config['preprocess'])

print(len(modellist))
print(pplist)

In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
dataset=Dataset_Test(
        test,
        TEST_AUDIO_DIR,
        duration=5,
        sample_rate=32000,
    )
dataloader=DataLoader(dataset,batch_size=1,shuffle=False,num_workers=2,drop_last=False,pin_memory=True)

## Prediction loop

In [None]:
THRESHOLD=0.5
MAX_PRED={
    'site_1':3,
    'site_2':3,
    'site_3':10
}

In [None]:
rows=[]
result=[]
for clips,row_ids,site in tqdm(dataloader):
    clips=clips.cuda()[0]
    batch_output=0.
    specs={}
    specs_shift1={}
    specs_shift2={}
    
    shifted1=torch.zeros(clips.shape).float().cuda()  #[k,n]
    shifted1[:,:-155]=clips[:,155:]
    
    shifted2=torch.zeros(clips.shape).float().cuda()  #[k,n]
    shifted2[:,155:]=clips[:,:-155]
    
    with torch.no_grad():
        for k,v in preprocessor.items():
            spec=v['transformer'](clips)
            spec=v['augmenter'](spec)
            specs[k]=spec
            specs_shift1[k]=v['augmenter'](v['transformer'](shifted1))
            specs_shift2[k]=v['augmenter'](v['transformer'](shifted2))
        for model,p in zip(modellist,pplist):
            #print(specs[p].shape)
            output = torch.sigmoid(model(specs[p]))
            #output = output['clipwise_output']
            output_shift1=torch.sigmoid(model(specs_shift1[p]))
            output_shift2=torch.sigmoid(model(specs_shift2[p]))
            
            batch_output+=((output+output_shift1+output_shift2)/3)**0.5
        
        batch_output/=len(modellist)
        #batch_output=batch_output**0.5
        batch_output = batch_output.cpu().numpy()  #[n,264]
        #print(output)
        if site[0]=='site_3':
            batch_output=np.max(batch_output,axis=0,keepdims=True) #[1,264]
            
        #binary_preds=output>THRESHOLD  #[n,264]
        print(batch_output[:,0])
        for clip_pred in batch_output:
            binary_pred=clip_pred>THRESHOLD
            pos_idx=np.where(binary_pred)[0]
            pos_prob=clip_pred[pos_idx]
            sort=np.argsort(pos_prob)[::-1]
            pos_idx=pos_idx[sort[:min(len(sort),MAX_PRED[site[0]])]]
            #print(pos_idx)
            if len(pos_idx)==0:
                result.append("nocall")
            else:
                result.append(" ".join([INV_BIRD_CODE[code] for code in pos_idx]))    
    for r in row_ids:
        rows.extend(r)
        

In [None]:
#result=np.concatenate(result)
submission_df=pd.DataFrame(
    {
        "row_id": rows,
        "birds": result
    }
)

In [None]:
submission_df

In [None]:
submission_df.to_csv("submission.csv",index=False)