## Libraries

In [None]:
import sys
import cv2
import audioread
import logging
import os
import random
import time
import warnings
import glob
import pdb
import json

import librosa as lb
import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset

from collections import Counter
from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from tqdm import tqdm

pytorch_timm_path = "../input/timm-pytorch-image-models/pytorch-image-models-master"
sys.path.append(pytorch_timm_path)
import timm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

# MelSpec GPU code

In [None]:
class DFTBase(nn.Module):
    def __init__(self):
        r"""Base class for DFT and IDFT matrix.
        """
        super(DFTBase, self).__init__()

    def dft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(-2 * np.pi * 1j / n)
        W = np.power(omega, x * y)  # shape: (n, n)
        return W

    def idft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(2 * np.pi * 1j / n)
        W = np.power(omega, x * y)  # shape: (n, n)
        return W


class DFT(DFTBase):
    def __init__(self, n, norm):
        r"""Calculate discrete Fourier transform (DFT), inverse DFT (IDFT, 
        right DFT (RDFT) RDFT, and inverse RDFT (IRDFT.) 

        Args:
          n: fft window size
          norm: None | 'ortho'
        """
        super(DFT, self).__init__()

        self.W = self.dft_matrix(n)
        self.inv_W = self.idft_matrix(n)

        self.W_real = torch.Tensor(np.real(self.W))
        self.W_imag = torch.Tensor(np.imag(self.W))
        self.inv_W_real = torch.Tensor(np.real(self.inv_W))
        self.inv_W_imag = torch.Tensor(np.imag(self.inv_W))

        self.n = n
        self.norm = norm

    def dft(self, x_real, x_imag):
        r"""Calculate DFT of a signal.

        Args:
            x_real: (n,), real part of a signal
            x_imag: (n,), imag part of a signal

        Returns:
            z_real: (n,), real part of output
            z_imag: (n,), imag part of output
        """
        z_real = torch.matmul(x_real, self.W_real) - torch.matmul(x_imag, self.W_imag)
        z_imag = torch.matmul(x_imag, self.W_real) + torch.matmul(x_real, self.W_imag)
        # shape: (n,)

        if self.norm is None:
            pass
        elif self.norm == 'ortho':
            z_real /= math.sqrt(self.n)
            z_imag /= math.sqrt(self.n)

        return z_real, z_imag

    def idft(self, x_real, x_imag):
        r"""Calculate IDFT of a signal.

        Args:
            x_real: (n,), real part of a signal
            x_imag: (n,), imag part of a signal
        Returns:
            z_real: (n,), real part of output
            z_imag: (n,), imag part of output
        """
        z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)
        z_imag = torch.matmul(x_imag, self.inv_W_real) + torch.matmul(x_real, self.inv_W_imag)
        # shape: (n,)

        if self.norm is None:
            z_real /= self.n
        elif self.norm == 'ortho':
            z_real /= math.sqrt(n)
            z_imag /= math.sqrt(n)

        return z_real, z_imag

    def rdft(self, x_real):
        r"""Calculate right RDFT of signal.

        Args:
            x_real: (n,), real part of a signal
            x_imag: (n,), imag part of a signal

        Returns:
            z_real: (n // 2 + 1,), real part of output
            z_imag: (n // 2 + 1,), imag part of output
        """
        n_rfft = self.n // 2 + 1
        z_real = torch.matmul(x_real, self.W_real[..., 0 : n_rfft])
        z_imag = torch.matmul(x_real, self.W_imag[..., 0 : n_rfft])
        # shape: (n // 2 + 1,)

        if self.norm is None:
            pass
        elif self.norm == 'ortho':
            z_real /= math.sqrt(self.n)
            z_imag /= math.sqrt(self.n)

        return z_real, z_imag

    def irdft(self, x_real, x_imag):
        r"""Calculate IRDFT of signal.
        
        Args:
            x_real: (n // 2 + 1,), real part of a signal
            x_imag: (n // 2 + 1,), imag part of a signal

        Returns:
            z_real: (n,), real part of output
            z_imag: (n,), imag part of output
        """
        n_rfft = self.n // 2 + 1

        flip_x_real = torch.flip(x_real, dims=(-1,))
        flip_x_imag = torch.flip(x_imag, dims=(-1,))
        # shape: (n // 2 + 1,)

        x_real = torch.cat((x_real, flip_x_real[..., 1 : n_rfft - 1]), dim=-1)
        x_imag = torch.cat((x_imag, -1. * flip_x_imag[..., 1 : n_rfft - 1]), dim=-1)
        # shape: (n,)

        z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)
        # shape: (n,)

        if self.norm is None:
            z_real /= self.n
        elif self.norm == 'ortho':
            z_real /= math.sqrt(n)

        return z_real


