## Dependencies

In [None]:
!pip install ../input/timm-pytorch-image-models/pytorch-image-models-master/

## Libraries

In [None]:
# import glob
import os
import random
import warnings
from functools import partial

# import colorednoise as cn
import librosa
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import scipy as sp
import soundfile as sf
import timm
import torch
import torch.optim as optim
from pytorch_lightning import LightningDataModule, callbacks

# from pytorch_lightning.utilities import rank_zero_info
from sklearn import model_selection
from sklearn.metrics import f1_score
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram


In [None]:
import cv2
import audioread
import logging
import os
import random
import time
import warnings

import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from albumentations.core.transforms_interface import ImageOnlyTransform
# from torchlibrosa.stft import LogmelFilterBank, Spectrogram
# from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm

## Utilities

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
    
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger
    
    
@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)

In [None]:
logger = get_logger("main.log")
set_seed(1213)

## Config

In [None]:
class CFG:
    ######################
    # Globals #
    ######################
    seed = 1213
    epochs = 35
    train = True
    folds = [0]
    img_size = 224
    main_metric = "epoch_f1_at_05"
    minimize_metric = False

    ######################
    # Data #
    ######################
    train_datadir = Path("../input/birdclef-2021/train_short_audio")
    train_csv = "../input/birdclef-2021/train_metadata.csv"
    train_soundscape = "../input/birdclef-2021/train_soundscape_labels.csv"

    ######################
    # Dataset #
    ######################
    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}],
        "test": [{"name": "Normalize"}]
    }
    period = 30
