In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import time
import os
import copy
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import cv2
import librosa
from librosa import display
from pathlib import Path
import multiprocessing
from tqdm import tqdm
from matplotlib import gridspec
from datetime import datetime
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sns
sns.set_palette(sns.color_palette('tab20', 20))
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from sklearn import metrics
from plotly.offline import init_notebook_mode, iplot, plot
init_notebook_mode(connected=True)

In [None]:
nb_start = datetime.now()

plt.rcParams['figure.figsize'] = [16, 10]
plt.rcParams['font.size'] = 14
pd.set_option('display.max_columns', 99)
pd.set_option('display.max_rows', 99)

In [None]:
from dataclasses import dataclass

@dataclass
class BirdLists:
    all = [
        'aldfly', 'ameavo', 'amebit', 'amecro', 'amegfi', 'amekes', 'amepip', 'amered', 'amerob', 'amewig', 'amewoo',
        'amtspa', 'annhum', 'astfly', 'baisan', 'baleag', 'balori', 'banswa', 'barswa', 'bawwar', 'belkin1', 'belspa2',
        'bewwre', 'bkbcuc', 'bkbmag1', 'bkbwar', 'bkcchi', 'bkchum', 'bkhgro', 'bkpwar', 'bktspa', 'blkpho', 'blugrb1',
        'blujay', 'bnhcow', 'boboli', 'bongul', 'brdowl', 'brebla', 'brespa', 'brncre', 'brnthr', 'brthum', 'brwhaw',
        'btbwar', 'btnwar', 'btywar', 'buffle', 'buggna', 'buhvir', 'bulori', 'bushti', 'buwtea', 'buwwar', 'cacwre',
        'calgul', 'calqua', 'camwar', 'cangoo', 'canwar', 'canwre', 'carwre', 'casfin', 'caster1', 'casvir', 'cedwax',
        'chispa', 'chiswi', 'chswar', 'chukar', 'clanut', 'cliswa', 'comgol', 'comgra', 'comloo', 'commer', 'comnig',
        'comrav', 'comred', 'comter', 'comyel', 'coohaw', 'coshum', 'cowscj1', 'daejun', 'doccor', 'dowwoo', 'dusfly',
        'eargre', 'easblu', 'easkin', 'easmea', 'easpho', 'eastow', 'eawpew', 'eucdov', 'eursta', 'evegro', 'fiespa',
        'fiscro', 'foxspa', 'gadwal', 'gcrfin', 'gnttow', 'gockin', 'gocspa', 'goleag', 'grbher3', 'grcfly', 'greegr',
        'greroa', 'greyel', 'grhowl', 'grnher', 'grtgra', 'grycat', 'gryfly', 'haiwoo', 'hamfly', 'herthr', 'hoomer',
        'hoowar', 'horgre', 'horlar', 'houfin', 'houspa', 'houwre', 'indbun', 'juntit1', 'killde', 'labwoo', 'larspa',
        'lazbun', 'leabit', 'leafly', 'leasan', 'lecthr', 'lesgol', 'lesnig', 'lesyel', 'lewwoo', 'linspa', 'lobcur',
        'lobdow', 'logshr', 'lotduc', 'louwat', 'macwar', 'magwar', 'mallar3', 'marwre', 'merlin', 'moublu', 'mouchi',
        'moudov', 'norcar', 'norfli', 'norhar2', 'normoc', 'norpar', 'norpin', 'norsho', 'norwat', 'nrwswa', 'nutwoo',
        'olsfly', 'orcwar', 'osprey', 'ovenbi1', 'palwar', 'pasfly', 'pecsan', 'perfal', 'phaino', 'pibgre', 'pilwoo',
        'pingro', 'pinjay', 'pinsis', 'pinwar', 'plsvir', 'prawar', 'purfin', 'pygnut', 'rebmer', 'rebnut', 'rebsap',
        'rebwoo', 'redcro', 'redhea', 'reevir1', 'renpha', 'reshaw', 'rethaw', 'rewbla', 'ribgul', 'rinduc', 'robgro',
        'rocpig', 'rocwre', 'rthhum', 'ruckin', 'rudduc', 'rufgro', 'rufhum', 'rusbla', 'sagspa1', 'sagthr', 'savspa',
        'saypho', 'scatan', 'scoori', 'semplo', 'semsan', 'sheowl', 'shshaw', 'snobun', 'snogoo', 'solsan', 'sonspa',
        'sora', 'sposan', 'spotow', 'stejay', 'swahaw', 'swaspa', 'swathr', 'treswa', 'truswa', 'tuftit', 'tunswa',
        'veery', 'vesspa', 'vigswa', 'warvir', 'wesblu', 'wesgre', 'weskin', 'wesmea', 'wessan', 'westan', 'wewpew',
        'whbnut', 'whcspa', 'whfibi', 'whtspa', 'whtswi', 'wilfly', 'wilsni1', 'wiltur', 'winwre3', 'wlswar', 'wooduc',
        'wooscj2', 'woothr', 'y00475', 'yebfly', 'yebsap', 'yehbla', 'yelwar', 'yerwar', 'yetvir'
    ]
    top_165 = [
        'amecro', 'amegfi', 'amekes', 'amered', 'amerob', 'amewig', 'amtspa', 'annhum', 'baleag', 'balori', 'barswa',
        'bawwar', 'belkin1', 'bewwre', 'bkbmag1', 'bkcchi', 'bkhgro', 'blkpho', 'blujay', 'bnhcow', 'bongul', 'brdowl',
        'brebla', 'brncre', 'brnthr', 'btbwar', 'btnwar', 'buffle', 'buggna', 'buhvir', 'bushti', 'buwtea', 'calgul',
        'calqua', 'cangoo', 'carwre', 'caster1', 'cedwax', 'chispa', 'chiswi', 'chswar', 'cliswa', 'comgol', 'comgra',
        'comloo', 'commer', 'comrav', 'comyel', 'coohaw', 'cowscj1', 'daejun', 'doccor', 'dowwoo', 'easblu', 'easkin',
        'easmea', 'easpho', 'eastow', 'eawpew', 'eucdov', 'eursta', 'fiespa', 'fiscro', 'foxspa', 'gadwal', 'gockin',
        'gocspa', 'grbher3', 'grcfly', 'greegr', 'greyel', 'grhowl', 'grnher', 'grtgra', 'grycat', 'haiwoo', 'herthr',
        'hoomer', 'horgre', 'horlar', 'houfin', 'houspa', 'houwre', 'indbun', 'killde', 'leafly', 'leasan', 'lesgol',
        'lesyel', 'linspa', 'logshr', 'magwar', 'mallar3', 'marwre', 'merlin', 'mouchi', 'moudov', 'norcar', 'norfli',
        'norhar2', 'normoc', 'norpar', 'norpin', 'norsho', 'nrwswa', 'nutwoo', 'orcwar', 'osprey', 'ovenbi1', 'palwar',
        'perfal', 'pibgre', 'pilwoo', 'pinsis', 'pinwar', 'purfin', 'rebmer', 'rebnut', 'rebwoo', 'redhea', 'reevir1',
        'reshaw', 'rethaw', 'rewbla', 'ribgul', 'rinduc', 'robgro', 'rocpig', 'rthhum', 'ruckin', 'rudduc', 'savspa',
        'saypho', 'scatan', 'semplo', 'shshaw', 'snogoo', 'sonspa', 'sposan', 'spotow', 'stejay', 'swaspa', 'swathr',
        'treswa', 'tuftit', 'veery', 'vigswa', 'warvir', 'wesblu', 'weskin', 'wesmea', 'westan', 'whbnut', 'whcspa',
        'whtspa', 'wilsni1', 'wiltur', 'wlswar', 'wooduc', 'woothr', 'y00475', 'yebsap', 'yelwar', 'yerwar', 'yetvir'
    ]
    non_zero = [
        'ameavo', 'amebit', 'amecro', 'amegfi', 'amekes', 'amepip', 'amerob', 'amewig', 'amewoo', 'amtspa', 'annhum',
        'astfly', 'baisan', 'baleag', 'banswa', 'belkin1', 'belspa2', 'bkbcuc', 'bkbmag1', 'bkbwar', 'bkcchi', 'bkchum',
        'bkpwar', 'bktspa', 'blkpho', 'blugrb1', 'blujay', 'boboli', 'bongul', 'brdowl', 'brebla', 'brespa', 'brthum',
        'brwhaw', 'btbwar', 'btnwar', 'btywar', 'buffle', 'buhvir', 'bushti', 'buwtea', 'buwwar', 'cacwre', 'calgul',
        'calqua', 'camwar', 'cangoo', 'canwar', 'canwre', 'casfin', 'caster1', 'casvir', 'cedwax', 'chiswi', 'chswar',
        'chukar', 'clanut', 'cliswa', 'comgol', 'commer', 'comnig', 'comred', 'comter', 'comyel', 'coohaw', 'coshum',
        'cowscj1', 'daejun', 'doccor', 'dusfly', 'eargre', 'easkin', 'easmea', 'eawpew', 'eucdov', 'evegro', 'fiespa',
        'fiscro', 'gadwal', 'gcrfin', 'gnttow', 'gocspa', 'goleag', 'grcfly', 'greroa', 'greyel', 'grnher', 'grtgra',
        'grycat', 'gryfly', 'hamfly', 'herthr', 'hoomer', 'hoowar', 'horgre', 'horlar', 'indbun', 'juntit1', 'labwoo',
        'larspa', 'lazbun', 'leabit', 'leafly', 'leasan', 'lecthr', 'lesgol', 'lesnig', 'lesyel', 'lewwoo', 'linspa',
        'lobcur', 'lobdow', 'logshr', 'lotduc', 'louwat', 'macwar', 'magwar', 'merlin', 'moublu', 'mouchi', 'norfli',
        'norhar2', 'norpin', 'norsho', 'norwat', 'nrwswa', 'olsfly', 'ovenbi1', 'palwar', 'pasfly', 'pecsan', 'perfal',
        'phaino', 'pibgre', 'pingro', 'pinjay', 'pinsis', 'pinwar', 'plsvir', 'prawar', 'purfin', 'pygnut', 'rebmer',
        'rebsap', 'rebwoo', 'redhea', 'reevir1', 'renpha', 'reshaw', 'rewbla', 'rinduc', 'robgro', 'rocwre', 'rthhum',
        'rudduc', 'rufgro', 'rufhum', 'rusbla', 'sagspa1', 'sagthr', 'saypho', 'scatan', 'scoori', 'semplo', 'semsan',
        'sheowl', 'shshaw', 'snobun', 'snogoo', 'solsan', 'sonspa', 'sora', 'sposan', 'stejay', 'swahaw', 'swaspa',
        'truswa', 'tunswa', 'veery', 'vesspa', 'vigswa', 'wesblu', 'wesgre', 'weskin', 'wessan', 'westan', 'wewpew',
        'whcspa', 'whfibi', 'whtswi', 'wilfly', 'wilsni1', 'wiltur', 'winwre3', 'wooduc', 'wooscj2', 'woothr', 'y00475',
        'yebfly', 'yebsap', 'yehbla', 'yetvir'
    ]