class STFT(DFTBase):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None,
        window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
        r"""PyTorch implementation of STFT with Conv1d. The function has the 
        same output as librosa.stft.

        Args:
            n_fft: int, fft window size, e.g., 2048
            hop_length: int, hop length samples, e.g., 441
            win_length: int, window length e.g., 2048
            window: str, window function name, e.g., 'hann'
            center: bool
            pad_mode: str, e.g., 'reflect'
            freeze_parameters: bool, set to True to freeze all parameters. Set
                to False to finetune all parameters.
        """
        super(STFT, self).__init__()

        assert pad_mode in ['constant', 'reflect']

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.center = center
        self.pad_mode = pad_mode

        # By default, use the entire frame.
        if self.win_length is None:
            self.win_length = n_fft

        # Set the default hop, if it's not already specified.
        if self.hop_length is None:
            self.hop_length = int(self.win_length // 4)

        fft_window = lb.filters.get_window(window, self.win_length, fftbins=True)

        # Pad the window out to n_fft size.
        fft_window = lb.util.pad_center(fft_window, n_fft)

        # DFT & IDFT matrix.
        self.W = self.dft_matrix(n_fft)

        out_channels = n_fft // 2 + 1

        self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels,
            kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1,
            groups=1, bias=False)

        self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels,
            kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1,
            groups=1, bias=False)

        # Initialize Conv1d weights.
        self.conv_real.weight.data = torch.Tensor(
            np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        self.conv_imag.weight.data = torch.Tensor(
            np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        r"""Calculate STFT of batch of signals.

        Args: 
            input: (batch_size, data_length), input signals.

        Returns:
            real: (batch_size, 1, time_steps, n_fft // 2 + 1)
            imag: (batch_size, 1, time_steps, n_fft // 2 + 1)
        """

        x = input[:, None, :]   # (batch_size, channels_num, data_length)

        if self.center:
            x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)

        real = self.conv_real(x)
        imag = self.conv_imag(x)
        # (batch_size, n_fft // 2 + 1, time_steps)

        real = real[:, None, :, :].transpose(2, 3)
        imag = imag[:, None, :, :].transpose(2, 3)
        # (batch_size, 1, time_steps, n_fft // 2 + 1)

        return real, imag


class LogmelFilterBank(nn.Module):
    def __init__(self, sr=22050, n_fft=2048, n_mels=64, fmin=0.0, fmax=None, 
        is_log=True, ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True):
        r"""Calculate logmel spectrogram using pytorch. The mel filter bank is 
        the pytorch implementation of as librosa.filters.mel 
        """
        super(LogmelFilterBank, self).__init__()

        self.is_log = is_log
        self.ref = ref
        self.amin = amin
        self.top_db = top_db
        if fmax == None:
            fmax = sr//2

        self.melW = lb.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels,
            fmin=fmin, fmax=fmax).T
        # (n_fft // 2 + 1, mel_bins)

        self.melW = nn.Parameter(torch.Tensor(self.melW))

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        r"""Calculate (log) mel spectrogram from spectrogram.

        Args:
            input: (*, n_fft), spectrogram
        
        Returns: 
            output: (*, mel_bins), (log) mel spectrogram
        """

        # Mel spectrogram
        mel_spectrogram = torch.matmul(input, self.melW)
        # (*, mel_bins)

        # Logmel spectrogram
        if self.is_log:
            output = self.power_to_db(mel_spectrogram)
        else:
            output = mel_spectrogram

        return output

    def power_to_db(self, input):
        r"""Power to db, this function is the pytorch implementation of 
        librosa.power_to_lb
        """
        ref_value = self.ref
        log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
        log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))

        if self.top_db is not None:
            if self.top_db < 0:
                raise lb.util.exceptions.ParameterError('top_db must be non-negative')
            log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)

        return log_spec


class Spectrogram(nn.Module):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None,
        window='hann', center=True, pad_mode='reflect', power=2.0,
        freeze_parameters=True):
        r"""Calculate spectrogram using pytorch. The STFT is implemented with 
        Conv1d. The function has the same output of librosa.stft
        """
        super(Spectrogram, self).__init__()

        self.power = power

        self.stft = STFT(n_fft=n_fft, hop_length=hop_length,
            win_length=win_length, window=window, center=center,
            pad_mode=pad_mode, freeze_parameters=True)

    def forward(self, input):
        r"""Calculate spectrogram of input signals.
        Args: 
            input: (batch_size, data_length)

        Returns:
            spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1)
        """

        (real, imag) = self.stft.forward(input)
        # (batch_size, n_fft // 2 + 1, time_steps)

        spectrogram = real ** 2 + imag ** 2

        if self.power == 2.0:
            pass
        else:
            spectrogram = spectrogram ** (self.power / 2.0)

        return spectrogram