#     period = 5
    n_mels = 128
    
    sample_rate = 32000

    target_columns = [
        'acafly', 'acowoo', 'aldfly', 'ameavo', 'amecro',
        'amegfi', 'amekes', 'amepip', 'amered', 'amerob',
        'amewig', 'amtspa', 'andsol1', 'annhum', 'astfly',
        'azaspi1', 'babwar', 'baleag', 'balori', 'banana',
        'banswa', 'banwre1', 'barant1', 'barswa', 'batpig1',
        'bawswa1', 'bawwar', 'baywre1', 'bbwduc', 'bcnher',
        'belkin1', 'belvir', 'bewwre', 'bkbmag1', 'bkbplo',
        'bkbwar', 'bkcchi', 'bkhgro', 'bkmtou1', 'bknsti', 'blbgra1',
        'blbthr1', 'blcjay1', 'blctan1', 'blhpar1', 'blkpho',
        'blsspa1', 'blugrb1', 'blujay', 'bncfly', 'bnhcow', 'bobfly1',
        'bongul', 'botgra', 'brbmot1', 'brbsol1', 'brcvir1', 'brebla',
        'brncre', 'brnjay', 'brnthr', 'brratt1', 'brwhaw', 'brwpar1',
        'btbwar', 'btnwar', 'btywar', 'bucmot2', 'buggna', 'bugtan',
        'buhvir', 'bulori', 'burwar1', 'bushti', 'butsal1', 'buwtea',
        'cacgoo1', 'cacwre', 'calqua', 'caltow', 'cangoo', 'canwar',
        'carchi', 'carwre', 'casfin', 'caskin', 'caster1', 'casvir',
        'categr', 'ccbfin', 'cedwax', 'chbant1', 'chbchi', 'chbwre1',
        'chcant2', 'chispa', 'chswar', 'cinfly2', 'clanut', 'clcrob',
        'cliswa', 'cobtan1', 'cocwoo1', 'cogdov', 'colcha1', 'coltro1',
        'comgol', 'comgra', 'comloo', 'commer', 'compau', 'compot1',
        'comrav', 'comyel', 'coohaw', 'cotfly1', 'cowscj1', 'cregua1',
        'creoro1', 'crfpar', 'cubthr', 'daejun', 'dowwoo', 'ducfly', 'dusfly',
        'easblu', 'easkin', 'easmea', 'easpho', 'eastow', 'eawpew', 'eletro',
        'eucdov', 'eursta', 'fepowl', 'fiespa', 'flrtan1', 'foxspa', 'gadwal',
        'gamqua', 'gartro1', 'gbbgul', 'gbwwre1', 'gcrwar', 'gilwoo',
        'gnttow', 'gnwtea', 'gocfly1', 'gockin', 'gocspa', 'goftyr1',
        'gohque1', 'goowoo1', 'grasal1', 'grbani', 'grbher3', 'grcfly',
        'greegr', 'grekis', 'grepew', 'grethr1', 'gretin1', 'greyel',
        'grhcha1', 'grhowl', 'grnher', 'grnjay', 'grtgra', 'grycat',
        'gryhaw2', 'gwfgoo', 'haiwoo', 'heptan', 'hergul', 'herthr',
        'herwar', 'higmot1', 'hofwoo1', 'houfin', 'houspa', 'houwre',
        'hutvir', 'incdov', 'indbun', 'kebtou1', 'killde', 'labwoo', 'larspa',
        'laufal1', 'laugul', 'lazbun', 'leafly', 'leasan', 'lesgol', 'lesgre1',
        'lesvio1', 'linspa', 'linwoo1', 'littin1', 'lobdow', 'lobgna5', 'logshr',
        'lotduc', 'lotman1', 'lucwar', 'macwar', 'magwar', 'mallar3', 'marwre',
        'mastro1', 'meapar', 'melbla1', 'monoro1', 'mouchi', 'moudov', 'mouela1',
        'mouqua', 'mouwar', 'mutswa', 'naswar', 'norcar', 'norfli', 'normoc', 'norpar',
        'norsho', 'norwat', 'nrwswa', 'nutwoo', 'oaktit', 'obnthr1', 'ocbfly1',
        'oliwoo1', 'olsfly', 'orbeup1', 'orbspa1', 'orcpar', 'orcwar', 'orfpar',
        'osprey', 'ovenbi1', 'pabspi1', 'paltan1', 'palwar', 'pasfly', 'pavpig2',
        'phivir', 'pibgre', 'pilwoo', 'pinsis', 'pirfly1', 'plawre1', 'plaxen1',
        'plsvir', 'plupig2', 'prowar', 'purfin', 'purgal2', 'putfru1', 'pygnut',
        'rawwre1', 'rcatan1', 'rebnut', 'rebsap', 'rebwoo', 'redcro', 'reevir1',
        'rehbar1', 'relpar', 'reshaw', 'rethaw', 'rewbla', 'ribgul', 'rinkin1',
        'roahaw', 'robgro', 'rocpig', 'rotbec', 'royter1', 'rthhum', 'rtlhum',
        'ruboro1', 'rubpep1', 'rubrob', 'rubwre1', 'ruckin', 'rucspa1', 'rucwar',
        'rucwar1', 'rudpig', 'rudtur', 'rufhum', 'rugdov', 'rumfly1', 'runwre1',
        'rutjac1', 'saffin', 'sancra', 'sander', 'savspa', 'saypho', 'scamac1',
        'scatan', 'scbwre1', 'scptyr1', 'scrtan1', 'semplo', 'shicow', 'sibtan2',
        'sinwre1', 'sltred', 'smbani', 'snogoo', 'sobtyr1', 'socfly1', 'solsan',
        'sonspa', 'soulap1', 'sposan', 'spotow', 'spvear1', 'squcuc1', 'stbori',
        'stejay', 'sthant1', 'sthwoo1', 'strcuc1', 'strfly1', 'strsal1', 'stvhum2',
        'subfly', 'sumtan', 'swaspa', 'swathr', 'tenwar', 'thbeup1', 'thbkin',
        'thswar1', 'towsol', 'treswa', 'trogna1', 'trokin', 'tromoc', 'tropar',
        'tropew1', 'tuftit', 'tunswa', 'veery', 'verdin', 'vigswa', 'warvir',
        'wbwwre1', 'webwoo1', 'wegspa1', 'wesant1', 'wesblu', 'weskin', 'wesmea',
        'westan', 'wewpew', 'whbman1', 'whbnut', 'whcpar', 'whcsee1', 'whcspa',
        'whevir', 'whfpar1', 'whimbr', 'whiwre1', 'whtdov', 'whtspa', 'whwbec1',
        'whwdov', 'wilfly', 'willet1', 'wilsni1', 'wiltur', 'wlswar', 'wooduc',
        'woothr', 'wrenti', 'y00475', 'yebcha', 'yebela1', 'yebfly', 'yebori1',
        'yebsap', 'yebsee1', 'yefgra1', 'yegvir', 'yehbla', 'yehcar1', 'yelgro',
        'yelwar', 'yeofly1', 'yerwar', 'yeteup1', 'yetvir']

    ######################
    # Loaders #
    ######################
    loader_params = {
        "train": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": True
        },
        "valid": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": False
        },
        "test": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": False
        }
    }

    ######################
    # Model #
    ######################
    models_cfg = [
        {"resnet34": "../input/birdclef20201-checkpoints/checkpoints/exp004/fold0__20210417-17-00-45/lightning_logs/version_0/checkpoints/best_loss.ckpt",
},
        {"efficientnet_b0": "../input/birdclef20201-checkpoints/checkpoints/exp004/fold0__20210418-10-44-57/best_loss.ckpt",},
        {"resnet18": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210420-06-34-46/best_loss.ckpt"},
        {"ecaresnet26t": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210421-00-17-35/best_loss.ckpt",},
        {"mixnet_m": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210421-11-01-38/best_loss.ckpt",},
        {"repvgg_b0": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210421-21-55-49/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210422-11-13-45/best_loss.ckpt"},
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints/checkpoints/exp005/fold0__20210423-04-29-31/best_loss.ckpt",},
        
        {"repvgg_b0": "../input/birdclef20201-checkpoints/checkpoints/exp008/repvgg_b0/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        
        {"resnet34": "../input/birdclef20201-checkpoints/checkpoints/exp009/resnet34/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"efficientnet_b0": "../input/birdclef20201-checkpoints/checkpoints/exp009/efficientnet_b0/fold0/lightning_logs/version_1/checkpoints/best_loss.ckpt",},
        {"resnest50d_1s4x24d":"../input/birdclef20201-checkpoints/checkpoints/exp009/resnest50d_1s4x24d/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest50d_1s4x24d":"../input/birdclef20201-checkpoints/checkpoints/exp009/resnest50d_1s4x24d/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet18":"../input/birdclef20201-checkpoints/checkpoints/exp009/resnet18/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"ecaresnet26t": "../input/birdclef20201-checkpoints/checkpoints/exp009/ecaresnet26t/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"ecaresnet26t": "../input/birdclef20201-checkpoints/checkpoints/exp009/ecaresnet26t/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"mixnet_m": "../input/birdclef20201-checkpoints/checkpoints/exp009/mixnet_m/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"resnest26d": "../input/birdclef20201-checkpoints/checkpoints/exp009/resnest26d/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints/checkpoints/exp009/resnest26d/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints/checkpoints/exp009/resnest26d/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"repvgg_b0":"../input/birdclef20201-checkpoints/checkpoints/exp009/repvgg_b0/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"repvgg_b0":"../input/birdclef20201-checkpoints/checkpoints/exp009/repvgg_b0/fold3/lightning_logs/version_0/checkpoints/checkpoints/best_loss.ckpt"},
        {"repvgg_b2":"../input/birdclef20201-checkpoints/checkpoints/exp009/repvgg_b2/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"repvgg_b2":"../input/birdclef20201-checkpoints/checkpoints/exp009/repvgg_b2/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        
        {"resnet34": "../input/birdclef20201-checkpoints/checkpoints/exp009/resnet34/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"efficientnet_b0": "../input/birdclef20201-checkpoints/checkpoints/exp009/efficientnet_b0/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints2/checkpoints/exp009/resnest50d_1s4x24d/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet18": "../input/birdclef20201-checkpoints2/checkpoints/exp009/resnet18/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"ecaresnet26t": "../input/birdclef20201-checkpoints2/checkpoints/exp009/ecaresnet26t/fold0/lightning_logs/version_1/checkpoints/best_loss.ckpt",},
        {"mixnet_m": "../input/birdclef20201-checkpoints2/checkpoints/exp009/mixnet_m/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints2/checkpoints/exp009/resnest26d/fold2/lightning_logs/version_1/checkpoints/best_loss.ckpt"},
        {"repvgg_b0": "../input/birdclef20201-checkpoints2/checkpoints/exp009/repvgg_b0/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"repvgg_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp009/repvgg_b2/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        
        {"efficientnet_b0": "../input/birdclef20201-checkpoints2/checkpoints/exp010/efficientnet_b0_ft30/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest50d_1s4x24d_ft30/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"ecaresnet26t":"../input/birdclef20201-checkpoints2/checkpoints/exp010/ecaresnet26t_ft30/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest26d_ft30/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"repvgg_b2":"../input/birdclef20201-checkpoints2/checkpoints/exp010/repvgg_b2_ft30/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        
        {"efficientnet_b0":"../input/birdclef20201-checkpoints2/checkpoints/exp010/efficientnet_b0_ft30/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest50d_1s4x24d_ft30/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"ecaresnet26t": "../input/birdclef20201-checkpoints2/checkpoints/exp010/ecaresnet26t_ft30/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest26d_ft30/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnest26d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest26d_ft30/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"repvgg_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp010/repvgg_b2_ft30/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints2/checkpoints/exp009/resnest50d_1s4x24d/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d":"../input/birdclef20201-checkpoints2/checkpoints/exp009/resnet50d/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d":"../input/birdclef20201-checkpoints2/checkpoints/exp009/resnet50d/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"efficientnet_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp009/efficientnet_b2/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"efficientnet_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp009/efficientnet_b2/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        
        {"resnest50d_1s4x24d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnest50d_1s4x24d_ft30/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnet50d_ft30/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d": "../input/birdclef20201-checkpoints2/checkpoints/exp010/resnet50d_ft30/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt", },
        {"efficientnet_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp010/efficientnet_b2_ft30/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"efficientnet_b2": "../input/birdclef20201-checkpoints2/checkpoints/exp010/efficientnet_b2_ft30/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        
        {"wide_resnet50_2": "../input/birdclef20201-checkpoints2/checkpoints/exp009/wide_resnet50_2/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"wide_resnet50_2": "../input/birdclef20201-checkpoints2/checkpoints/exp009/wide_resnet50_2/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        
        {"repvgg_b2": "../input/birdclef20201-checkpoints3/checkpoints/exp011/repvgg_b2/fold0/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"repvgg_b2": "../input/birdclef20201-checkpoints3/checkpoints/exp011/repvgg_b2/fold1/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d":"../input/birdclef20201-checkpoints3/checkpoints/exp011/resnet50d/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"resnet50d":"../input/birdclef20201-checkpoints3/checkpoints/exp011/resnet50d/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt",},
        {"wide_resnet50_2":"../input/birdclef20201-checkpoints3/checkpoints/exp011/wide_resnet50_2/fold2/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
        {"wide_resnet50_2":"../input/birdclef20201-checkpoints3/checkpoints/exp011/wide_resnet50_2/fold3/lightning_logs/version_0/checkpoints/best_loss.ckpt"},
    ]
    pretrained = False
    num_classes = 397
    in_channels = 1
print(f"model_num: {len(CFG.models_cfg)}")

## Data Loading

In [None]:
TARGET_SR = 32000
TEST = (len(list(Path("../input/birdclef-2021/test_soundscapes/").glob("*.ogg"))) != 0)
if TEST:
    DATADIR = Path("../input/birdclef-2021/test_soundscapes/")
else:
    DATADIR = Path("../input/birdclef-2021/train_soundscapes/")

In [None]:
all_audios = list(DATADIR.glob("*.ogg"))
all_audio_ids = ["_".join(audio_id.name.split("_")[:2]) for audio_id in all_audios]
submission_df = pd.DataFrame({
    "row_id": all_audio_ids
})
submission_df

## Define Model

In [None]:
def gem_freq(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), 1)).pow(1.0 / p)


class GeMFreq(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem_freq(x, p=self.p, eps=self.eps)


class NormalizeMelSpec(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, X):
        mean = X.mean((1, 2), keepdim=True)
        std = X.std((1, 2), keepdim=True)
        Xstd = (X - mean) / (std + self.eps)
        norm_min, norm_max = Xstd.min(-1)[0].min(-1)[0], Xstd.max(-1)[0].max(-1)[0]
        fix_ind = (norm_max - norm_min) > self.eps * torch.ones_like(
            (norm_max - norm_min)
        )
        V = torch.zeros_like(Xstd)
        if fix_ind.sum():
            V_fix = Xstd[fix_ind]
            norm_max_fix = norm_max[fix_ind, None, None]
            norm_min_fix = norm_min[fix_ind, None, None]
            V_fix = torch.max(
                torch.min(V_fix, norm_max_fix),
                norm_min_fix,
            )
            # print(V_fix.shape, norm_min_fix.shape, norm_max_fix.shape)
            V_fix = (V_fix - norm_min_fix) / (norm_max_fix - norm_min_fix)
            V[fix_ind] = V_fix
        return V


class AttHead(nn.Module):
    def __init__(
        self, in_chans, p=0.5, num_class=397, train_period=15.0, infer_period=5.0
    ):
        super().__init__()
        self.train_period = train_period
        self.infer_period = infer_period
        self.pooling = GeMFreq()

        self.dense_layers = nn.Sequential(
            nn.Dropout(p / 2),
            nn.Linear(in_chans, 512),
            nn.ReLU(),
            nn.Dropout(p),
        )
        self.attention = nn.Conv1d(
            in_channels=512,
            out_channels=num_class,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
        self.fix_scale = nn.Conv1d(
            in_channels=512,
            out_channels=num_class,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )

    def forward(self, feat):
        feat = self.pooling(feat).squeeze(-2).permute(0, 2, 1)  # (bs, time, ch)

        feat = self.dense_layers(feat).permute(0, 2, 1)  # (bs, 512, time)
        time_att = torch.tanh(self.attention(feat))
        assert self.train_period >= self.infer_period
        if self.training or self.train_period == self.infer_period:

            clipwise_pred = torch.sum(
                torch.sigmoid(self.fix_scale(feat)) * torch.softmax(time_att, dim=-1),
                dim=-1,
            )  # sum((bs, 24, time), -1) -> (bs, 24)
            logits = torch.sum(
                self.fix_scale(feat) * torch.softmax(time_att, dim=-1),
                dim=-1,
            )
        else:
            framewise_pred_long = torch.sigmoid(self.fix_scale(feat))
            clipwise_pred_long = torch.sum(framewise_pred_long * torch.softmax(time_att, dim=-1), dim=-1) 
            
            feat_time = feat.size(-1)
            start = (
                feat_time / 2 - feat_time * (self.infer_period / self.train_period) / 2
            )
            end = start + feat_time * (self.infer_period / self.train_period)
            start = int(start)
            end = int(end)
            feat = feat[:, :, start:end]
            att = torch.softmax(time_att[:, :, start:end], dim=-1)
#             print(feat_time, start, end)
#             print(att_a.sum(), att.sum(), time_att.shape)
            framewise_pred = torch.sigmoid(self.fix_scale(feat))
            clipwise_pred = torch.sum(framewise_pred * att, dim=-1) 
            logits = torch.sum(
                self.fix_scale(feat) * att,
                dim=-1,
            )
            time_att = time_att[:, :, start:end]
        return (
            logits,
            clipwise_pred,
            self.fix_scale(feat).permute(0, 2, 1),
            time_att.permute(0, 2, 1),
            clipwise_pred_long,
        )


class AttModel(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        p=0.5,
        n_mels=224,
        num_class=397,
        train_period=15.0,
        infer_period=5.0,
        in_chans=1,
    ):
        super().__init__()
        self.n_mels = n_mels
        self.logmelspec_extractor = nn.Sequential(
            MelSpectrogram(
                32000,
                n_mels=n_mels,
                f_min=20,
                n_fft=2048,
                hop_length=512,
                normalized=True,
            ),
            AmplitudeToDB(top_db=80.0),
            NormalizeMelSpec(),
        )

        self.backbone = timm.create_model(
            backbone, features_only=True, pretrained=False, in_chans=in_chans
        )
        encoder_channels = self.backbone.feature_info.channels()
        dense_input = encoder_channels[-1]
        self.head = AttHead(
            dense_input,
            p=p,
            num_class=num_class,
            train_period=train_period,
            infer_period=infer_period,
        )

    def forward(self, input):
#         img = self.logmelspec_extractor(input)[
#             :, None
#         ]  # (batch_size, 1, mel_bins, time_steps)
        feats = self.backbone(input)
        return self.head(feats[-1])
    
class Model(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        p=0.5,
        n_mels=224,
        num_class=397,
        train_period=15.0,
        infer_period=5.0,
        in_chans=1,
    ):
        super().__init__()
        self.model = AttModel(backbone, p, n_mels, num_class, train_period, infer_period, in_chans)


## Dataset

In [None]:
class TestDataset(torchdata.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray, train_period=30, 
                 waveform_transforms=None):
        self.df = df
        self.clip = np.concatenate([clip, clip, clip])
        self.train_period = train_period
        self.waveform_transforms=waveform_transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        row_id = sample.row_id

        end_seconds = int(sample.seconds)
        start_seconds = int(end_seconds - 5)
        
        end_index = int(SR * (end_seconds + (self.train_period - 5) / 2) + len(self.clip) // 3)
        start_index = int(SR * (start_seconds - (self.train_period - 5) / 2) + len(self.clip) // 3)
        
        y = self.clip[start_index:end_index].astype(np.float32)

        y = np.nan_to_num(y)

        if self.waveform_transforms:
            y = self.waveform_transforms(y)

        y = np.nan_to_num(y)
        
        return y, row_id

In [None]:
def get_transforms(phase: str):
    transforms = CFG.transforms
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if globals().get(trns_name) is not None:
                trns_cls = globals()[trns_name]
                trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return Compose(trns_list)
        else:
            return None


def get_waveform_transforms(config: dict, phase: str):
    return get_transforms(config, phase)


def get_spectrogram_transforms(config: dict, phase: str):
    transforms = config.get('spectrogram_transforms')
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if hasattr(A, trns_name):
                trns_cls = A.__getattribute__(trns_name)
                trns_list.append(trns_cls(**trns_params))
            else:
                trns_cls = globals().get(trns_name)
                if trns_cls is not None:
                    trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return A.Compose(trns_list, p=1.0)
        else:
            return None


class Normalize:
    def __call__(self, y: np.ndarray):
        max_vol = np.abs(y).max()
        y_vol = y * 1 / max_vol
        return np.asfortranarray(y_vol)


class NewNormalize:
    def __call__(self, y: np.ndarray):
        y_mm = y - y.mean()
        return y_mm / y_mm.abs().max()


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y


class AudioTransform:
    def __init__(self, always_apply=False, p=0.5):
        self.always_apply = always_apply
        self.p = p

    def __call__(self, y: np.ndarray):
        if self.always_apply:
            return self.apply(y)
        else:
            if np.random.rand() < self.p:
                return self.apply(y)
            else:
                return y

    def apply(self, y: np.ndarray):
        raise NotImplementedError


class NoiseInjection(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_noise_level=0.5, sr=32000):
        super().__init__(always_apply, p)

        self.noise_level = (0.0, max_noise_level)
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        noise_level = np.random.uniform(*self.noise_level)
        noise = np.random.randn(len(y))
        augmented = (y + noise * noise_level).astype(y.dtype)
        return augmented


class GaussianNoise(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, min_snr=5, max_snr=20, sr=32000):
        super().__init__(always_apply, p)

        self.min_snr = min_snr
        self.max_snr = max_snr
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        snr = np.random.uniform(self.min_snr, self.max_snr)
        a_signal = np.sqrt(y ** 2).max()
        a_noise = a_signal / (10 ** (snr / 20))

        white_noise = np.random.randn(len(y))
        a_white = np.sqrt(white_noise ** 2).max()
        augmented = (y + white_noise * 1 / a_white * a_noise).astype(y.dtype)
        return augmented


class PinkNoise(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, min_snr=5, max_snr=20, sr=32000):
        super().__init__(always_apply, p)

        self.min_snr = min_snr
        self.max_snr = max_snr
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        snr = np.random.uniform(self.min_snr, self.max_snr)
        a_signal = np.sqrt(y ** 2).max()
        a_noise = a_signal / (10 ** (snr / 20))

        pink_noise = cn.powerlaw_psd_gaussian(1, len(y))
        a_pink = np.sqrt(pink_noise ** 2).max()
        augmented = (y + pink_noise * 1 / a_pink * a_noise).astype(y.dtype)
        return augmented


class PitchShift(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_range=5, sr=32000):
        super().__init__(always_apply, p)
        self.max_range = max_range
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        n_steps = np.random.randint(-self.max_range, self.max_range)
        augmented = librosa.effects.pitch_shift(y, self.sr, n_steps)
        return augmented


class TimeStretch(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_rate=1, sr=32000):
        super().__init__(always_apply, p)
        self.max_rate = max_rate
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        rate = np.random.uniform(0, self.max_rate)
        augmented = librosa.effects.time_stretch(y, rate)
        return augmented


def _db2float(db: float, amplitude=True):
    if amplitude:
        return 10**(db / 20)
    else:
        return 10 ** (db / 10)


def volume_down(y: np.ndarray, db: float):
    """
    Low level API for decreasing the volume
    Parameters
    ----------
    y: numpy.ndarray
        stereo / monaural input audio
    db: float
        how much decibel to decrease
    Returns
    -------
    applied: numpy.ndarray
        audio with decreased volume
    """
    applied = y * _db2float(-db)
    return applied


def volume_up(y: np.ndarray, db: float):
    """
    Low level API for increasing the volume
    Parameters
    ----------
    y: numpy.ndarray
        stereo / monaural input audio
    db: float
        how much decibel to increase
    Returns
    -------
    applied: numpy.ndarray
        audio with increased volume
    """
    applied = y * _db2float(db)
    return applied


class RandomVolume(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, limit=10):
        super().__init__(always_apply, p)
        self.limit = limit

    def apply(self, y: np.ndarray, **params):
        db = np.random.uniform(-self.limit, self.limit)
        if db >= 0:
            return volume_up(y, db)
        else:
            return volume_down(y, db)


class OneOf:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        n_trns = len(self.transforms)
        trns_idx = np.random.choice(n_trns)
        trns = self.transforms[trns_idx]
        y = trns(y)
        return y


class CosineVolume(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, limit=10):
        super().__init__(always_apply, p)
        self.limit = limit

    def apply(self, y: np.ndarray, **params):
        db = np.random.uniform(-self.limit, self.limit)
        cosine = np.cos(np.arange(len(y)) / len(y) * np.pi * 2)
        dbs = _db2float(cosine * db)
        return y * dbs


def drop_stripes(image: np.ndarray, dim: int, drop_width: int, stripes_num: int):
    total_width = image.shape[dim]
    lowest_value = image.min()
    for _ in range(stripes_num):
        distance = np.random.randint(low=0, high=drop_width, size=(1,))[0]
        begin = np.random.randint(
            low=0, high=total_width - distance, size=(1,))[0]

        if dim == 0:
            image[begin:begin + distance] = lowest_value
        elif dim == 1:
            image[:, begin + distance] = lowest_value
        elif dim == 2:
            image[:, :, begin + distance] = lowest_value
    return image


class TimeFreqMasking(ImageOnlyTransform):
    def __init__(self,
                 time_drop_width: int,
                 time_stripes_num: int,
                 freq_drop_width: int,
                 freq_stripes_num: int,
                 always_apply=False,
                 p=0.5):
        super().__init__(always_apply, p)
        self.time_drop_width = time_drop_width
        self.time_stripes_num = time_stripes_num
        self.freq_drop_width = freq_drop_width
        self.freq_stripes_num = freq_stripes_num

    def apply(self, img, **params):
        img_ = img.copy()
        if img.ndim == 2:
            img_ = drop_stripes(
                img_, dim=0, drop_width=self.freq_drop_width, stripes_num=self.freq_stripes_num)
            img_ = drop_stripes(
                img_, dim=1, drop_width=self.time_drop_width, stripes_num=self.time_stripes_num)
        return img_

## Get model

In [None]:
from torch.nn.modules.batchnorm import _BatchNorm

def prepare_model_for_inference(model, path: Path):
    if not torch.cuda.is_available():
        ckpt = torch.load(path, map_location="cpu")
    else:
        ckpt = torch.load(path)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    return model

In [None]:
def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        models, 
                        threshold=0.05, 
                        threshold_long=None):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          train_period = CFG.period, 
                          waveform_transforms=get_transforms(phase="test"))
    loader = torchdata.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
#     [model.eval() for model in models]
    prediction_dict = {}
    for image, row_id in tqdm(loader):
        row_id = row_id[0]
        image = image.to(device)

        with torch.no_grad():
            image = models[0].logmelspec_extractor(image)[:, None]
            probas = []
            probas_long = []
            for model in models:
                with torch.cuda.amp.autocast():
                    _, clipwise_pred, _, _, clipwise_pred_long = model(image)
                probas.append(clipwise_pred.detach().cpu().numpy().reshape(-1))
                probas_long.append(clipwise_pred_long.detach().cpu().numpy().reshape(-1))
            probas = np.array(probas)
            probas_long = np.array(probas_long)
#             probas = np.array([model(image)[1].detach().cpu().numpy().reshape(-1) for model in models])
        if threshold_long is None:
            events = probas.mean(0) >= threshold
        else:
            events = ((probas.mean(0) >= threshold).astype(int) \
                      + (probas_long.mean(0) >= threshold_long).astype(int)) >= 2
        labels = np.argwhere(events).reshape(-1).tolist()
#         labels = labels[:2]
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: CFG.target_columns[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [None]:
def load_model(backbone_name, weight_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Model(
        backbone_name,
        p=0.5,
        n_mels=CFG.n_mels,
        num_class=CFG.num_classes,
        train_period=CFG.period,
        infer_period=5,
    )
    model = prepare_model_for_inference(model, weight_path).to(device)
    model = model.model
    return model

def prediction(test_audios,
               models_cfg,
               threshold=0.05, 
               threshold_long=None):
    
    models = [load_model(list(models_cfg.keys())[0], list(models_cfg.values())[0]) for models_cfg in models_cfg]
    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_path in test_audios:
        with timer(f"Loading {str(audio_path)}", logger):
            clip, _ = sf.read(audio_path)

        seconds = []
        row_ids = []
        for second in range(5, 605, 5):
            row_id = "_".join(audio_path.name.split("_")[:2]) + f"_{second}"
            seconds.append(second)
            row_ids.append(row_id)
            
        test_df = pd.DataFrame({
            "row_id": row_ids,
            "seconds": seconds
        })
        with timer(f"Prediction on {audio_path}", logger):
            prediction_dict = prediction_for_clip(test_df,
                                                  clip=clip,
                                                  models=models,
                                                  threshold=threshold, threshold_long=threshold_long)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

## Prediction

In [None]:
threshold = 0.025
threshold_long = 0.05

In [None]:
# # site_mask = pd.read_csv("../input/site-month-mask/site_mask.csv")
# # site_mask["all"] = site_mask.drop('other', axis=1).T.sum()
# # rare_threshold = 10
# # common_threshold = 200
# # print("common, rare: ", site_mask["all"][site_mask["all"].values>common_threshold].sum(), site_mask["all"][site_mask["all"].values<rare_threshold].sum())
# # print(f"common classes:", (site_mask["all"].values>common_threshold).sum(), "rare classes: ", (site_mask["all"].values<rare_threshold).sum())
# # threshold = np.where((site_mask["all"].values>rare_threshold), np.ones_like(site_mask["all"].values) * threshold, np.ones_like(site_mask["all"].values) * threshold * 4)
# # threshold = np.where((site_mask["all"].values>common_threshold), np.ones_like(site_mask["all"].values) * threshold / 2, np.ones_like(site_mask["all"].values) * threshold)

# # threshold_long = np.where((site_mask["all"].values>rare_threshold), np.ones_like(site_mask["all"].values) * threshold_long, np.ones_like(site_mask["all"].values) * threshold_long * 4)
# # threshold_long = np.where((site_mask["all"].values>common_threshold), np.ones_like(site_mask["all"].values) * threshold_long / 2, np.ones_like(site_mask["all"].values) * threshold_long)

# site_mask = pd.read_csv("../input/site-month-mask/site_mask.csv")
# site_mask["all"] = site_mask.drop('other', axis=1).T.sum()
# rare_threshold = 10
# print("common, rare: ", site_mask["all"][site_mask["all"].values>rare_threshold].sum(), site_mask["all"][site_mask["all"].values<rare_threshold].sum())
# print("common classes: ", (site_mask["all"].values>rare_threshold).sum())
# threshold = np.where((site_mask["all"].values>rare_threshold), np.ones_like(site_mask["all"].values) * threshold, np.ones_like(site_mask["all"].values) * threshold * 4)
# threshold_long = np.where((site_mask["all"].values>rare_threshold), np.ones_like(site_mask["all"].values) * threshold_long, np.ones_like(site_mask["all"].values) * threshold_long * 4)

In [None]:
# month_mask = pd.read_csv("../input/site-month-mask/month_mask.csv")

# month_mask["all"] = month_mask.T.sum()
# rare_threshold = 100
# print("common, rare: ", month_mask["all"][month_mask["all"].values>rare_threshold].sum(), month_mask["all"][month_mask["all"].values<rare_threshold].sum())
# print("common classes: ", (month_mask["all"].values>rare_threshold).sum())
# threshold = np.where((month_mask["all"].values>rare_threshold), np.ones_like(month_mask["all"].values) * threshold, np.ones_like(month_mask["all"].values) * 0.1)
# threshold_long = np.where((month_mask["all"].values>rare_threshold), np.ones_like(month_mask["all"].values) * threshold_long, np.ones_like(month_mask["all"].values) * 0.2)
# month_mask["all"].plot()

In [None]:
submission = prediction(test_audios=all_audios,
                        models_cfg=CFG.models_cfg,
                        threshold=threshold, 
                        threshold_long=threshold_long)
submission.to_csv("submission.csv", index=False)

In [None]:
# birds2id = {b : i for i, b in enumerate(CFG.target_columns)}
# month_mask = pd.read_csv("../input/site-month-mask/month_mask.csv")
# site_mask = pd.read_csv("../input/site-month-mask/site_mask.csv")
# index = 0
# remove_num = 0
# for audio_path in all_audios:
#     site = audio_path.name.split("_")[1]
#     month = audio_path.name.split("_")[2][4:6]
#     m = month_mask[month]
#     for second in range(5, 605, 5):
#         row_id = "_".join(audio_path.name.split("_")[:2]) + f"_{second}"
#         birds = submission.iloc[index, 1].split(" ")
#         birds_removed = []
#         for b in birds:
#             if b != "nocall":
#                 if site == "COR" or site == "COL":
#                     if site_mask.loc[birds2id[b], "COR"] == 0 and site_mask.loc[birds2id[b], "COL"] == 0:
#                         remove_num += 1
#                     else:
#                         if month_mask.loc[birds2id[b], month] == 0:
#                             remove_num += 1
#                         else:
#                             birds_removed.append(b)
#                 else:
#                     if site_mask.loc[birds2id[b], site] == 0:
#                         remove_num += 1
#                     else:
#                         if month_mask.loc[birds2id[b], month] == 0:
#                             remove_num += 1
#                         else:
#                             birds_removed.append(b)
#         birds_removed = list(set(birds_removed))
#         if len(birds_removed) == 0:
#             birds_removed.append("nocall")
#         submission.iloc[index, 1] = " ".join(birds_removed)
#         index += 1
# print(remove_num)
# # print(submission)
# submission.to_csv("submission.csv", index=False)

In [None]:
pd.read_csv("submission.csv")

In [None]:
def get_metrics(s_true, s_pred):
    s_true = set(s_true.split())
    s_pred = set(s_pred.split())
    n, n_true, n_pred = len(s_true.intersection(s_pred)), len(s_true), len(s_pred)
    
    prec = n/n_pred
    rec = n/n_true
    f1 = 2*prec*rec/(prec + rec) if prec + rec else 0
    
    return {"f1": f1, "prec": prec, "rec": rec, "n_true": n_true, "n_pred": n_pred, "n": n}

In [None]:
TARGET_PATH = None
TEST_AUDIO_ROOT = Path("../input/birdclef-2021/test_soundscapes")
SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
if not len(list(TEST_AUDIO_ROOT.glob("*.ogg"))):
    TEST_AUDIO_ROOT = Path("../input/birdclef-2021/train_soundscapes")
    SAMPLE_SUB_PATH = None
    # SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
    TARGET_PATH = Path("../input/birdclef-2021/train_soundscape_labels.csv")
    
if TARGET_PATH:
    sub_target = pd.read_csv(TARGET_PATH)
    sub_target = sub_target.merge(submission, how="left", on="row_id")
    
    print(sub_target["birds_x"].notnull().sum(), sub_target["birds_x"].notnull().sum())
    assert sub_target["birds_x"].notnull().all()
    assert sub_target["birds_y"].notnull().all()
    
    df_metrics = pd.DataFrame([get_metrics(s_true, s_pred) for s_true, s_pred in zip(sub_target.birds_x, sub_target.birds_y)])
    
    print(df_metrics.mean())