In [None]:
SAMPLE_RATE = 32000
CHUNK_DURATION = 5
BATCH_SIZE = 64
chunk_len = SAMPLE_RATE * CHUNK_DURATION

In [None]:
import numpy as np
import librosa
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class DFTBase(nn.Module):
    def __init__(self):
        """Base class for DFT and IDFT matrix"""
        super(DFTBase, self).__init__()

    def dft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(-2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W

    def idft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W


class DFT(DFTBase):
    def __init__(self, n, norm):
        """Calculate DFT, IDFT, RDFT, IRDFT. 
        Args:
          n: fft window size
          norm: None | 'ortho'
        """
        super(DFT, self).__init__()

        self.W = self.dft_matrix(n)
        self.inv_W = self.idft_matrix(n)

        self.W_real = torch.Tensor(np.real(self.W))
        self.W_imag = torch.Tensor(np.imag(self.W))
        self.inv_W_real = torch.Tensor(np.real(self.inv_W))
        self.inv_W_imag = torch.Tensor(np.imag(self.inv_W))

        self.n = n
        self.norm = norm

    def dft(self, x_real, x_imag):
        """Calculate DFT of signal. 
        Args:
          x_real: (n,), signal real part
          x_imag: (n,), signal imag part
        Returns:
          z_real: (n,), output real part
          z_imag: (n,), output imag part
        """
        z_real = torch.matmul(x_real, self.W_real) - torch.matmul(x_imag, self.W_imag)
        z_imag = torch.matmul(x_imag, self.W_real) + torch.matmul(x_real, self.W_imag)

        if self.norm is None:
            pass
        elif self.norm == 'ortho':
            z_real /= math.sqrt(self.n)
            z_imag /= math.sqrt(self.n)

        return z_real, z_imag

        return z_real, z_imag

    def idft(self, x_real, x_imag):
        """Calculate IDFT of signal. 
        Args:
          x_real: (n,), signal real part
          x_imag: (n,), signal imag part
        Returns:
          z_real: (n,), output real part
          z_imag: (n,), output imag part
        """
        z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)
        z_imag = torch.matmul(x_imag, self.inv_W_real) + torch.matmul(x_real, self.inv_W_imag)

        if self.norm is None:
            z_real /= self.n
        elif self.norm == 'ortho':
            z_real /= math.sqrt(n)
            z_imag /= math.sqrt(n)

        return z_real, z_imag

    def rdft(self, x_real):
        """Calculate right DFT of signal. 
        Args:
          x_real: (n,), signal real part
          x_imag: (n,), signal imag part
        Returns:
          z_real: (n // 2 + 1,), output real part
          z_imag: (n // 2 + 1,), output imag part
        """
        n_rfft = self.n // 2 + 1
        z_real = torch.matmul(x_real, self.W_real[..., 0 : n_rfft])
        z_imag = torch.matmul(x_real, self.W_imag[..., 0 : n_rfft])

        if self.norm is None:
            pass
        elif self.norm == 'ortho':
            z_real /= math.sqrt(self.n)
            z_imag /= math.sqrt(self.n)

        return z_real, z_imag

    def irdft(self, x_real, x_imag):
        """Calculate inverse right DFT of signal. 
        Args:
          x_real: (n // 2 + 1,), signal real part
          x_imag: (n // 2 + 1,), signal imag part
        Returns:
          z_real: (n,), output real part
          z_imag: (n,), output imag part
        """
        n_rfft = self.n // 2 + 1

        flip_x_real = torch.flip(x_real, dims=(-1,))
        x_real = torch.cat((x_real, flip_x_real[..., 1 : n_rfft - 1]), dim=-1)

        flip_x_imag = torch.flip(x_imag, dims=(-1,))
        x_imag = torch.cat((x_imag, -1. * flip_x_imag[..., 1 : n_rfft - 1]), dim=-1)

        z_real = torch.matmul(x_real, self.inv_W_real) - torch.matmul(x_imag, self.inv_W_imag)

        if self.norm is None:
            z_real /= self.n
        elif self.norm == 'ortho':
            z_real /= math.sqrt(n)

        return z_real
        