class DropStripes(nn.Module):
    def __init__(self, dim, drop_width, stripes_num):
        """Drop stripes. 

        Args:
          dim: int, dimension along which to drop
          drop_width: int, maximum width of stripes to drop
          stripes_num: int, how many stripes to drop
        """
        super(DropStripes, self).__init__()

        assert dim in [2, 3]    # dim 2: time; dim 3: frequency

        self.dim = dim
        self.drop_width = drop_width
        self.stripes_num = stripes_num

    def transform_slice(self, e, total_width):
        """e: (channels, time_steps, freq_bins)"""

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                e[:, bgn : bgn + distance, :] = 0
            elif self.dim == 3:
                e[:, :, bgn : bgn + distance] = 0

    def forward(self, input):
        """input: (batch_size, channels, time_steps, freq_bins)"""

        assert input.ndimension() == 4

#         if self.training is False:
#             return input

#         else:
        batch_size = input.shape[0]
        total_width = input.shape[self.dim]

        for n in range(batch_size):
            self.transform_slice(input[n], total_width)

        return input


class SpecAugmentation(nn.Module):
    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, 
        freq_stripes_num):
        """Spec augmetation. 
        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. 
        and Le, Q.V., 2019. Specaugment: A simple data augmentation method 
        for automatic speech recognition. arXiv preprint arXiv:1904.08779.

        Args:
          time_drop_width: int
          time_stripes_num: int
          freq_drop_width: int
          freq_stripes_num: int
        """

        super(SpecAugmentation, self).__init__()

        self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, 
            stripes_num=time_stripes_num)

        self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, 
            stripes_num=freq_stripes_num)

    def forward(self, input):
        x = self.time_dropper(input)
        x = self.freq_dropper(x)
        return x

## Dataset and Loading

In [None]:
TEST = (len(list(Path("../input/birdclef-2021/test_soundscapes/").glob("*.ogg"))) != 0)
# SAMPLE_SUB_PATH = None
if TEST:
    DATADIR = Path("../input/birdclef-2021/test_soundscapes/")
    SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
else:
    DATADIR = Path("../input/birdclef-2021/train_soundscapes/")
    SAMPLE_SUB_PATH = None

In [None]:
data = pd.DataFrame(
     [(path.stem, *path.stem.split("_"), path) for path in DATADIR.glob("*.ogg")],
    columns = ["filename", "id", "site", "date", "filepath"]
)
map_Latitude = {
    'COL': 5.57,
    'COR': 10.12,
    'SNE': 38.49,
    'SSW': 42.47,
}
map_Longitude = {
    'COL': -75.85,
    'COR': -84.51,
    'SNE': -119.95,
    'SSW': -76.45
}

data['month'] = data['date'].apply(lambda x: int(x[4:6]) - 1)
data['latitude'] = data['site'].map(map_Latitude)
data['longitude'] = data['site'].map(map_Longitude)
print(data.shape)

df_train = pd.read_csv("../input/birdclef-2021/train_metadata.csv")

TARG_NAMES = np.unique(df_train['primary_label'].values).tolist()
print(f"targ length: {len(TARG_NAMES)}")

In [None]:
data.tail(3)

# Dataset

In [None]:
def mono_to_color(X, eps=1e-6, mean=None, std=None):
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)
    
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = torch.clamp(X, _min, _max)
        V = (V - _min) / (_max - _min)
    else:
        V = torch.zeros_like(X)

    return V


class MelSpecComputer:
    def __init__(self, sr, n_mels, fmin, fmax, n_fft, hop_length):
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax
        self.n_fft = n_fft
        self.hop_length = hop_length

    def __call__(self, y):

        melspec = lb.feature.melspectrogram(
            y, sr=self.sr, n_mels=self.n_mels, fmin=self.fmin, fmax=self.fmax, n_fft=self.n_fft, hop_length=self.hop_length
        )

        melspec = lb.power_to_db(melspec).astype(np.float32)
        return melspec


class TestDataset(Dataset):
    def __init__(self, data, sr=32000, n_mels=128, fmin=0, fmax=None, n_fft=2048, hop_length=512, duration=5, step=None, res_type="kaiser_fast", resample=True):
        
        self.data = data
        
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax or self.sr//2
        self.n_fft = n_fft
        self.hop_length = hop_length

        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.step = step or self.audio_length
        
        self.res_type = res_type
        self.resample = resample
        self.mel_spec_computer = MelSpecComputer(sr=self.sr, n_mels=self.n_mels, fmin=self.fmin, fmax=self.fmax, n_fft=self.n_fft, hop_length=self.hop_length)

    def __len__(self):
        return len(self.data)
    
    @staticmethod
    def normalize(image):
        image = image.astype("float32", copy=False) / 255.0
        image = np.stack([image, image, image])
        return image
    
    def audio_to_image(self, audio):
        melspec = self.mel_spec_computer(audio) 
        image = mono_to_color(melspec)
        image = self.normalize(image)
        return image

    def read_file(self, filepath):
        audio, orig_sr = sf.read(filepath, dtype="float32")

        if self.resample and orig_sr != self.sr:
            audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)
          
        audios = []
        for i in range(self.audio_length, len(audio) + self.step, self.step):
            start = max(0, i - self.audio_length)
            end = start + self.audio_length
            audios.append(audio[start:end])
            
        if len(audios[-1]) < self.audio_length:
            audios = audios[:-1]
            
        images = np.stack(audios)
        
        return images
    
    def __getitem__(self, idx):
        sample = self.data.loc[idx]
        images = self.read_file(sample["filepath"])        
        ### meta feat
        meta = np.zeros(4)
        if sample.month >= 0:
            meta[0] = month2cosine(sample.month)
            meta[1] = month2sin(sample.month)
        meta[2] = float(sample.latitude)
        meta[3] = float(sample.longitude)
        meta = torch.tensor(meta).float()
        ###
        return images, meta

