## This is submission shows ensemble of our own model, with public score 0.80 and private 0.79 (rank 7-8, Version 6). All dataset including trained models were public.
## We also use a tensorflow model provided by host which rank 2 in private (see https://www.kaggle.com/code/leonshangguan/birdnet-inference). 

In [None]:
!pip install ../input/torchlibrosa/torchlibrosa-0.0.5-py3-none-any.whl > /dev/null

In [None]:
import cv2
import audioread
import logging
import os
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
import random
import time
import warnings

import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

from contextlib import contextmanager
from pathlib import Path
from typing import List
from typing import Optional
from sklearn import metrics

from albumentations.core.transforms_interface import ImageOnlyTransform
from torchlibrosa.stft import LogmelFilterBank, Spectrogram
from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm

import albumentations as A
import albumentations.pytorch.transforms as T

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
    
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger
    
    
@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)


def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


device = get_device()
logger = get_logger("main.log")
set_seed(42)

# define model

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def init_weights(model):
    classname = model.__class__.__name__
    if classname.find("Conv2d") != -1:
        nn.init.xavier_uniform_(model.weight, gain=np.sqrt(2))
        model.bias.data.fill_(0)
    elif classname.find("BatchNorm") != -1:
        model.weight.data.normal_(1.0, 0.02)
        model.bias.data.fill_(0)
    elif classname.find("GRU") != -1:
        for weight in model.parameters():
            if len(weight.size()) > 1:
                nn.init.orghogonal_(weight.data)
    elif classname.find("Linear") != -1:
        model.weight.data.normal_(0, 0.01)
        model.bias.data.zero_()


def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    output = F.interpolate(
        framewise_output.unsqueeze(1),
        size=(frames_num, framewise_output.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)

    return output


class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


class TimmSED(nn.Module):
    def __init__(
        self, 
        base_model_name: str, 
        config=None,
        pretrained=False, 
        num_classes=24, 
        in_channels=1
    ):
        super().__init__()
        
        self.config = config

        self.spec_augmenter = SpecAugmentation(
            time_drop_width=64 // 2, 
            time_stripes_num=2,
            freq_drop_width=8 // 2, 
            freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(self.config.n_mels)

        base_model = timm.create_model(
            base_model_name, 
            pretrained=pretrained, 
            num_classes=0,
            global_pool="",
            in_chans=in_channels,
        )
        
        layers = list(base_model.children())[:-2]
        self.encoder = nn.Sequential(*layers)

        in_features = base_model.num_features

        self.fc1 = nn.Linear(in_features, in_features, bias=True)
        self.att_block = AttBlockV2(
            in_features, num_classes, activation="sigmoid")

        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        
    def forward(self, input_data):
        if self.config.in_channels == 3:
            x = input_data
        else:
            x = input_data[:, [0], :, :] # (batch_size, 1, time_steps, mel_bins)

        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            if random.random() < 0.25:
                x = self.spec_augmenter(x)

        x = x.transpose(2, 3)

        x = self.encoder(x)
        
        # Aggregate in frequency axis
        x = torch.mean(x, dim=2)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        interpolate_ratio = frames_num // segmentwise_output.size(1)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        framewise_logit = interpolate(segmentwise_logit, interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            "framewise_output": framewise_output,
            "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            "clipwise_output": clipwise_output
        }

        return output_dict

# define config

In [None]:
mean = (0.485) # R only for RGB
std = (0.229) # R only for RGB

albu_transforms = {
    'train' : A.Compose([
            A.Normalize(mean, std),
    ]),
    'valid' : A.Compose([
            A.Normalize(mean, std),
    ]),
}


mean2 = (0.485, 0.456, 0.406) # RGB
std2 = (0.229, 0.224, 0.225) # RGB

albu_transforms2 = {
    'train' : A.Compose([
            A.Normalize(mean2, std2),
    ]),
    'valid' : A.Compose([
            A.Normalize(mean2, std2),
    ]),
}

In [None]:
class CFG_eca_nfnet_l0:

    seed = 96000
    train = False

    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}]
    }
    
    duration = 5
    n_mels = 128
    fmin = 50
    fmax = 14000
    n_fft = 1024
    hop_length = 320
    sr = 32000

    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()

   
    base_model_name = "eca_nfnet_l0"
    pretrained = False
    num_classes = 152
    in_channels = 1
    
    ckpt_path = [
        "../input/birdclef-ensemble/eca_nfnet_l0_fold_0.bin",
        "../input/birdclef-ensemble/eca_nfnet_l0_fold_1.bin",
        "../input/birdclef-ensemble/eca_nfnet_l0_fold_2.bin",
        "../input/birdclef-ensemble/eca_nfnet_l0_fold_3.bin",
        "../input/birdclef-ensemble/eca_nfnet_l0_fold_4.bin"
    ]
    
    
