In [1]:
import glob
import os
import random
import sys
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import librosa
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from tqdm.auto import tqdm

import concurrent.futures


warnings.filterwarnings("ignore")

In [2]:
print(torch.__version__)
print(timm.__version__)

1.13.0+cpu
0.6.12


In [3]:
seconds = [i for i in range(5, 605, 5)]


class CFG:
    sample_rate = 32000

    target_columns = [
        'abethr1', 'abhori1', 'abythr1', 'afbfly1', 'afdfly1', 'afecuc1',
        'affeag1', 'afgfly1', 'afghor1', 'afmdov1', 'afpfly1', 'afpkin1',
        'afpwag1', 'afrgos1', 'afrgrp1', 'afrjac1', 'afrthr1', 'amesun2',
        'augbuz1', 'bagwea1', 'barswa', 'bawhor2', 'bawman1', 'bcbeat1',
        'beasun2', 'bkctch1', 'bkfruw1', 'blacra1', 'blacuc1', 'blakit1',
        'blaplo1', 'blbpuf2', 'blcapa2', 'blfbus1', 'blhgon1', 'blhher1',
        'blksaw1', 'blnmou1', 'blnwea1', 'bltapa1', 'bltbar1', 'bltori1',
        'blwlap1', 'brcale1', 'brcsta1', 'brctch1', 'brcwea1', 'brican1',
        'brobab1', 'broman1', 'brosun1', 'brrwhe3', 'brtcha1', 'brubru1',
        'brwwar1', 'bswdov1', 'btweye2', 'bubwar2', 'butapa1', 'cabgre1',
        'carcha1', 'carwoo1', 'categr', 'ccbeat1', 'chespa1', 'chewea1',
        'chibat1', 'chtapa3', 'chucis1', 'cibwar1', 'cohmar1', 'colsun2',
        'combul2', 'combuz1', 'comsan', 'crefra2', 'crheag1', 'crohor1',
        'darbar1', 'darter3', 'didcuc1', 'dotbar1', 'dutdov1', 'easmog1',
        'eaywag1', 'edcsun3', 'egygoo', 'equaka1', 'eswdov1', 'eubeat1',
        'fatrav1', 'fatwid1', 'fislov1', 'fotdro5', 'gabgos2', 'gargan',
        'gbesta1', 'gnbcam2', 'gnhsun1', 'gobbun1', 'gobsta5', 'gobwea1',
        'golher1', 'grbcam1', 'grccra1', 'grecor', 'greegr', 'grewoo2',
        'grwpyt1', 'gryapa1', 'grywrw1', 'gybfis1', 'gycwar3', 'gyhbus1',
        'gyhkin1', 'gyhneg1', 'gyhspa1', 'gytbar1', 'hadibi1', 'hamerk1',
        'hartur1', 'helgui', 'hipbab1', 'hoopoe', 'huncis1', 'hunsun2',
        'joygre1', 'kerspa2', 'klacuc1', 'kvbsun1', 'laudov1', 'lawgol',
        'lesmaw1', 'lessts1', 'libeat1', 'litegr', 'litswi1', 'litwea1',
        'loceag1', 'lotcor1', 'lotlap1', 'luebus1', 'mabeat1', 'macshr1',
        'malkin1', 'marsto1', 'marsun2', 'mcptit1', 'meypar1', 'moccha1',
        'mouwag1', 'ndcsun2', 'nobfly1', 'norbro1', 'norcro1', 'norfis1',
        'norpuf1', 'nubwoo1', 'pabspa1', 'palfly2', 'palpri1', 'piecro1',
        'piekin1', 'pitwhy', 'purgre2', 'pygbat1', 'quailf1', 'ratcis1',
        'raybar1', 'rbsrob1', 'rebfir2', 'rebhor1', 'reboxp1', 'reccor',
        'reccuc1', 'reedov1', 'refbar2', 'refcro1', 'reftin1', 'refwar2',
        'rehblu1', 'rehwea1', 'reisee2', 'rerswa1', 'rewsta1', 'rindov',
        'rocmar2', 'rostur1', 'ruegls1', 'rufcha2', 'sacibi2', 'sccsun2',
        'scrcha1', 'scthon1', 'shesta1', 'sichor1', 'sincis1', 'slbgre1',
        'slcbou1', 'sltnig1', 'sobfly1', 'somgre1', 'somtit4', 'soucit1',
        'soufis1', 'spemou2', 'spepig1', 'spewea1', 'spfbar1', 'spfwea1',
        'spmthr1', 'spwlap1', 'squher1', 'strher', 'strsee1', 'stusta1',
        'subbus1', 'supsta1', 'tacsun1', 'tafpri1', 'tamdov1', 'thrnig1',
        'trobou1', 'varsun2', 'vibsta2', 'vilwea1', 'vimwea1', 'walsta1',
        'wbgbir1', 'wbrcha2', 'wbswea1', 'wfbeat1', 'whbcan1', 'whbcou1',
        'whbcro2', 'whbtit5', 'whbwea1', 'whbwhe3', 'whcpri2', 'whctur2',
        'wheslf1', 'whhsaw1', 'whihel1', 'whrshr1', 'witswa1', 'wlwwar',
        'wookin1', 'woosan', 'wtbeat1', 'yebapa1', 'yebbar1', 'yebduc1',
        'yebere1', 'yebgre1', 'yebsto1', 'yeccan1', 'yefcan', 'yelbis1',
        'yenspu1', 'yertin1', 'yesbar1', 'yespet1', 'yetgre1', 'yewgre1'
        ]