def month2cosine(month):
    month_norm = 2 * np.pi * month / 12
    return np.cos(month_norm)

def month2sin(month):
    month_norm = 2 * np.pi * month / 12
    return np.sin(month_norm)

# Define Model

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def init_weights(model):
    classname = model.__class__.__name__
    if classname.find("Conv2d") != -1:
        nn.init.xavier_uniform_(model.weight, gain=np.sqrt(2))
        model.bias.data.fill_(0)
    elif classname.find("BatchNorm") != -1:
        model.weight.data.normal_(1.0, 0.02)
        model.bias.data.fill_(0)
    elif classname.find("GRU") != -1:
        for weight in model.parameters():
            if len(weight.size()) > 1:
                nn.init.orghogonal_(weight.data)
    elif classname.find("Linear") != -1:
        model.weight.data.normal_(0, 0.01)
        model.bias.data.zero_()


def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    output = F.interpolate(
        framewise_output.unsqueeze(1),
        size=(frames_num, framewise_output.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)

    return output


class DownPool(nn.Module):
    def __init__(self, pool_stride, conv_stride):
        super().__init__()
        self.avgpool = nn.AvgPool2d(3, stride=pool_stride, padding=1)
        self.downconv = nn.Sequential(
                            nn.Conv2d(1, 2, kernel_size=5, stride=conv_stride, padding=2, bias=False),
                            nn.BatchNorm2d(2),
                            nn.ReLU()
                        )
    def forward(self, x):
        x = torch.cat((self.avgpool(x), self.downconv(x)), dim=1)
        return x


class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


class TimmSED(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, 
                 num_classes=397, in_channels=1, norm_free=False, mix_up=False, downpool=False,
                 n_fft=2048, hop_length=800, fmin=20, fmax=16000, n_mels=128, pool_stride=(2,2), conv_stride=(2,2)):
        super(TimmSED, self).__init__()
        self.num_channel = in_channels
        self.mix_up = mix_up
        self.downpool=None
        if downpool:
            self.downpool = DownPool(pool_stride, conv_stride)
            in_channels = 3
        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=n_fft, hop_length=hop_length,
                                                 win_length=n_fft, window="hann", center=True, pad_mode="reflect",
                                                 freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=32000, n_fft=n_fft,
                                                 n_mels=n_mels, 
                                                 fmin=fmin, fmax=fmax, 
                                                 ref=1.0, amin=1e-10, top_db=None,
                                                 freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
                                               freq_drop_width=8, freq_stripes_num=2)
        
        base_model = timm.create_model(base_model_name, pretrained=pretrained, in_chans=in_channels)
        if 'rexnet' in base_model_name:
            layers = list(base_model.children())[:-1]
            fc = list(base_model.children())[-1].fc
        else:
            layers = list(base_model.children())[:-2]
            fc = list(base_model.children())[-1]
        self.encoder = nn.Sequential(*layers)
        if norm_free:
            fc = fc.fc        
        in_features = fc.in_features
        self.fc1 = nn.Linear(in_features, in_features, bias=True)
        self.att_block = AttBlockV2(in_features, num_classes, activation="sigmoid")
        
        self.MLP = nn.Sequential(
            nn.Linear(num_classes+4, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, num_classes)
        )

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.fc1)

    def preprocess(self, input_x, mixup_lambda=None):

        x = self.spectrogram_extractor(input_x)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        frames_num = x.shape[2]
        x = mono_to_color(x)

        if self.training:
            x = self.spec_augmenter(x)

        return x, frames_num

    def forward(self, x, meta):
        # input shape: (batch_size, audio_len)
        x, frames_num = self.preprocess(x)
        
        x = x.transpose(2, 3)  # (batch_size, 1, mel_bins, time_steps)
        if self.num_channel == 3:
            x = torch.cat([x,x,x], dim=1)
        if self.downpool:
            x = self.downpool(x)

        x = self.encoder(x)  # (batch_size, channels, freq, frames)
        x = torch.mean(x, dim=2)  # (batch_size, channels, frames)

        # channel smoothing
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2  # (batch_size, channels, frames)

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)  # (batch_size, frames, channels)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)  # (batch_size, channels, frames)
        x = F.dropout(x, p=0.5, training=self.training)
        
        # calculate attention
        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