class CFG_tf_efficientnetv2_s_in21k:

    seed = 97000
    train = False

    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}]
    }
    
    duration = 5
    n_mels = 64
    fmin = 50
    fmax = 14000
    n_fft = 1024
    hop_length = 320
    sr = 32000

    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()

   
    base_model_name = "tf_efficientnetv2_s_in21k"
    pretrained = False
    num_classes = 152
    in_channels = 1
    
    ckpt_path = [
        "../input/birdclef-ensemble/tf_efficientnetv2_s_in21k_fold_0.bin"
    ]
    
    
class CFG_tf_efficientnetv2_m_in21k:

    seed = 98000
    train = False

    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}]
    }
    
    duration = 5
    n_mels = 64
    fmin = 50
    fmax = 14000
    n_fft = 1024
    hop_length = 320
    sr = 32000

    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()

   
    base_model_name = "tf_efficientnetv2_m_in21k"
    pretrained = False
    num_classes = 152
    in_channels = 1
    
    ckpt_path = [
        "../input/birdclef-ensemble/tf_efficientnetv2_m_in21k_fold_0.bin",
        "../input/birdclef-ensemble/tf_efficientnetv2_m_in21k_fold_1.bin",
        "../input/birdclef-ensemble/tf_efficientnetv2_m_in21k_fold_2.bin",
        "../input/birdclef-ensemble/tf_efficientnetv2_m_in21k_fold_3.bin",
        "../input/birdclef-ensemble/tf_efficientnetv2_m_in21k_fold_4.bin"
    ]
    
class CFG_tf_efficientnet_b0:

    seed = 630
    train = False

    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}]
    }
    
    duration = 5
    n_mels = 256
    fmin = 16
    fmax = 16386
    n_fft = 2048
    hop_length = 512
    sr = 32000

    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()

   
    base_model_name = "tf_efficientnet_b0_ns"
    pretrained = False
    num_classes = 152
    in_channels = 1
    
    ckpt_path = [
        "../input/tf-efficientnet-b0-ns/fold-0_0.8157349896480331.bin",
        "../input/tf-efficientnet-b0-ns/fold-1_0.8130277442702051.bin",
        "../input/tf-efficientnet-b0-ns/fold-2_0.81753840842396.bin",
    ]
    