class STFT(DFTBase):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
        """Implementation of STFT with Conv1d. The function has the same output 
        of librosa.core.stft
        """
        super(STFT, self).__init__()

        assert pad_mode in ['constant', 'reflect']

        self.n_fft = n_fft
        self.center = center
        self.pad_mode = pad_mode

        # By default, use the entire frame
        if win_length is None:
            win_length = n_fft

        # Set the default hop, if it's not already specified
        if hop_length is None:
            hop_length = int(win_length // 4)

        fft_window = librosa.filters.get_window(window, win_length, fftbins=True)

        # Pad the window out to n_fft size
        fft_window = librosa.util.pad_center(fft_window, n_fft)

        # DFT & IDFT matrix
        self.W = self.dft_matrix(n_fft)

        out_channels = n_fft // 2 + 1

        self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_real.weight.data = torch.Tensor(
            np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        self.conv_imag.weight.data = torch.Tensor(
            np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, data_length)
        Returns:
          real: (batch_size, n_fft // 2 + 1, time_steps)
          imag: (batch_size, n_fft // 2 + 1, time_steps)
        """

        x = input[:, None, :]   # (batch_size, channels_num, data_length)

        if self.center:
            x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)

        real = self.conv_real(x)
        imag = self.conv_imag(x)
        # (batch_size, n_fft // 2 + 1, time_steps)

        real = real[:, None, :, :].transpose(2, 3)
        imag = imag[:, None, :, :].transpose(2, 3)
        # (batch_size, 1, time_steps, n_fft // 2 + 1)

        return real, imag


def magphase(real, imag):
    mag = (real ** 2 + imag ** 2) ** 0.5
    cos = real / torch.clamp(mag, 1e-10, np.inf)
    sin = imag / torch.clamp(mag, 1e-10, np.inf)
    return mag, cos, sin


class ISTFT(DFTBase):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
        """Implementation of ISTFT with Conv1d. The function has the same output 
        of librosa.core.istft
        """
        super(ISTFT, self).__init__()

        assert pad_mode in ['constant', 'reflect']

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.center = center
        self.pad_mode = pad_mode

        # By default, use the entire frame
        if win_length is None:
            win_length = n_fft

        # Set the default hop, if it's not already specified
        if hop_length is None:
            hop_length = int(win_length // 4)

        # DFT & IDFT matrix
        self.W = self.idft_matrix(n_fft) / n_fft

        self.conv_real = nn.Conv1d(in_channels=n_fft, out_channels=n_fft, 
            kernel_size=1, stride=1, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_imag = nn.Conv1d(in_channels=n_fft, out_channels=n_fft, 
            kernel_size=1, stride=1, padding=0, dilation=1, 
            groups=1, bias=False)

        self.reverse = nn.Conv1d(in_channels=n_fft // 2 + 1, 
            out_channels=n_fft // 2 - 1, kernel_size=1, bias=False)

        self.overlap_add = nn.ConvTranspose2d(in_channels=n_fft, 
            out_channels=1, kernel_size=(n_fft, 1), stride=(self.hop_length, 1), bias=False)

        self.ifft_window_sum = []

        self.init_weights()

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def init_weights(self):
        ifft_window = librosa.filters.get_window(self.window, self.win_length, fftbins=True)
        """(win_length,)"""

        # Pad the window to n_fft
        ifft_window = librosa.util.pad_center(ifft_window, self.n_fft)

        self.conv_real.weight.data = torch.Tensor(
            np.real(self.W * ifft_window[None, :]).T)[:, :, None]
        # (n_fft // 2 + 1, 1, n_fft)

        self.conv_imag.weight.data = torch.Tensor(
            np.imag(self.W * ifft_window[None, :]).T)[:, :, None]
        # (n_fft // 2 + 1, 1, n_fft)

        tmp = np.zeros((self.n_fft // 2 - 1, self.n_fft // 2 + 1, 1))
        tmp[:, 1 : -1, 0] = np.array(np.eye(self.n_fft // 2 - 1)[::-1])
        self.reverse.weight.data = torch.Tensor(tmp)
        """(n_fft // 2 - 1, n_fft // 2 + 1, 1)"""

        self.overlap_add.weight.data = torch.Tensor(np.eye(self.n_fft)[:, None, :, None])
        """(n_fft, 1, n_fft, 1)"""

    def get_ifft_window(self, n_frames):
        device = next(self.parameters()).device

        ifft_window_sum = librosa.filters.window_sumsquare(self.window, n_frames,
            win_length=self.win_length, n_fft=self.n_fft, hop_length=self.hop_length)

        ifft_window_sum = np.clip(ifft_window_sum, 1e-8, np.inf)
        ifft_window_sum = torch.Tensor(ifft_window_sum).to(device)
        return ifft_window_sum

    def forward(self, real_stft, imag_stft, length):
        """input: (batch_size, 1, time_steps, n_fft // 2 + 1)
        Returns:
          real: (batch_size, data_length)
        """
        assert real_stft.ndimension() == 4 and imag_stft.ndimension() == 4
        device = next(self.parameters()).device
        batch_size = real_stft.shape[0]

        real_stft = real_stft[:, 0, :, :].transpose(1, 2)
        imag_stft = imag_stft[:, 0, :, :].transpose(1, 2)
        # (batch_size, n_fft // 2 + 1, time_steps)

        # Full stft, using flip is not supported by ONNX.
        # full_real_stft = torch.cat((real_stft, torch.flip(real_stft[:, 1 : -1, :], dims=[1])), dim=1)
        # full_imag_stft = torch.cat((imag_stft, - torch.flip(imag_stft[:, 1 : -1, :], dims=[1])), dim=1)
        full_real_stft = torch.cat((real_stft, self.reverse(real_stft)), dim=1)
        full_imag_stft = torch.cat((imag_stft, - self.reverse(imag_stft)), dim=1)
        """(1, n_fft, time_steps)"""

        # IDFT
        s_real = self.conv_real(full_real_stft) - self.conv_imag(full_imag_stft)
        s_real = s_real[..., None]  # (1, n_fft, time_steps, 1)
        y = self.overlap_add(s_real)[:, 0, :, 0]    # (1, samples_num)

        # Divide window
        if len(self.ifft_window_sum) != y.shape[1]:
            frames_num = real_stft.shape[2]
            self.ifft_window_sum = self.get_ifft_window(frames_num)
            
        y = y / self.ifft_window_sum[None, 0 : y.shape[1]]

        # Trim or pad to length
        if length is None:
            if self.center:
                y = y[:, self.n_fft // 2 : -self.n_fft // 2]
        else:
            if self.center:
                start = self.n_fft // 2
            else:
                start = 0

            y = y[:, start : start + length]
            (batch_size, len_y) = y.shape
            if y.shape[-1] < length:
                y = torch.cat((y, torch.zeros(batch_size, length - len_y).to(device)), dim=-1)
        
        return y
        

class Spectrogram(nn.Module):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window='hann', center=True, pad_mode='reflect', power=2.0, 
        freeze_parameters=True):
        """Calculate spectrogram using pytorch. The STFT is implemented with 
        Conv1d. The function has the same output of librosa.core.stft
        """
        super(Spectrogram, self).__init__()

        self.power = power

        self.stft = STFT(n_fft=n_fft, hop_length=hop_length, 
            win_length=win_length, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

    def forward(self, input):
        """input: (batch_size, 1, time_steps, n_fft // 2 + 1)
        Returns:
          spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1)
        """

        (real, imag) = self.stft.forward(input)
        # (batch_size, n_fft // 2 + 1, time_steps)

        spectrogram = real ** 2 + imag ** 2

        if self.power == 2.0:
            pass
        else:
            spectrogram = spectrogram ** (power / 2.0)

        return spectrogram


class LogmelFilterBank(nn.Module):
    def __init__(self, sr=32000, n_fft=2048, n_mels=64, fmin=50, fmax=14000, is_log=True, 
        ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True):
        """Calculate logmel spectrogram using pytorch. The mel filter bank is 
        the pytorch implementation of as librosa.filters.mel 
        """
        super(LogmelFilterBank, self).__init__()

        self.is_log = is_log
        self.ref = ref
        self.amin = amin
        self.top_db = top_db

        self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels,
            fmin=fmin, fmax=fmax).T
        # (n_fft // 2 + 1, mel_bins)

        self.melW = nn.Parameter(torch.Tensor(self.melW))

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, channels, time_steps)
        
        Output: (batch_size, time_steps, mel_bins)
        """

        # Mel spectrogram
        mel_spectrogram = torch.matmul(input, self.melW)

        # Logmel spectrogram
        if self.is_log:
            output = self.power_to_db(mel_spectrogram)
        else:
            output = mel_spectrogram

        return output


    def power_to_db(self, input):
        """Power to db, this function is the pytorch implementation of 
        librosa.core.power_to_lb
        """
        ref_value = self.ref
        log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
        log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))

        if self.top_db is not None:
            if self.top_db < 0:
                raise ParameterError('top_db must be non-negative')
            log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)

        return log_spec


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F


class DropStripes(nn.Module):
    def __init__(self, dim, drop_width, stripes_num):
        """Drop stripes. 

        Args:
          dim: int, dimension along which to drop
          drop_width: int, maximum width of stripes to drop
          stripes_num: int, how many stripes to drop
        """
        super(DropStripes, self).__init__()

        assert dim in [2, 3]    # dim 2: time; dim 3: frequency

        self.dim = dim
        self.drop_width = drop_width
        self.stripes_num = stripes_num

    def forward(self, input):
        """input: (batch_size, channels, time_steps, freq_bins)"""

        assert input.ndimension() == 4

        if self.training is False:
            return input

        else:
            batch_size = input.shape[0]
            total_width = input.shape[self.dim]

            for n in range(batch_size):
                self.transform_slice(input[n], total_width)

            return input


    def transform_slice(self, e, total_width):
        """e: (channels, time_steps, freq_bins)"""

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                e[:, bgn : bgn + distance, :] = 0
            elif self.dim == 3:
                e[:, :, bgn : bgn + distance] = 0


class SpecAugmentation(nn.Module):
    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, 
        freq_stripes_num):
        """Spec augmetation. 
        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. 
        and Le, Q.V., 2019. Specaugment: A simple data augmentation method 
        for automatic speech recognition. arXiv preprint arXiv:1904.08779.

        Args:
          time_drop_width: int
          time_stripes_num: int
          freq_drop_width: int
          freq_stripes_num: int
        """

        super(SpecAugmentation, self).__init__()

        self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, 
            stripes_num=time_stripes_num)

        self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, 
            stripes_num=freq_stripes_num)

    def forward(self, input):
        x = self.time_dropper(input)
        x = self.freq_dropper(x)
        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchlibrosa.stft import Spectrogram, LogmelFilterBank
# from torchlibrosa.augmentation import SpecAugmentation


def do_mixup(x, mixup_lambda):
    """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
    (1, 3, 5, ...).

    Args:
      x: (batch_size * 2, ...)
      mixup_lambda: (batch_size * 2,)

    Returns:
      out: (batch_size, ...)
    """
    out = (x[0:: 2].transpose(0, -1) * mixup_lambda[0:: 2] + \
           x[1:: 2].transpose(0, -1) * mixup_lambda[1:: 2]).transpose(0, -1)
    return out


def interpolate(x, ratio):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate

    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output, frames_num):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.

    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad

    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    pad = framewise_output[:, -1:, :].repeat(1, frames_num - framewise_output.shape[1], 1)
    """tensor for padding"""

    output = torch.cat((framewise_output, pad), dim=1)
    """(batch_size, frames_num, classes_num)"""

    return output


def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):

        super(ConvBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1),
                               padding=(1, 1), bias=False)

        self.conv2 = nn.Conv2d(in_channels=out_channels,
                               out_channels=out_channels,
                               kernel_size=(3, 3), stride=(1, 1),
                               padding=(1, 1), bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

    def forward(self, input, pool_size=(2, 2), pool_type='avg'):

        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')

        return x


class ConvBlock5x5(nn.Module):
    def __init__(self, in_channels, out_channels):

        super(ConvBlock5x5, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=(5, 5), stride=(1, 1),
                               padding=(2, 2), bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)

        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_bn(self.bn1)

    def forward(self, input, pool_size=(2, 2), pool_type='avg'):

        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')

        return x


class AttBlock(nn.Module):
    def __init__(self, n_in, n_out, activation='linear', temperature=1.):
        super(AttBlock, self).__init__()

        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
        self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)

        self.bn_att = nn.BatchNorm1d(n_out)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


class Cnn14(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
                 fmax, classes_num):

        super(Cnn14, self).__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None
        
        self.dataset_mean = 0.
        self.dataset_std = 1.

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
                                                 win_length=window_size, window=window, center=center,
                                                 pad_mode=pad_mode,
                                                 freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
                                                 n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin,
                                                 top_db=top_db,
                                                 freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=32, time_stripes_num=2,
                                               freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)

        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)

    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        x = self.spectrogram_extractor(input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)
        
        x = (x - self.dataset_mean) / self.dataset_std
        
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)

        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))

        output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return output_dict

In [None]:
def _resnet_conv3x3(in_planes, out_planes):
    # 3x3 convolution with padding
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
                     padding=1, groups=1, bias=False, dilation=1)

def _resnet_conv1x1(in_planes, out_planes):
    # 1x1 convolution
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)

class _ResnetBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(_ResnetBasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('_ResnetBasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1

        self.stride = stride

        self.conv1 = _resnet_conv3x3(inplanes, planes)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = _resnet_conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        self.init_weights()

    def init_weights(self):
        init_layer(self.conv1)
        init_bn(self.bn1)
        init_layer(self.conv2)
        init_bn(self.bn2)
        nn.init.constant_(self.bn2.weight, 0)

    def forward(self, x):
        identity = x

        if self.stride == 2:
            out = F.avg_pool2d(x, kernel_size=(2, 2))
        else:
            out = x

        out = self.conv1(out)
        out = self.bn1(out)
        out = self.relu(out)
        out = F.dropout(out, p=0.1, training=self.training)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out


class _ResnetBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(_ResnetBottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        self.stride = stride
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = _resnet_conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = _resnet_conv3x3(width, width)
        self.bn2 = norm_layer(width)
        self.conv3 = _resnet_conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        self.init_weights()

    def init_weights(self):
        init_layer(self.conv1)
        init_bn(self.bn1)
        init_layer(self.conv2)
        init_bn(self.bn2)
        init_layer(self.conv3)
        init_bn(self.bn3)
        nn.init.constant_(self.bn3.weight, 0)

    def forward(self, x):
        identity = x

        if self.stride == 2:
            x = F.avg_pool2d(x, kernel_size=(2, 2))

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = F.dropout(out, p=0.1, training=self.training)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out

class _ResNet(nn.Module):
    def __init__(self, block, layers, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(_ResNet, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            if stride == 1:
                downsample = nn.Sequential(
                    _resnet_conv1x1(self.inplanes, planes * block.expansion),
                    norm_layer(planes * block.expansion),
                )
                init_layer(downsample[0])
                init_bn(downsample[1])
            elif stride == 2:
                downsample = nn.Sequential(
                    nn.AvgPool2d(kernel_size=2),
                    _resnet_conv1x1(self.inplanes, planes * block.expansion),
                    norm_layer(planes * block.expansion),
                )
                init_layer(downsample[1])
                init_bn(downsample[2])

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

class ResNet38(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
                 fmax, classes_num):

        super(ResNet38, self).__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
                                                 win_length=window_size, window=window, center=center,
                                                 pad_mode=pad_mode,
                                                 freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
                                                 n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin,
                                                 top_db=top_db,
                                                 freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
                                               freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        # self.conv_block2 = ConvBlock(in_channels=64, out_channels=64)

        self.resnet = _ResNet(block=_ResnetBasicBlock, layers=[3, 4, 6, 3], zero_init_residual=True)

        self.conv_block_after1 = ConvBlock(in_channels=512, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048)
        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)

    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        x = self.spectrogram_extractor(input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training, inplace=True)
        x = self.resnet(x)
        x = F.avg_pool2d(x, kernel_size=(2, 2))
        x = F.dropout(x, p=0.2, training=self.training, inplace=True)
        x = self.conv_block_after1(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training, inplace=True)
        x = torch.mean(x, dim=3)

        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))

        output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return output_dict

In [None]:
_window = 'hann'
_center = True
_pad_mode = 'reflect'
_ref = 1.0
_amin = 1e-10
_top_db = None
_model_config = {
    "sample_rate": 32000,
    "window_size": 1024,  # original
    "hop_size": 320,  # original
    "mel_bins": 64,  # original
    "fmin": 50,
    "fmax": 14000,
    "classes_num": 527  # original
}

def my_cnn14(n_fft, n_mels, n_classes=100, hop_size=320, fmin=160, fmax=10300):
    model = Cnn14(**_model_config)
    model.fc_audioset = nn.Linear(2048, n_classes, bias=True)
    init_layer(model.fc_audioset)
    model.spectrogram_extractor = Spectrogram(
        n_fft=n_fft, hop_length=hop_size, win_length=n_fft, window=_window,
        center=_center, pad_mode=_pad_mode, freeze_parameters=True)
    model.logmel_extractor = LogmelFilterBank(
        sr=SAMPLE_RATE, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax,
        ref=_ref, amin=_amin, top_db=_top_db, freeze_parameters=True)
    model.bn0 = nn.BatchNorm2d(n_mels)
    init_bn(model.bn0)
    return model



In [None]:
def read_audio_fast(path):
    clip, sr_native = librosa.core.audio.__audioread_load(
        path, offset=0.0, duration=None, dtype=np.float32)
    clip = librosa.to_mono(clip)
    if sr_native > 0:
        clip = librosa.resample(clip, sr_native, SAMPLE_RATE, res_type='kaiser_fast')
    return clip, SAMPLE_RATE

def random_chunk(y, duration, sr=SAMPLE_RATE):
    sample_size = int(duration * sr)
    if y.shape[0] < sample_size:
        raise ValueError('Too short clip')
    t = np.random.randint(0, y.shape[0] - sample_size)
    return y[t: t + sample_size]

In [None]:
if os.path.exists('/kaggle/input/birdsong-recognition/test_audio'):
    BASE_TEST_DIR = '/kaggle/input/birdsong-recognition'
    TEST_AUDIO_DIR = '/kaggle/input/birdsong-recognition/test_audio'
else:
    BASE_TEST_DIR = '/kaggle/input/cornellfaketest'
    TEST_AUDIO_DIR = '/kaggle/input/cornellfaketest'
print(BASE_TEST_DIR, TEST_AUDIO_DIR)
df_test = pd.read_csv(f'{BASE_TEST_DIR}/test.csv')
df_test.shape
df_test.head()
df_test.groupby('site').nunique()

In [None]:
audio_site = {a: s for a, s in df_test[['audio_id', 'site']].drop_duplicates().values}

In [None]:
def get_model_predictions_for_clip_fast(y, model, device, classes):
    duration = y.shape[0] // SAMPLE_RATE
    batch = []
    start_seconds = []
    for start in range(0, duration - CHUNK_DURATION + 1, 5):
        end = start + 5
        start_seconds.append(start)
        chunk = y[start * SAMPLE_RATE: end * SAMPLE_RATE]
        if len(chunk) != CHUNK_DURATION * SAMPLE_RATE:
            print(chunk.shape)
            break
        batch.append(chunk)
    batch = np.asarray(batch)
    tensors = torch.from_numpy(batch)
    tensors = tensors.to(device)
    with torch.no_grad():
        preds = model(tensors)['clipwise_output']
    test_preds = preds.cpu().numpy()
    pred_df = pd.DataFrame(test_preds, columns=classes)
    pred_df['start_second'] = start_seconds
    return pred_df

In [None]:
sites = ['site_1', 'site_2', 'site_3']
audio_ids_by_site = {}
for s in sites:
    audio_ids_by_site[s] = list(df_test[df_test.site == s].audio_id.unique())
    print(s, len(audio_ids_by_site[s]))

# Site 1 & 2

In [None]:
!ls /kaggle/input/cornellmodels/

In [None]:
models = [
#     (
#         'models_CNN14_CV0_C262_M2_5139_13143345.pth',
#         {'n_fft': 1536, 'hop_size': 320, 'n_mels': 128, 'n_classes': 262, 'fmin': 160, 'fmax': 12000}
#     ),
#     (
#         'models_CNN14_CV0_C262_F18147_05073327.pth',
#         {'n_fft': 1536, 'hop_size': 360, 'n_mels': 128, 'n_classes': 262, 'fmin': 500, 'fmax': 11000}
#     ),
#     (
#         'models_CNN14_CV0_C165_M2_N0_5816_13174140.pth', 
#         {'n_fft': 1536, 'hop_size': 320, 'n_mels': 128, 'n_classes': 165, 'fmin': 300, 'fmax': 11000}
#     ),
    (
        'models_CNN14_CV1_C262_M1_N1_7045_14005959.pth', 
        {'n_fft': 1536, 'hop_size': 320, 'n_mels': 128, 'n_classes': 262, 'fmin': 300, 'fmax': 11000}
    ),
    (
        'models_CNN14_CV3_C165_M3_N1_F18256_14102207.pth', 
        {'n_fft': 1536, 'hop_size': 320, 'n_mels': 128, 'n_classes': 165, 'fmin': 300, 'fmax': 11000}
    ),
    (
        'models_CNN14_CV3_C262_M0_N1_7063_14141546.pth', 
        {'n_fft': 1536, 'hop_size': 320, 'n_mels': 128, 'n_classes': 262, 'fmin': 300, 'fmax': 11000}
    )
]

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
                      
def load_model(model_name, model_config):
    model = my_cnn14(**model_config)
    model.load_state_dict(torch.load(f'/kaggle/input/cornellmodels/{model_name}', map_location=device))
    return model

In [None]:
def get_spec_stats(ys, model):
    test_means = []
    test_stds = []
    for y in tqdm(ys):
        duration = y.shape[0] // SAMPLE_RATE
        batch = []
        for start in range(0, duration, 5):
            end = start + 5
            chunk = y[start * SAMPLE_RATE: end * SAMPLE_RATE]
            batch.append(chunk)
        batch = np.array(batch)
        batch = torch.from_numpy(batch)
        tensors = batch.to(device)
        specs = model.spectrogram_extractor(tensors)
        specs = model.logmel_extractor(specs)
        specs = specs.cpu().numpy()
        for spec in specs:
            test_means.append(spec.mean())
            test_stds.append(spec.std())
    test_mean = np.mean(test_means)
    test_std = np.mean(test_stds)
    print(f'Test stats: {test_mean:.3} +- {test_std:.3}')
    return test_mean, test_std

In [None]:
site12_preds = []

for site in ['site_1', 'site_2']:
    test_audios = {}
    for audio_id in tqdm(audio_ids_by_site[site]):
        path = os.path.join(TEST_AUDIO_DIR, f'{audio_id}.mp3')
        y, _ = read_audio_fast(path)
        test_audios[audio_id] = y
    
    for model_name, model_config in models:
        model = load_model(model_name, model_config)
        model = model.to(device)
        _ = model.eval()

        if '_C262_' in model_name:
            classes = BirdLists.all
        elif '_C165_' in model_name:
            classes = BirdLists.top_165
        elif '_C202_' in model_name:
            classes = BirdLists.non_zero
        
        if '_N1_' in model_name:
            test_mean, test_std = get_spec_stats(test_audios.values(), model)
            model.dataset_mean = test_mean
            model.dataset_std = test_std
        
        for audio_id, y in tqdm(test_audios.items()):
            pred_df = get_model_predictions_for_clip_fast(y, model, device, classes)
            pred_df['audio_id'] = audio_id
            pred_df['site'] = site
            site12_preds.append(pred_df)
predictions_12 = pd.concat(site12_preds)

In [None]:
predictions_12.shape
predictions_12.head()
predictions_12.tail()
predictions_12.count()


In [None]:
zero = [
    'redcro', 'norcar', 'houwre', 'spotow', 'bewwre', 'carwre', 'swathr', 'comrav', 'warvir', 'whbnut', 'marwre',
    'normoc', 'wesmea', 'amered', 'whtspa', 'houfin', 'ruckin', 'eastow', 'tuftit', 'chispa', 'grhowl', 'savspa',
    'wlswar', 'buggna', 'brncre', 'aldfly', 'orcwar', 'bkhgro', 'bnhcow', 'houspa', 'yerwar', 'haiwoo', 'dowwoo',
    'bawwar', 'foxspa', 'balori', 'rebnut', 'eursta', 'gockin', 'killde', 'comgra', 'barswa', 'osprey', 'mallar3',
    'pilwoo', 'moudov', 'easpho', 'comloo', 'treswa', 'easblu', 'rethaw', 'grbher3', 'yelwar', 'greegr', 'ribgul',
    'nutwoo', 'rocpig', 'brnthr', 'norpar', 'bulori',
]

common = ['whcspa', 'grycat', 'herthr', 'amerob', 'reevir1', 'comyel', 'norfli', 'amecro']
rare = [
    'doccor', 'baleag', 'buffle', 'amekes', 'annhum', 'reshaw', 'coohaw', 'lesgol', 'norsho', 'commer', 'norhar2',
    'rinduc', 'amewig', 'blkpho', 'hoomer', 'comgol', 'cowscj1', 'wiltur', 'buwtea', 'rudduc', 'bushti', 'rebmer',
    'stejay', 'norpin', 'vigswa', 'gocspa', 'redhea', 'wesblu', 'shshaw', 'horgre', 'calqua', 'perfal', 'labwoo',
    'wewpew', 'bkchum', 'truswa', 'lotduc', 'brwhaw', 'pasfly', 'swahaw', 'eargre', 'whfibi', 'tunswa', 'phaino',
    'brthum', 'rufgro', 'wooscj2', 'cacwre', 'rebsap', 'whtswi', 'goleag', 'snobun', 'renpha', 'coshum', 'baisan',
    'sheowl', 'lewwoo', 'lesnig', 'gcrfin', 'sagspa1', 'chukar', 'lecthr',
]

BIRD_THRESHOLD = 0.33
THRESHOLDS = {c: BIRD_THRESHOLD for c in BirdLists.all}
for c in zero:
    THRESHOLDS[c] = 0.6
for c in common:
    THRESHOLDS[c] = 0.3
for c in rare:
    THRESHOLDS[c] = 0.55

THRESHOLDS

In [None]:
mean_predictions_12 = predictions_12.fillna(0).groupby(['site', 'audio_id', 'start_second']).mean().reset_index()
mean_predictions_12['birdmax'] = mean_predictions_12[BirdLists.all].max(axis=1)
mean_predictions_12['seconds'] = mean_predictions_12.start_second + CHUNK_DURATION
mean_predictions_12['row_id'] = mean_predictions_12.site + '_' + mean_predictions_12.audio_id + '_' + mean_predictions_12.seconds.astype(str)
multi_birds = []
for _, row in mean_predictions_12[BirdLists.all].iterrows():
    guesses = ' '.join([bird for bird, p in row.items() if p >= THRESHOLDS[bird]])
    multi_birds.append(guesses)
mean_predictions_12['bird_guesses'] = multi_birds
mean_predictions_12.shape
mean_predictions_12.head()

In [None]:
mean_predictions_12[BirdLists.all].max(axis=1).hist()
np.mean(mean_predictions_12.birdmax > 0.5)
np.mean(mean_predictions_12.birdmax > 0.4)
np.mean(mean_predictions_12.birdmax > 0.33)
np.mean(mean_predictions_12.birdmax > 0.3)


In [None]:
site12_birds = mean_predictions_12.loc[mean_predictions_12.bird_guesses.str.len() > 1, ['row_id', 'bird_guesses']]
site12_birds.shape
site12_birds.head()

# Site 3

In [None]:
site3_preds = []
site = 'site_3'
test_audios = {}
for audio_id in tqdm(audio_ids_by_site[site]):
    path = os.path.join(TEST_AUDIO_DIR, f'{audio_id}.mp3')
    y, _ = read_audio_fast(path)
    test_audios[audio_id] = y

for model_name, model_config in models[:1]:
    model = load_model(model_name, model_config)
    model = model.to(device)
    _ = model.eval()
    classes = BirdLists.all
    for audio_id, y in tqdm(test_audios.items()):
        pred_df = get_model_predictions_for_clip_fast(y, model, device, classes)
        pred_df['audio_id'] = audio_id
        pred_df['site'] = site
        site3_preds.append(pred_df)
predictions_3 = pd.concat(site3_preds)

In [None]:

predictions_3['birdmax'] = predictions_3[BirdLists.all].max(axis=1)
predictions_3['bird_guess'] = [BirdLists.all[idx] for idx in predictions_3[BirdLists.all].values.argmax(axis=1)]

In [None]:
guesses_3 = predictions_3[predictions_3.birdmax > 0.5].groupby(['audio_id', 'bird_guess'])[['site']].count().reset_index()
guesses_3

In [None]:
site_3_birds = []
for audio_id, df in guesses_3.groupby('audio_id'):
    site_3_birds.append([audio_id, ' '.join(df.bird_guess.values)])
site_3_birds = pd.DataFrame(site_3_birds, columns=['audio_id', 'bird_guesses'])    
site_3_birds = site_3_birds.merge(df_test[['audio_id', 'row_id']], how='left', on='audio_id')
site_3_birds

# Submission

In [None]:
guesses = pd.concat([
    site12_birds[['row_id', 'bird_guesses']],
    site_3_birds[['row_id', 'bird_guesses']]
])
guesses.shape
guesses.head()
guesses.tail()

In [None]:
submission = pd.read_csv('/kaggle/input/birdsong-recognition/sample_submission.csv')
submission.head()
submission.shape


submission = submission.merge(guesses[['row_id', 'bird_guesses']], how='left', on=['row_id'])


idx = (submission.bird_guesses.str.len() > 1)
submission.loc[idx, 'birds'] = submission.loc[idx, 'bird_guesses']
submission

In [None]:
submission[['row_id', 'birds']].to_csv('submission.csv', index=False)

In [None]:
end = datetime.now()
print(end, (end - nb_start).seconds)