#         logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        xl = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        ### add region and month feature here?
        xl = torch.cat([xl, meta], 1)
        logit = self.MLP(xl)
        ###
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        interpolate_ratio = frames_num // segmentwise_output.size(1)

        # # Get framewise output
        # framewise_output = interpolate(segmentwise_output, interpolate_ratio)
        # framewise_output = pad_framewise_output(framewise_output, frames_num)
        
        framewise_logit = interpolate(segmentwise_logit, interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            # "framewise_output": framewise_output,
            # "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            # "clipwise_output": clipwise_output
        }

        return output_dict


class BirdCallModel(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, 
                 num_classes=397, in_channels=3):
        super(BirdCallModel, self).__init__()

        base_model = timm.create_model(base_model_name, pretrained=pretrained, in_chans=in_channels)
        layers = list(base_model.children())[:-2]
        fc = list(base_model.children())[-1]
        self.encoder = nn.Sequential(*layers)
        
        self.in_features = fc.in_features
        
        # self.pooling = GeM()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(self.in_features, num_classes, bias=True)

        self.init_weight()

    def init_weight(self):
        init_layer(self.classifier)

    def forward(self, x):
        mb_size = x.shape[0]
        x = self.encoder(x)
        x = self.pooling(x)
        # x = torch.mean(x, dim=2)  # (batch_size, channels, frames)
        out = self.classifier(x.view(mb_size, self.in_features))
        return {'logit': out}


# Load Models

In [None]:
def prepare_model_for_inference(model, path: Path):
    try:
        ckpt = torch.load(path, map_location="cpu")
        model.load_state_dict(ckpt["model"])
    except:
        model.load_state_dict(torch.load(path, map_location="cpu"))
    model.eval()
    return model

**10 seconds models**