class CFG_resnext50:

    seed = 630
    train = False

    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}]
    }
    
    duration = 5
    n_mels = 384
    fmin = 300
    fmax = 16000
    n_fft = 2048
    hop_length = 417
    sr = 32000

    target_columns = 'afrsil1 akekee akepa1 akiapo akikik amewig aniani apapan arcter \
                      barpet bcnher belkin1 bkbplo bknsti bkwpet blkfra blknod bongul \
                      brant brnboo brnnod brnowl brtcur bubsan buffle bulpet burpar buwtea \
                      cacgoo1 calqua cangoo canvas caster1 categr chbsan chemun chukar cintea \
                      comgal1 commyn compea comsan comwax coopet crehon dunlin elepai ercfra eurwig \
                      fragul gadwal gamqua glwgul gnwtea golphe grbher3 grefri gresca gryfra gwfgoo \
                      hawama hawcoo hawcre hawgoo hawhaw hawpet1 hoomer houfin houspa hudgod iiwi incter1 \
                      jabwar japqua kalphe kauama laugul layalb lcspet leasan leater1 lessca lesyel lobdow lotjae \
                      madpet magpet1 mallar3 masboo mauala maupar merlin mitpar moudov norcar norhar2 normoc norpin \
                      norsho nutman oahama omao osprey pagplo palila parjae pecsan peflov perfal pibgre pomjae puaioh \
                      reccar redava redjun redpha1 refboo rempar rettro ribgul rinduc rinphe rocpig rorpar rudtur ruff \
                      saffin sander semplo sheowl shtsan skylar snogoo sooshe sooter1 sopsku1 sora spodov sposan \
                      towsol wantat1 warwhe1 wesmea wessan wetshe whfibi whiter whttro wiltur yebcar yefcan zebdov'.split()

   
    base_model_name = "resnext50_32x4d"
    pretrained = False
    num_classes = 152
    in_channels = 3
    
    ckpt_path = [
        "../input/resnext50/resnext50_fold-0_0.7777777777777778.bin",
#         "../input/resnext50/resnext50_fold-1_0.8084007574453433.bin",
#         "../input/resnext50/resnext50_fold-2_0.8004904536696444.bin",
#         "../input/resnext50/resnext50_fold-3_0.7966936334857543.bin",
        "../input/resnext50/resnext50_fold-1_0.8105864037363778.bin",
        "../input/resnext50/resnext50_fold-2_0.8004908835904627.bin",
        "../input/resnext50/resnext50_fold-3_0.7987388334209142.bin"
        
    ]

In [None]:
config_eca_nfnet_l0 = CFG_eca_nfnet_l0()
config_tf_efficientnetv2_s_in21k = CFG_tf_efficientnetv2_s_in21k()
config_tf_efficientnetv2_m_in21k = CFG_tf_efficientnetv2_m_in21k()
config_tf_efficientnet_b0 = CFG_tf_efficientnet_b0()
config_resnext50 = CFG_resnext50()

In [None]:
config_enesemble = [
    config_eca_nfnet_l0,
#     config_tf_efficientnetv2_s_in21k,
    config_tf_efficientnetv2_m_in21k,
    config_tf_efficientnet_b0,
    config_resnext50
]

# define dataset

In [None]:
def compute_melspec(y, params):
    """
    Computes a mel-spectrogram and puts it at decibel scale
    Arguments:
        y {np array} -- signal
        params {AudioParams} -- Parameters to use for the spectrogram. Expected to have the attributes sr, n_mels, f_min, f_max
    Returns:
        np array -- Mel-spectrogram
    """
    melspec = librosa.feature.melspectrogram(
        y=y, sr=params.sr, n_mels=params.n_mels, n_fft=params.n_fft, hop_length=params.hop_length, fmin=params.fmin, fmax=params.fmax,
    )

    return melspec


def mono_to_color(X, eps=1e-6, mean=None, std=None):
    """
    Converts a one channel array in [0, 255]
    Arguments:
        X {numpy array [H x W]} -- 2D array to convert
    Keyword Arguments:
        eps {float} -- To avoid dividing by 0 (default: {1e-6})
        mean {None or np array} -- Mean for normalization (default: {None})
        std {None or np array} -- Std for normalization (default: {None})
    Returns:
        numpy array [1 x H x W] -- RGB numpy array
    """
    # X = np.stack([X, X, X], axis=-1)
    X = np.expand_dims(X, axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)

    # Normalize to [0, 255]
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

def mono_to_color3(X, eps=1e-6, mean=None, std=None):
    """
    Converts a one channel array to a 3 channel one in [0, 255]
    Arguments:
        X {numpy array [H x W]} -- 2D array to convert
    Keyword Arguments:
        eps {float} -- To avoid dividing by 0 (default: {1e-6})
        mean {None or np array} -- Mean for normalization (default: {None})
        std {None or np array} -- Std for normalization (default: {None})
    Returns:
        numpy array [3 x H x W] -- RGB numpy array
    """
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)

    # Normalize to [0, 255]
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