class TestDataset(torchdata.Dataset):
    def __init__(self, 
                 df: pd.DataFrame, 
                 clip: np.ndarray,
                 config=None,
                ):
        
        self.df = df
        self.clip = clip
        self.sr = CFG.sample_rate
        self.config = config

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):

        sample = self.df.loc[idx, :]
        row_id = sample.row_id

        end_seconds = int(sample.seconds)
        start_seconds = int(end_seconds - 5)
        
        wave = self.clip[self.sr * start_seconds : self.sr * end_seconds].astype(np.float32)
            
        return {
            "row_id": row_id,
            "wave": wave,
            "rating": torch.ones(1),
            "loss_target": torch.ones(1),
            "embedding": torch.rand(264).unsqueeze(0),
        }


def prediction_for_clip(audio_path):
    
    prediction_dict = {}
    
    clip, _ = librosa.load(audio_path, sr=32000)
    name_ = audio_path.split(".ogg")[0].split("/")[-1]
    row_ids = [name_+f"_{second}" for second in seconds]

    test_df = pd.DataFrame({
        "row_id": row_ids,
        "seconds": seconds,
    })
    
    dataset = TestDataset(
        df=test_df, 
        clip=clip,
    )
        
    loader = torchdata.DataLoader(
        dataset,
        batch_size=4, 
        num_workers=os.cpu_count(),
        drop_last=False,
        shuffle=False,
        pin_memory=True
    )
    
    for model in models_ensemble:
        
        for inputs in loader:
            
            row_ids = inputs['row_id']
            inputs.pop('row_id')

            for row_id in row_ids:
                if row_id not in prediction_dict:
                    prediction_dict[str(row_id)] = []
                            
            probas = []

            with torch.no_grad():
                output = model(inputs)
                    
            for row_id_idx, row_id in enumerate(row_ids):
                prediction_dict[str(row_id)].append(output[row_id_idx, :].sigmoid().detach().numpy())
                                                        
    for row_id in list(prediction_dict.keys()):
        logits = np.array(prediction_dict[row_id]).mean(0)
        prediction_dict[row_id] = {}
        for label in range(len(CFG.target_columns)):
            prediction_dict[row_id][CFG.target_columns[label]] = logits[label]

    return prediction_dict


models_ensemble = []
model1 = torch.jit.load(os.path.join(f"/kaggle/input/birdclef2023-4th-models/exp105_eca_nfnet_l0/", f"fold_0_model_jit_bs4.pt"))
model1.eval()
models_ensemble.append(model1)
model2 = torch.jit.load(os.path.join(f"/kaggle/input/birdclef2023-4th-models/exp106_eca_nfnet_l0/", f"fold_3_model_jit_bs4.pt"))
model2.eval()
models_ensemble.append(model2)
model3 = torch.jit.load(os.path.join(f"/kaggle/input/birdclef2023-4th-models/exp107_eca_nfnet_l0/", f"fold_4_model_jit_bs4.pt"))
model3.eval()
models_ensemble.append(model3)
model4 = torch.jit.load(os.path.join(f"/kaggle/input/birdclef2023-4th-models/exp108_eca_nfnet_l0/", f"fold_2_model_jit_bs4.pt"))
model4.eval()
models_ensemble.append(model4)


def main():

    all_audios = list(glob.glob('/kaggle/input/birdclef-2023/test_soundscapes/*.ogg'))

    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        dicts = list(executor.map(prediction_for_clip, all_audios))
    
    prediction_dicts = {}
    for d in dicts:
        prediction_dicts.update(d)
        
    submission = pd.DataFrame.from_dict(prediction_dicts, "index").rename_axis("row_id").reset_index()
    submission.to_csv("submission.csv", index=False)


if __name__ == "__main__":
    main()