In [None]:
weights_10s_path = [
    ("SED", "tf_efficientnet_b4_ns", Path("../input/birdclef-2021-models-qishen/b4_10s_cyclic_v1_bs64_lr2e3_50epo_best_loss_fold3.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "tf_efficientnet_b4_ns", Path("../input/birdclef-2021-models-qishen/b4_10s_cyclic_v1_bs64_lr2e3_50epo_best_loss_fold4.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    
    ("SED", "tf_efficientnet_b5_ns", Path("../input/birdclef-2021-models-qishen/b5_10s_cyclic_v1_bs64_lr2e3_50epo_best_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "tf_efficientnet_b5_ns", Path("../input/birdclef-2021-models-qishen/b5_10s_cyclic_v1_bs64_lr2e3_50epo_best_fold4.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "rexnet_150", Path("../input/d/underwearfitting/bird-clef-10s-models/rexnet_150_10s_best_fold4.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    # ("SED", "rexnet_150", Path("../input/d/underwearfitting/bird-clef-10s-models/rexnet_150_10s_best_fold1.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "rexnet_200", Path("../input/d/underwearfitting/bird-clef-10s-models/rexnet_200_10s_best_fold1.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    # ("SED", "rexnet_200", Path("../input/d/underwearfitting/bird-clef-10s-models/rexnet_200_10s_best_fold2.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "nf_resnet50", Path("../input/lee-nf50-fold0-10s/nf50_128_cyclic_10s_lr2e3_50epo_bestscore_fold0.pth/nf50_128_cyclic_10s_lr2e3_50epo_bestscore_fold0.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),
    ("SED", "nf_resnet50", Path("../input/nf50-10s-cyclic/nf50_128_cyclic_10s_bestscore_fold4.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "eca_nfnet_l0", Path("../input/eca-nfnet-f34-10s/_eca_nfnet_l0-10s_50epo_bestloss_fold3.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),
    ("SED", "eca_nfnet_l0", Path("../input/eca-nfnet-f34-10s/_eca_nfnet_l0-10s_50epo_bestloss_fold4.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "eca_nfnet_l1", Path("../input/bird-eca-l1-10s/ecal1_128_cyclic_bestloss_fold1.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),
]

models_10s = []
for wp in tqdm(weights_10s_path):
    if wp[0] == "SED":
        net = TimmSED(wp[1], num_classes=397, in_channels=wp[4], n_fft=wp[5], hop_length=wp[6], fmin=50, fmax=16000, n_mels=wp[9], 
                      norm_free=wp[7], downpool=wp[8], pool_stride=wp[10], conv_stride=wp[11])
    else:
        net = BirdCallModel(wp[1], num_classes=397, in_channels=3)
    net = net.to(DEVICE)
    net = prepare_model_for_inference(net, wp[2])
    models_10s.append(net)

len(models_10s)

**15 seconds models**

In [None]:
weights_15s_path = [
    ("SED", "tf_efficientnet_b4_ns", Path("../input/bird-clef-128-v12-fold0-b4ns/b4_cyclic_best_fold3.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),  

    ("SED", "nf_resnet50", Path("../input/birdclefnf50meta/nf50_128_cyclic_v1_bs64_lr2e3_50epo_bestscore_fold0.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),
    ("SED", "nf_resnet50", Path("../input/birdclefnf50meta/nf50_128_cyclic_v1_bs64_lr2e3_50epo_bestscore_fold4.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),
    ("SED", "nf_resnet50", Path("../input/birdclefnf50meta/nf50_128_cyclic_v1_bs64_lr2e3_50epo_bestscore_fold2.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "dm_nfnet_f0", Path("../input/dimf0-temp-birdelf/nff0_128_cyclic_v1_bs64_lr2e3_50epo_bestscore_fold3.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "rexnet_200", Path("../input/bird-clef-rexnet200-fold2/rexnet_200_cyclic_v1_bs64_lr2e3_50epo_best_fold2.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "rexnet_200", Path("../input/birdclef-2021-models-qishen/rexnet_200_cyclic_v1_bs62_lr2e3_50epo_best_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    # ("SED", "rexnet_200", Path("../input/birdclef-2021-models-qishen/rexnet_200_cyclic_v1_bs62_lr2e3_50epo_best_fold4.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "efficientnetv2_rw_s", Path("../input/effnetv2-birdelf-temp/effnetv2_best_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "eca_nfnet_l0", Path("../input/bird-eca-nfnet-l0-cyclic-fold0/nf50_128_cyclic_v1_bs64_lr2e3_50epo_bestloss_fold0.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "eca_nfnet_l1", Path("../input/birdclef2021-eca-nfnet-l1/eca_nfnet_l1_128_cyclic_v1_bs64_lr2e3_50epo_bestloss_fold0.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)),

    ("SED", "rexnet_150", Path("../input/rex150-birdelf/rexnet_150_cyclic_v1_bs64_lr2e3_50epo_best_fold2.pthbest_sc"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "rexnet_150", Path("../input/rex150-birdelf/rexnet_150_cyclic_v1_bs64_lr2e3_50epo_best_fold4.pthbest_sc"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),

    ("SED", "rexnet_200", Path("../input/bird-25d-rex200-40epochs-fold0/rex200_cyclic_v1_sample4bs16_lr2e3_40epo_best_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
]

models_15s = []
for wp in tqdm(weights_15s_path):
    if wp[0] == "SED":
        net = TimmSED(wp[1], num_classes=397, in_channels=wp[4], n_fft=wp[5], hop_length=wp[6], fmin=50, fmax=16000, n_mels=wp[9], 
                      norm_free=wp[7], downpool=wp[8], pool_stride=wp[10], conv_stride=wp[11])
    else:
        net = BirdCallModel(wp[1], num_classes=397, in_channels=3)
    net = net.to(DEVICE)
    net = prepare_model_for_inference(net, wp[2])
    models_15s.append(net)
    
len(models_15s)

**20 seconds models**

In [None]:
weights_20s_path = [
    ("SED", "rexnet_200", Path("../input/birdclef-2021-models-qishen/rexnet_200_20s_cyclic_v1_bs56_lr2e3_50epo_best_loss_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "rexnet_200", Path("../input/birdclef-2021-models-qishen/rexnet_200_20s_cyclic_v1_bs56_lr2e3_50epo_best_loss_fold1.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    
    ("SED", "tf_efficientnet_b5_ns", Path("../input/d/underwearfitting/birdclef-model-20s/b5ns_20s_best_fold3.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    ("SED", "tf_efficientnet_b5_ns", Path("../input/birdclef2021-eca-nfnet-l1/tf_efficientnet_b5_ns_20s_128_cyclic_v1_bs64_lr2e3_bestloss_fold3.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), 
    
    ("SED", "tf_efficientnet_b4_ns", Path("../input/d/underwearfitting/birdclef-model-20s/b4ns_20s_best_fold3.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), # 0.781 @0.42
    ("SED", "tf_efficientnet_b4_ns", Path("../input/birdclef2021-eca-nfnet-l1/tf_efficientnet_b4_20s_ns_128_cyclic_v1_bs64_lr2e3_bestscore_fold4.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), 
    
    ("SED", "tf_efficientnet_b3_ns", Path("../input/birdclef-2021-models-qishen/b3_20s_cyclic_v1_bs64_lr2e3_50epo_best_loss_fold2.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), #0.759 @0.45
    ("SED", "tf_efficientnet_b3_ns", Path("../input/birdclef-2021-models-qishen/b3_20s_cyclic_v1_bs64_lr2e3_50epo_best_loss_fold4.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
    
    ("SED", "nf_resnet50", Path("../input/birdclef2021-eca-nfnet-l1/nf50_128_cyclic_20s_lr2e3_50epo_bestloss_fold0.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)), # 0.772 @0.42
    ("SED", "nf_resnet50", Path("../input/birdclefnf50meta/nf50_128_cyclic_20s_lr2e3_bestloss_fold2.pth"), 0.3, 1, 2048, 512, True, False,128,(1,2), (1,2)), 
    
    ("SED", "efficientnetv2_rw_s", Path("../input/birdenetv2-20s-fold0/enetv2_20s_best_fold0.pth"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), # 0.772 @0.42
    # ("SED", "efficientnetv2_rw_s", Path("../input/birdenetv2-20s-fold0/enetv2_20s_best_fold0.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)), 
    
    ("SED", "rexnet_150", Path("../input/d/underwearfitting/birdclef-model-20s/rex150_20s_best_fold2.pthbest_loss"), 0.3, 1, 2048, 512, False, False,128,(1,2), (1,2)),
]

models_20s = []
for wp in tqdm(weights_20s_path):
    if wp[0] == "SED":
        net = TimmSED(wp[1], num_classes=397, in_channels=wp[4], n_fft=wp[5], hop_length=wp[6], fmin=50, fmax=16000, n_mels=wp[9], 
                      norm_free=wp[7], downpool=wp[8], pool_stride=wp[10], conv_stride=wp[11])
    else:
        net = BirdCallModel(wp[1], num_classes=397, in_channels=3)
    net = net.to(DEVICE)
    net = prepare_model_for_inference(net, wp[2])
    models_20s.append(net)
    
len(models_20s)

In [None]:
! nvidia-smi

# Get Probabilities

In [None]:
test_data = TestDataset(data=data)
len(test_data), test_data[0][0].shape, test_data[0][1].shape

In [None]:
little_bs = 32


def get_prob_zero_pad(x, meta, models, pad_seconds):
    ### zero pad center
    xa = torch.cat([
        torch.zeros(x.shape[0], int(32000 * pad_seconds / 2)).to(DEVICE),
        x,
        torch.zeros(x.shape[0], int(32000 * pad_seconds / 2)).to(DEVICE),
    ], 1)
    meta_a = meta.unsqueeze(0).repeat(xa.shape[0], 1)
    prob_a = torch.stack([
        torch.cat([
                    m(xa[b:b+little_bs], meta_a[b:b+little_bs])['logit'].sigmoid()
                        for b in range(0, xa.shape[0], little_bs)
                 ])
        for m in models
    ], 0).mean(0)

    ### zero pad left
    xb = torch.cat([
        x,
        torch.zeros(x.shape[0], int(32000 * pad_seconds)).to(DEVICE),
    ], 1)
    prob_b = torch.stack([
        torch.cat([
                    m(xb[b:b+little_bs], meta_a[b:b+little_bs])['logit'].sigmoid()
                        for b in range(0, xb.shape[0], little_bs)
                 ])
        for m in models
    ], 0).mean(0)

    ### zero pad right
    xc = torch.cat([
        torch.zeros(x.shape[0], int(32000 * pad_seconds)).to(DEVICE),
        x,
    ], 1)
    prob_c = torch.stack([
        torch.cat([
                    m(xc[b:b+little_bs], meta_a[b:b+little_bs])['logit'].sigmoid()
                        for b in range(0, xc.shape[0], little_bs)
                 ])
        for m in models
    ], 0).mean(0)

    return (prob_a + prob_b + prob_c) / 3.


def get_prob_avg_pool(x, meta, models, mode):
    ### concat then avg pool
    if mode == '10s':
        zero_pad = 1
    elif mode == '15s':
        zero_pad = 2
    elif mode == '20s':
        zero_pad = 3
    else:
        raise

    xd = torch.cat([torch.zeros(zero_pad, int(32000 * 5)).to(DEVICE), x, torch.zeros(zero_pad, int(32000 * 5)).to(DEVICE)])
    
    if mode == '10s':
        xd = torch.cat([xd[:-1], xd[1:]], 1)
    elif mode == '15s':
        xd = torch.cat([xd[:-2], xd[1:-1], xd[2:]], 1)
    elif mode == '20s':
        xd = torch.cat([xd[:-3],xd[1:-2],xd[2:-1],xd[3:]], 1)
    
    meta_d = meta.unsqueeze(0).repeat(xd.shape[0], 1)
    prob = torch.stack([
        torch.cat([
                    m(xd[b:b+little_bs], meta_d[b:b+little_bs])['logit'].sigmoid()
                        for b in range(0, xd.shape[0], little_bs)
                 ])
        for m in models
    ], 0).mean(0)

    prob_d = F.avg_pool1d(prob.transpose(1,0).unsqueeze(1), kernel_size=zero_pad+1, stride=1).squeeze(1).transpose(1,0)
    return prob_d

In [None]:
model_prob = []
with torch.no_grad():
    for idx in tqdm(list(range(len(test_data)))):
        x, meta = test_data[idx]
        x = torch.tensor(x).to(DEVICE)
        meta = meta.to(DEVICE)
        
        ### 10 Sec
        prob_zero_pad = get_prob_zero_pad(x, meta, models_10s, pad_seconds=5)
        prob_avg_pool = get_prob_avg_pool(x, meta, models_10s, mode='10s')
        prob_10s = (prob_zero_pad + prob_avg_pool) / 2.

        ### 15 Sec
        prob_zero_pad = get_prob_zero_pad(x, meta, models_15s, pad_seconds=10)
        prob_avg_pool = get_prob_avg_pool(x, meta, models_15s, mode='15s')
        prob_15s = (prob_zero_pad + prob_avg_pool) / 2.

        ### 20 Sec
        prob_zero_pad = get_prob_zero_pad(x, meta, models_20s, pad_seconds=15)
        prob_avg_pool = get_prob_avg_pool(x, meta, models_20s, mode='20s')
        prob_20s = (prob_zero_pad + prob_avg_pool) / 2.

        ### Final
        prob_final = (prob_15s * len(models_15s) + prob_10s * len(models_10s) + prob_20s * len(models_20s)) / \
                     (len(models_15s) + len(models_10s) + len(models_20s))

        model_prob.append(prob_final.cpu())

model_prob = torch.cat(model_prob)

# Post Process Probabilities

In [None]:
def get_thresh_preds(probs, thresh=0.5):
    o = (-probs).argsort(dim=1)
    npreds = torch.sum(probs > thresh, dim=1)
    preds = []
    for prob_idx, npred in zip(o, npreds):
        preds.append(prob_idx[:npred].cpu().numpy().tolist())
    return preds


def get_bird_names(preds, ebird_names):
    bird_names = []
    for pred in preds:
        if not pred:
            bird_names.append("nocall")
        else:
            bird_names.append(" ".join([ebird_names[bird_id] for bird_id in pred]))
    return bird_names


def preds_as_df(data, preds):
    sub = {
        "row_id": [],
        "birds": [],
    }

    for row in data.itertuples():
        row_id = [f"{row.id}_{row.site}_{5*i}" for i in range(1, 121)]
#         sub["birds"] += pred
        sub["row_id"] += row_id

    sub["birds"] += preds[0]
    sub = pd.DataFrame(sub)

    if SAMPLE_SUB_PATH:
        sample_sub = pd.read_csv(SAMPLE_SUB_PATH, usecols=["row_id"])
        sub = sample_sub.merge(sub, on="row_id", how="left")
        sub["birds"] = sub["birds"].fillna("nocall")
    return sub

def get_metrics(s_true, s_pred):
    s_true = set(s_true.split())
    s_pred = set(s_pred.split())
    n, n_true, n_pred = len(s_true.intersection(s_pred)), len(s_true), len(s_pred)
    
    prec = n/n_pred
    rec = n/n_true
    f1 = 2*prec*rec/(prec + rec) if prec + rec else 0
    
    return {"f1": f1, "prec": prec, "rec": rec, "n_true": n_true, "n_pred": n_pred, "n": n}


In [None]:
preds = get_thresh_preds(model_prob, 0.36)
named_pred = get_bird_names(preds, TARG_NAMES)
prediction_df = preds_as_df(data, [named_pred])
prediction_df.to_csv("submission.csv", index=False)
if "train_soundscapes" in DATADIR.name:
    train_labels = pd.read_csv("../input/birdclef-2021/train_soundscape_labels.csv")
#     remove_sound = ["7019_COR_20190904", "7954_COR_20190923","31928_COR_20191004"]
    remove_sound = ["7019_COR", "7954_COR","31928_COR"]
    each_sub = pd.read_csv("submission.csv")
    sub_target = train_labels.merge(each_sub, how="left", on="row_id")

    print(sub_target["birds_x"].notnull().sum(), sub_target["birds_x"].notnull().sum())
    assert sub_target["birds_x"].notnull().all()
    assert sub_target["birds_y"].notnull().all()

    df_metrics = pd.DataFrame([get_metrics(s_true, s_pred) for s_true, s_pred in zip(sub_target.birds_x, sub_target.birds_y)])
    df_metrics_1 = pd.DataFrame([get_metrics(s_true, s_pred) for s_true, s_pred, row_idx in zip(sub_target.birds_x, sub_target.birds_y, sub_target.row_id) if "_".join(row_idx.split("_")[:-1]) not in remove_sound])

    print(f"{df_metrics.mean()}")
    print("=================")
    print(f"{df_metrics_1.mean()}")
    print("=================\n")