In [None]:
class TestDataset(torchdata.Dataset):
    def __init__(self, 
                 df: pd.DataFrame, 
                 clip: np.ndarray,
                 config=None,
                ):
        
        self.df = df
        self.clip = clip
        self.config = config

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):

        sample = self.df.loc[idx, :]
        row_id = sample.row_id

        end_seconds = int(sample.seconds)
        start_seconds = int(end_seconds - 5)
        
        y = self.clip[self.config.sr * start_seconds : self.config.sr * end_seconds].astype(np.float32)
        
        image = compute_melspec(y, self.config)
        image = librosa.power_to_db(image.astype(np.float32), ref=np.max)
        
        if config.in_channels == 3:
            image = mono_to_color3(image)
            image = image.astype(np.uint8)
            image = albu_transforms2['valid'](image=image)['image'].T
        else:
            image = mono_to_color(image)
            image = image.astype(np.uint8)
            image = albu_transforms['valid'](image=image)['image'].T
#         image = image.astype(np.uint8)

#         image = albu_transforms['valid'](image=image)['image'].T
        
#         print(image.shape)
            
        return {
            "image": image,
            "row_id": row_id,
        }

# get models

In [None]:
models = []

for config in config_enesemble:
    
    config_models = []
    
    for ckpt_path in config.ckpt_path:
    
        model = TimmSED(
            base_model_name=config.base_model_name,
            config=config,
            pretrained=config.pretrained,
            num_classes=config.num_classes,
            in_channels=config.in_channels
        )

        model.to(device)
        model.load_state_dict(torch.load(ckpt_path, map_location='cuda:0'))
        model.eval()
        
        config_models.append(model)
    
    if config.base_model_name=="resnext50_32x4d":
        config.base_model_name = "tf_efficientnet_b0_ns"
        for ckpt_path in ["../input/eb0-final/fold-4_0.7996463306808134.bin"]:
    
            model = TimmSED(
                base_model_name=config.base_model_name,
                config=config,
                pretrained=config.pretrained,
                num_classes=config.num_classes,
                in_channels=config.in_channels
            )

            model.to(device)
            model.load_state_dict(torch.load(ckpt_path, map_location='cuda:0'))
            model.eval()

            config_models.append(model)
        
    models.append((config, config_models))

# prediction

In [None]:
all_audios = list(Path("../input/birdclef-2022/test_soundscapes/").glob("*.ogg"))#*10

In [None]:
sample_submission = pd.read_csv('../input/birdclef-2022/sample_submission.csv')

sample_submission.head()

In [None]:
def prediction_for_clip(
    test_df: pd.DataFrame, 
    clip: np.ndarray, 
    models, 
    threshold=0.05, 
):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # inference
    prediction_dict = {}
    
    for config_models in models:
        
        config, models = config_models[0], config_models[1]
        
        dataset = TestDataset(
            df=test_df, 
            clip=clip,
            config=config,
        )
        
        loader = torchdata.DataLoader(
            dataset, 
            batch_size=12, 
            num_workers=8,
            drop_last=False,
            shuffle=False,
            pin_memory=True
        )
        
        for data in tqdm(loader):
            
            row_ids = data['row_id']
            
            for row_id in row_ids:
                if row_id not in prediction_dict:
                    prediction_dict[str(row_id)] = []
            
            image = data['image'].to(device)
                
            probas = []
            
            for model in models:

                with torch.no_grad():
                    output = model(image)
#                     print(output['framewise_output'].shape)
                    
#                 for out_ in output['framewise_output']:
                for row_id_idx, row_id in enumerate(row_ids):
#                         print(row_ids, out_==output['framewise_output'][[row_id_idx]][0,:,:], torch.max(output['framewise_output'][[row_id_idx]], 1).values.detach().cpu().numpy().reshape(-1), torch.max(out_[[row_id_idx]], 0).values.detach().cpu().numpy().reshape(-1))
                    prediction_dict[str(row_id)].append(torch.max(output['framewise_output'][[row_id_idx]], 1).values.detach().cpu().numpy().reshape(-1))
#                     prediction_dict[str(row_id)].append(output['clipwise_output'][[row_id_idx]].detach().cpu().numpy().reshape(-1))
#                     prediction_dict[str(row_id)].append(torch.max(output['framewise_output'][[row_id_idx]], 1).values.detach().cpu().numpy().reshape(-1))
    
    # ensemble
    for row_id in list(prediction_dict.keys()):
                
        logits = np.array(prediction_dict[row_id]).mean(0)
        events = logits >= threshold
        labels = np.argwhere(events).reshape(-1).tolist()
            
        if len(labels) == 0:
            prediction_dict[row_id] = {}
            
        else:
            prediction_dict[row_id] = {}
            
            for label in labels:
    
                prediction_dict[row_id][config.target_columns[label]] = logits[label]
             
    return prediction_dict

In [None]:
def crop_or_pad(y, length, sr, train=True, probs=None):
    """
    Crops an array to a chosen length
    Arguments:
        y {1D np array} -- Array to crop
        length {int} -- Length of the crop
        sr {int} -- Sampling rate
    Keyword Arguments:
        train {bool} -- Whether we are at train time. If so, crop randomly, else return the beginning of y (default: {True})
        probs {None or numpy array} -- Probabilities to use to chose where to crop (default: {None})
    Returns:
        1D np array -- Cropped array
    """
    if len(y) <= length:
        y = np.concatenate([y, np.zeros(length - len(y))])
    else:
        if not train:
            start = 0
        elif probs is None:
            start = np.random.randint(len(y) - length)
        else:
            start = (
                    np.random.choice(np.arange(len(probs)), p=probs) + np.random.random()
            )
            start = int(sr * (start))

        y = y[start: start + length]

    return y.astype(np.float32)

In [None]:
def prediction(
    test_audios,
    threshold=0.05, 
):
    
    warnings.filterwarnings("ignore")
    
    prediction_dicts = {}
    for audio_path in test_audios:
        with timer(f"Loading {str(audio_path)}", logger):
            clip, _ = librosa.load(audio_path, sr=32000)
            
        clip = crop_or_pad(clip, 32000*60, 32000, False)
        
        seconds = [i for i in range(5, 65, 5)]
        name_ = "_".join(audio_path.name.split(".")[:-1])
        row_ids = [name_+f"_{second}" for second in seconds]
            
        test_df = pd.DataFrame({
            "row_id": row_ids,
            "seconds": seconds
        })
        
        with timer(f"Prediction on {audio_path}", logger):
            
            prediction_dict = prediction_for_clip(
                test_df,
                clip=clip,
                models=models,
                threshold=threshold, 
            )

        prediction_dicts.update(prediction_dict)
        
    return prediction_dicts

In [None]:
threshold = 0.015 # N/A

prediction_dicts = prediction(
    test_audios=all_audios,
    threshold=threshold, 
)

In [None]:
import copy

def post_processing(prediction_dicts):

    processed_prediction_dicts = copy.deepcopy(prediction_dicts)

    for row_id in prediction_dicts.keys():

        second = int(row_id.split("_")[-1])
        prev_second = second - 5
        next_second = second + 5

        prev_row_id = "_".join(row_id.split("_")[:-1] + [str(prev_second)])
        next_row_id = "_".join(row_id.split("_")[:-1] + [str(next_second)])

        if prev_row_id in prediction_dicts:
            for bird in prediction_dicts[prev_row_id].keys():
                if bird in prediction_dicts[row_id].keys():
                    processed_prediction_dicts[row_id][bird] += prediction_dicts[prev_row_id][bird]
                else:
                    processed_prediction_dicts[row_id][bird] = prediction_dicts[prev_row_id][bird]
                    
        if next_row_id in prediction_dicts:
            for bird in prediction_dicts[next_row_id].keys():
                if bird in prediction_dicts[row_id].keys():
                    processed_prediction_dicts[row_id][bird] += prediction_dicts[next_row_id][bird]
                else:
                    processed_prediction_dicts[row_id][bird] = prediction_dicts[next_row_id][bird]
                    
                    
#         if prev_row_id in prediction_dicts and next_row_id in prediction_dicts:
#             for bird in prediction_dicts[prev_row_id].keys():
#                 if bird in prediction_dicts[row_id].keys():
#                     processed_prediction_dicts[row_id][bird] += prediction_dicts[prev_row_id][bird]
#                 else:
#                     processed_prediction_dicts[row_id][bird] = prediction_dicts[prev_row_id][bird]
                    
#                 if bird in prediction_dicts[next_row_id].keys():
#                     processed_prediction_dicts[row_id][bird] += prediction_dicts[next_row_id][bird]
#                 else:
#                     print(prediction_dicts[next_row_id][bird])
#                     processed_prediction_dicts[row_id][bird] = prediction_dicts[next_row_id][bird]
                    
#         elif prev_row_id in prediction_dicts and next_row_id not in prediction_dicts:
#             for bird in prediction_dicts[row_id].keys():
#                 if bird in prediction_dicts[prev_row_id].keys():
#                     processed_prediction_dicts[row_id][bird] += prediction_dicts[prev_row_id][bird]
#                 else:
#                     print(prediction_dicts[prev_row_id][bird])
#                     processed_prediction_dicts[row_id][bird] = prediction_dicts[prev_row_id][bird]
                    
#         elif prev_row_id not in prediction_dicts and next_row_id in prediction_dicts:
#             for bird in prediction_dicts[row_id].keys():
#                 if bird in prediction_dicts[next_row_id].keys():
#                     processed_prediction_dicts[row_id][bird] += prediction_dicts[next_row_id][bird]
#                 else:
#                     print(prediction_dicts[next_row_id][bird])
#                     processed_prediction_dicts[row_id][bird] = prediction_dicts[next_row_id][bird]

#         if next_row_id in prediction_dicts:
#             for bird in prediction_dicts[next_row_id].keys():
#                 if bird in prediction_dicts[row_id].keys():
#                     processed_prediction_dicts[row_id][bird] += prediction_dicts[next_row_id][bird]
#     print(processed_prediction_dicts)
    reformat_prediction_dicts = {}            
    
    for row_id in processed_prediction_dicts.keys():

        if len(processed_prediction_dicts[row_id]) == 0:
            reformat_prediction_dicts[row_id] = "nocall"

        else:

            bird_logit_pairs = []

            for bird in processed_prediction_dicts[row_id].keys():
                bird_logit_pairs.append((bird, processed_prediction_dicts[row_id][bird]))

            bird_logit_pairs = sorted(bird_logit_pairs, key=lambda x : x[1], reverse=True)[:min(5, len(bird_logit_pairs))]
            
            voted_birds = [bird_logit_pair[0] for bird_logit_pair in bird_logit_pairs]
            orig_birds = list(prediction_dicts[row_id].keys())
            reformat_prediction_dicts[row_id] = " ".join(list(set(voted_birds+orig_birds)))
#             orig_birds = list(prediction_dicts[row_id].keys())
# #             print(voted_birds,orig_birds)
#             reformat_prediction_dicts[row_id] = " ".join(list(orig_birds))
            
    return reformat_prediction_dicts
            
post_processed_prediction_dicts = post_processing(prediction_dicts)

In [None]:
post_processed_prediction_dicts

In [None]:
for i in range(len(sample_submission)):
    
    sample = sample_submission.row_id[i]
    key = sample.split("_")[0] + "_" + sample.split("_")[1] + "_" + sample.split("_")[3]
    target_bird = sample.split("_")[2]

    if key in post_processed_prediction_dicts:
        sample_submission.iat[i, 1] = (target_bird in post_processed_prediction_dicts[key])
        
sample_submission.to_csv("submission.csv", index=False)

In [None]:
sample_submission.head()