In [None]:
!cp -r ../input/timm-pytorch-image-models /kaggle/working/
!pip install /kaggle/working/timm-pytorch-image-models/pytorch-image-models-master/

In [None]:
from glob import glob
import os
import random
import warnings
from functools import partial

# import colorednoise as cn
import librosa
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import scipy as sp
import soundfile as sf
import timm
import torch
import torch.optim as optim
from pytorch_lightning import LightningDataModule, callbacks

# from pytorch_lightning.utilities import rank_zero_info
from sklearn import model_selection
from sklearn.metrics import f1_score
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from torchvision.transforms import Resize

In [None]:
import cv2
import audioread
import logging
import os
import random
import time
import warnings

import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from albumentations.core.transforms_interface import ImageOnlyTransform
# from torchlibrosa.stft import LogmelFilterBank, Spectrogram
# from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
    
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger
    
    
@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)

In [None]:
logger = get_logger("main.log")
set_seed(1213)

In [None]:
class CFG:
    ######################
    # Globals #
    ######################
    seed = 1213

    ######################
    # Data #
    ######################
    train_datadir = Path("../input/birdclef-2021/train_short_audio")
    train_csv = "../input/birdclef-2021/train_metadata.csv"
    train_soundscape = "../input/birdclef-2021/train_soundscape_labels.csv"

    ######################
    # Dataset #
    ######################
    transforms = {
        "train": [{"name": "Normalize"}],
        "valid": [{"name": "Normalize"}],
        "test": [{"name": "Normalize"}]
    }
    period = 30
    n_mels = 256
    
    sample_rate = 32000
    target_columns = [
        "afrsil1",
        "akekee",
        "akepa1",
        "akiapo",
        "akikik",
        "amewig",
        "aniani",
        "apapan",
        "arcter",
        "barpet",
        "bcnher",
        "belkin1",
        "bkbplo",
        "bknsti",
        "bkwpet",
        "blkfra",
        "blknod",
        "bongul",
        "brant",
        "brnboo",
        "brnnod",
        "brnowl",
        "brtcur",
        "bubsan",
        "buffle",
        "bulpet",
        "burpar",
        "buwtea",
        "cacgoo1",
        "calqua",
        "cangoo",
        "canvas",
        "caster1",
        "categr",
        "chbsan",
        "chemun",
        "chukar",
        "cintea",
        "comgal1",
        "commyn",
        "compea",
        "comsan",
        "comwax",
        "coopet",
        "crehon",
        "dunlin",
        "elepai",
        "ercfra",
        "eurwig",
        "fragul",
        "gadwal",
        "gamqua",
        "glwgul",
        "gnwtea",
        "golphe",
        "grbher3",
        "grefri",
        "gresca",
        "gryfra",
        "gwfgoo",
        "hawama",
        "hawcoo",
        "hawcre",
        "hawgoo",
        "hawhaw",
        "hawpet1",
        "hoomer",
        "houfin",
        "houspa",
        "hudgod",
        "iiwi",
        "incter1",
        "jabwar",
        "japqua",
        "kalphe",
        "kauama",
        "laugul",
        "layalb",
        "lcspet",
        "leasan",
        "leater1",
        "lessca",
        "lesyel",
        "lobdow",
        "lotjae",
        "madpet",
        "magpet1",
        "mallar3",
        "masboo",
        "mauala",
        "maupar",
        "merlin",
        "mitpar",
        "moudov",
        "norcar",
        "norhar2",
        "normoc",
        "norpin",
        "norsho",
        "nutman",
        "oahama",
        "omao",
        "osprey",
        "pagplo",
        "palila",
        "parjae",
        "pecsan",
        "peflov",
        "perfal",
        "pibgre",
        "pomjae",
        "puaioh",
        "reccar",
        "redava",
        "redjun",
        "redpha1",
        "refboo",
        "rempar",
        "rettro",
        "ribgul",
        "rinduc",
        "rinphe",
        "rocpig",
        "rorpar",
        "rudtur",
        "ruff",
        "saffin",
        "sander",
        "semplo",
        "sheowl",
        "shtsan",
        "skylar",
        "snogoo",
        "sooshe",
        "sooter1",
        "sopsku1",
        "sora",
        "spodov",
        "sposan",
        "towsol",
        "wantat1",
        "warwhe1",
        "wesmea",
        "wessan",
        "wetshe",
        "whfibi",
        "whiter",
        "whttro",
        "wiltur",
        "yebcar",
        "yefcan",
        "zebdov",
    ]
    bird2id = {b: i for i, b in enumerate(target_columns)}
    id2bird = {i: b for i, b in enumerate(target_columns)}
    scored_birds = ["akiapo", "aniani", "apapan", "barpet", "crehon", "elepai", "ercfra", "hawama", "hawcre", "hawgoo", "hawhaw", "hawpet1", "houfin", "iiwi", "jabwar", "maupar", "omao", "puaioh", "skylar", "warwhe1", "yefcan"]
    loader_params = {
        "train": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": True
        },
        "valid": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": False
        },
        "test": {
            "batch_size": 64,
            "num_workers": 20,
            "shuffle": False
        }
    }
#     models_cfg = [{"resnest26d": glob("../input/birdclef2022-weights/exp024/stage1/resnest26d/fold0/**/best_loss.ckpt", recursive=True)[0],}]
    models_cfg = [{"resnest26d": glob("../input/birdclef2022-weights/exp024_base/stage1/resnest26d/fold0/**/best_loss.ckpt", recursive=True)[0],}]
    num_classes = len(target_columns)
print(f"model_num: {len(CFG.models_cfg)}")

In [None]:
TARGET_SR = 32000
DATADIR = Path("../input/birdclef-2022/test_soundscapes/")

In [None]:
all_audios = list(DATADIR.glob("*.ogg"))
sample_submission = pd.read_csv('../input/birdclef-2022/sample_submission.csv')
sample_submission

In [None]:
class FilterAugmentation(nn.Module):
    def __init__(
        self, p=0.5, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"
    ):
        super().__init__()
        self.db_range = db_range
        self.n_band = n_band
        self.min_bw = min_bw
        self.filter_type = filter_type
        self.p = p

    def forward(self, features):
        if self.p < np.random.rand() or not self.training:
            return features
        if not isinstance(self.filter_type, str):
            if torch.rand(1).item() < self.filter_type:
                self.filter_type = "step"
                self.n_band = [2, 5]
                self.min_bw = 4
            else:
                self.filter_type = "linear"
                self.n_band = [3, 6]
                self.min_bw = 6
        batch_size, n_freq_bin, _ = features.shape
        n_freq_band = torch.randint(
            low=self.n_band[0], high=self.n_band[1], size=(1,)
        ).item()  # [low, high)
        if n_freq_band > 1:
            min_bw = self.min_bw
            while n_freq_bin - n_freq_band * self.min_bw + 1 < 0:
                min_bw -= 1
            band_bndry_freqs = (
                torch.sort(
                    torch.randint(
                        0, n_freq_bin - n_freq_band * min_bw + 1, (n_freq_band - 1,)
                    )
                )[0]
                + torch.arange(1, n_freq_band) * min_bw
            )
            band_bndry_freqs = torch.cat(
                (torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))
            )

            if self.filter_type == "step":
                band_factors = (
                    torch.rand((batch_size, n_freq_band)).to(features)
                    * (self.db_range[1] - self.db_range[0])
                    + self.db_range[0]
                )
                band_factors = 10 ** (band_factors / 20)

                freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
                for i in range(n_freq_band):
                    freq_filt[:, band_bndry_freqs[i] : band_bndry_freqs[i + 1], :] = (
                        band_factors[:, i].unsqueeze(-1).unsqueeze(-1)
                    )

            elif self.filter_type == "linear":
                band_factors = (
                    torch.rand((batch_size, n_freq_band + 1)).to(features)
                    * (self.db_range[1] - self.db_range[0])
                    + self.db_range[0]
                )
                freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
                for i in range(n_freq_band):
                    for j in range(batch_size):
                        freq_filt[
                            j, band_bndry_freqs[i] : band_bndry_freqs[i + 1], :
                        ] = torch.linspace(
                            band_factors[j, i],
                            band_factors[j, i + 1],
                            band_bndry_freqs[i + 1] - band_bndry_freqs[i],
                        ).unsqueeze(
                            -1
                        )
                freq_filt = 10 ** (freq_filt / 20)
            return features * freq_filt

        else:
            return features

In [None]:
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation):
        super(_ASPPModule, self).__init__()
        planes = int(planes)
        self.atrous_conv = nn.Conv2d(
            inplanes,
            int(planes),
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            dilation=dilation,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[1, 6, 12, 18]):
        super(ASPP, self).__init__()
        self.aspp1 = _ASPPModule(inplanes, mid_c, 1, padding=0, dilation=dilations[0])
        self.aspp2 = _ASPPModule(
            inplanes, mid_c, 3, padding=dilations[1], dilation=dilations[1]
        )
        self.aspp3 = _ASPPModule(
            inplanes, mid_c, 3, padding=dilations[2], dilation=dilations[2]
        )
        self.aspp4 = _ASPPModule(
            inplanes, mid_c, 3, padding=dilations[3], dilation=dilations[3]
        )
        mid_c = int(mid_c)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
            nn.BatchNorm2d(mid_c),
            nn.ReLU(),
        )
        self.conv1 = nn.Conv2d(mid_c * 5, mid_c, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_c)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self._init_weight()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return self.dropout(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


In [None]:
def gem_freq(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), 1)).pow(1.0 / p)


class GeMFreq(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem_freq(x, p=self.p, eps=self.eps)


class NormalizeMelSpec(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, X):
        mean = X.mean((1, 2), keepdim=True)
        std = X.std((1, 2), keepdim=True)
        Xstd = (X - mean) / (std + self.eps)
        norm_min, norm_max = Xstd.min(-1)[0].min(-1)[0], Xstd.max(-1)[0].max(-1)[0]
        fix_ind = (norm_max - norm_min) > self.eps * torch.ones_like(
            (norm_max - norm_min)
        )
        V = torch.zeros_like(Xstd)
        if fix_ind.sum():
            V_fix = Xstd[fix_ind]
            norm_max_fix = norm_max[fix_ind, None, None]
            norm_min_fix = norm_min[fix_ind, None, None]
            V_fix = torch.max(
                torch.min(V_fix, norm_max_fix),
                norm_min_fix,
            )
            # print(V_fix.shape, norm_min_fix.shape, norm_max_fix.shape)
            V_fix = (V_fix - norm_min_fix) / (norm_max_fix - norm_min_fix)
            V[fix_ind] = V_fix
        return V

class AttHead(nn.Module):
    def __init__(
        self, in_chans, p=0.5, num_class=397, train_period=15.0, infer_period=5.0
    ):
        super().__init__()
        self.train_period = train_period
        self.infer_period = infer_period
        self.pooling = GeMFreq()

        self.dense_layers = nn.Sequential(
            nn.Dropout(p / 2),
            nn.Linear(in_chans, 512),
            nn.ReLU(),
            nn.Dropout(p),
        )
        self.attention = nn.Conv1d(
            in_channels=512,
            out_channels=num_class,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
        self.fix_scale = nn.Conv1d(
            in_channels=512,
            out_channels=num_class,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
        self.aspp = ASPP(in_chans, in_chans, dilations=[1, 4, 8, 12])

    def forward(self, feat):
        feat = self.aspp(feat)
        feat = self.pooling(feat).squeeze(-2).permute(0, 2, 1)  # (bs, time, ch)

        feat = self.dense_layers(feat).permute(0, 2, 1)  # (bs, 512, time)
        time_att = torch.tanh(self.attention(feat))
        assert self.train_period >= self.infer_period
        if self.training or self.train_period == self.infer_period:
            framewise_pred = self.fix_scale(feat)
            clipwise_pred = torch.sum(
                torch.sigmoid(framewise_pred) * torch.softmax(time_att, dim=-1),
                dim=-1,
            )  # sum((bs, 24, time), -1) -> (bs, 24)
            logits = torch.sum(
                framewise_pred * torch.softmax(time_att, dim=-1),
                dim=-1,
            )
        else:
            framewise_pred_long = torch.sigmoid(self.fix_scale(feat))
            clipwise_pred_long = torch.sum(framewise_pred_long * torch.softmax(time_att, dim=-1), dim=-1) 
            
            feat_time = feat.size(-1)
            start = (
                feat_time / 2 - feat_time * (self.infer_period / self.train_period) / 2
            )
            end = start + feat_time * (self.infer_period / self.train_period)
            start = int(start)
            end = int(end)
            feat = feat[:, :, start:end]
            att = torch.softmax(time_att[:, :, start:end], dim=-1)
#             print(feat_time, start, end)
#             print(att_a.sum(), att.sum(), time_att.shape)
            framewise_pred = torch.sigmoid(self.fix_scale(feat))
            clipwise_pred = torch.sum(framewise_pred * att, dim=-1) 
            logits = torch.sum(
                self.fix_scale(feat) * att,
                dim=-1,
            )
            time_att = time_att[:, :, start:end]
        return (
            logits,
            clipwise_pred,
            self.fix_scale(feat),
            time_att,
            clipwise_pred_long,
        )


class AttModel(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        p=0.5,
        n_mels=224,
        num_class=397,
        train_period=15.0,
        infer_period=5.0,
        in_chans=1,
    ):
        super().__init__()
        self.n_mels = n_mels
        self.logmelspec_extractor = nn.Sequential(
            MelSpectrogram(
                32000,
                n_mels=n_mels,
                f_min=20,
                n_fft=2048,
                hop_length=512,
                normalized=True,
            ),
            AmplitudeToDB(top_db=80.0),
            FilterAugmentation(
                p=0., db_range=[-7.5, 6], n_band=[2, 3], min_bw=6, filter_type="linear"
            ),
            NormalizeMelSpec(),
            Resize(
                size=(int(n_mels), int(n_mels // 2 * (train_period // 5))),
            ),
        )

        self.backbone = timm.create_model(
            backbone, features_only=True, pretrained=False, in_chans=in_chans
        )
        encoder_channels = self.backbone.feature_info.channels()
        dense_input = encoder_channels[-1]
        self.head = AttHead(
            dense_input,
            p=p,
            num_class=num_class,
            train_period=train_period,
            infer_period=infer_period,
        )

    def forward(self, input):
#         img = self.logmelspec_extractor(input)[
#             :, None
#         ]  # (batch_size, 1, mel_bins, time_steps)
        feats = self.backbone(input)
        return self.head(feats[-1])
    
class Model(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        p=0.5,
        n_mels=224,
        num_class=CFG.num_classes,
        train_period=CFG.period,
        infer_period=5.0,
        in_chans=1,
    ):
        super().__init__()
        self.model = AttModel(backbone, p, n_mels, num_class, train_period, infer_period, in_chans)

In [None]:
class TestDataset(torchdata.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray, train_period=30, 
                 waveform_transforms=None):
        self.df = df
        self.clip = np.concatenate([clip[::-1], clip, clip[::-1]])
        self.train_period = train_period
        self.waveform_transforms=waveform_transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        row_id = sample.row_id

        end_seconds = int(sample.seconds)
        start_seconds = int(end_seconds - 5)
        
        end_index = int(SR * (end_seconds + (self.train_period - 5) / 2) + len(self.clip) // 3)
        start_index = int(SR * (start_seconds - (self.train_period - 5) / 2) + len(self.clip) // 3)
        
        y = self.clip[start_index:end_index].astype(np.float32)

        y = np.nan_to_num(y)

        if self.waveform_transforms:
            y = self.waveform_transforms(y)

        y = np.nan_to_num(y)
        
        return y, row_id

In [None]:
def get_transforms(phase: str):
    transforms = CFG.transforms
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if globals().get(trns_name) is not None:
                trns_cls = globals()[trns_name]
                trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return Compose(trns_list)
        else:
            return None


def get_waveform_transforms(config: dict, phase: str):
    return get_transforms(config, phase)


def get_spectrogram_transforms(config: dict, phase: str):
    transforms = config.get('spectrogram_transforms')
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if hasattr(A, trns_name):
                trns_cls = A.__getattribute__(trns_name)
                trns_list.append(trns_cls(**trns_params))
            else:
                trns_cls = globals().get(trns_name)
                if trns_cls is not None:
                    trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return A.Compose(trns_list, p=1.0)
        else:
            return None
        
class Normalize:
    def __call__(self, y: np.ndarray):
        max_vol = np.abs(y).max()
        y_vol = y * 1 / max_vol
        return np.asfortranarray(y_vol)


class NewNormalize:
    def __call__(self, y: np.ndarray):
        y_mm = y - y.mean()
        return y_mm / y_mm.abs().max()


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y


class AudioTransform:
    def __init__(self, always_apply=False, p=0.5):
        self.always_apply = always_apply
        self.p = p

    def __call__(self, y: np.ndarray):
        if self.always_apply:
            return self.apply(y)
        else:
            if np.random.rand() < self.p:
                return self.apply(y)
            else:
                return y

    def apply(self, y: np.ndarray):
        raise NotImplementedError
        
class NoiseInjection(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_noise_level=0.5, sr=32000):
        super().__init__(always_apply, p)

        self.noise_level = (0.0, max_noise_level)
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        noise_level = np.random.uniform(*self.noise_level)
        noise = np.random.randn(len(y))
        augmented = (y + noise * noise_level).astype(y.dtype)
        return augmented


class GaussianNoise(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, min_snr=5, max_snr=20, sr=32000):
        super().__init__(always_apply, p)

        self.min_snr = min_snr
        self.max_snr = max_snr
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        snr = np.random.uniform(self.min_snr, self.max_snr)
        a_signal = np.sqrt(y ** 2).max()
        a_noise = a_signal / (10 ** (snr / 20))

        white_noise = np.random.randn(len(y))
        a_white = np.sqrt(white_noise ** 2).max()
        augmented = (y + white_noise * 1 / a_white * a_noise).astype(y.dtype)
        return augmented


class PinkNoise(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, min_snr=5, max_snr=20, sr=32000):
        super().__init__(always_apply, p)

        self.min_snr = min_snr
        self.max_snr = max_snr
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        snr = np.random.uniform(self.min_snr, self.max_snr)
        a_signal = np.sqrt(y ** 2).max()
        a_noise = a_signal / (10 ** (snr / 20))

        pink_noise = cn.powerlaw_psd_gaussian(1, len(y))
        a_pink = np.sqrt(pink_noise ** 2).max()
        augmented = (y + pink_noise * 1 / a_pink * a_noise).astype(y.dtype)
        return augmented


class PitchShift(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_range=5, sr=32000):
        super().__init__(always_apply, p)
        self.max_range = max_range
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        n_steps = np.random.randint(-self.max_range, self.max_range)
        augmented = librosa.effects.pitch_shift(y, self.sr, n_steps)
        return augmented

class TimeStretch(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, max_rate=1, sr=32000):
        super().__init__(always_apply, p)
        self.max_rate = max_rate
        self.sr = sr

    def apply(self, y: np.ndarray, **params):
        rate = np.random.uniform(0, self.max_rate)
        augmented = librosa.effects.time_stretch(y, rate)
        return augmented


def _db2float(db: float, amplitude=True):
    if amplitude:
        return 10**(db / 20)
    else:
        return 10 ** (db / 10)


def volume_down(y: np.ndarray, db: float):
    """
    Low level API for decreasing the volume
    Parameters
    ----------
    y: numpy.ndarray
        stereo / monaural input audio
    db: float
        how much decibel to decrease
    Returns
    -------
    applied: numpy.ndarray
        audio with decreased volume
    """
    applied = y * _db2float(-db)
    return applied


def volume_up(y: np.ndarray, db: float):
    """
    Low level API for increasing the volume
    Parameters
    ----------
    y: numpy.ndarray
        stereo / monaural input audio
    db: float
        how much decibel to increase
    Returns
    -------
    applied: numpy.ndarray
        audio with increased volume
    """
    applied = y * _db2float(db)
    return applied

class RandomVolume(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, limit=10):
        super().__init__(always_apply, p)
        self.limit = limit

    def apply(self, y: np.ndarray, **params):
        db = np.random.uniform(-self.limit, self.limit)
        if db >= 0:
            return volume_up(y, db)
        else:
            return volume_down(y, db)


class OneOf:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        n_trns = len(self.transforms)
        trns_idx = np.random.choice(n_trns)
        trns = self.transforms[trns_idx]
        y = trns(y)
        return y


class CosineVolume(AudioTransform):
    def __init__(self, always_apply=False, p=0.5, limit=10):
        super().__init__(always_apply, p)
        self.limit = limit

    def apply(self, y: np.ndarray, **params):
        db = np.random.uniform(-self.limit, self.limit)
        cosine = np.cos(np.arange(len(y)) / len(y) * np.pi * 2)
        dbs = _db2float(cosine * db)
        return y * dbs


def drop_stripes(image: np.ndarray, dim: int, drop_width: int, stripes_num: int):
    total_width = image.shape[dim]
    lowest_value = image.min()
    for _ in range(stripes_num):
        distance = np.random.randint(low=0, high=drop_width, size=(1,))[0]
        begin = np.random.randint(
            low=0, high=total_width - distance, size=(1,))[0]

        if dim == 0:
            image[begin:begin + distance] = lowest_value
        elif dim == 1:
            image[:, begin + distance] = lowest_value
        elif dim == 2:
            image[:, :, begin + distance] = lowest_value
    return image

class TimeFreqMasking(ImageOnlyTransform):
    def __init__(self,
                 time_drop_width: int,
                 time_stripes_num: int,
                 freq_drop_width: int,
                 freq_stripes_num: int,
                 always_apply=False,
                 p=0.5):
        super().__init__(always_apply, p)
        self.time_drop_width = time_drop_width
        self.time_stripes_num = time_stripes_num
        self.freq_drop_width = freq_drop_width
        self.freq_stripes_num = freq_stripes_num

    def apply(self, img, **params):
        img_ = img.copy()
        if img.ndim == 2:
            img_ = drop_stripes(
                img_, dim=0, drop_width=self.freq_drop_width, stripes_num=self.freq_stripes_num)
            img_ = drop_stripes(
                img_, dim=1, drop_width=self.time_drop_width, stripes_num=self.time_stripes_num)
        return img_

In [None]:
from torch.nn.modules.batchnorm import _BatchNorm

def prepare_model_for_inference(model, path: Path):
    if not torch.cuda.is_available():
        ckpt = torch.load(path, map_location="cpu")
    else:
        ckpt = torch.load(path)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()

    return model

In [None]:
def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        models, 
                        threshold=0.05, 
                        threshold_long=None):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          train_period = CFG.period, 
                          waveform_transforms=get_transforms(phase="test"))
    loader = torchdata.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
#     [model.eval() for model in models]
    prediction_dict = {}
    for image, row_id in tqdm(loader):
        row_id = row_id[0]
        image = image.to(device)

        with torch.no_grad():
            image = models[0].logmelspec_extractor(image)[:, None]
            probas = []
            probas_long = []
            for model in models:
                with torch.cuda.amp.autocast():
                    _, clipwise_pred, _, _, clipwise_pred_long = model(image)
                probas.append(clipwise_pred.detach().cpu().numpy().reshape(-1))
                probas_long.append(clipwise_pred_long.detach().cpu().numpy().reshape(-1))
            probas = np.array(probas)
            probas_long = np.array(probas_long)
#             probas = np.array([model(image)[1].detach().cpu().numpy().reshape(-1) for model in models])
        if threshold_long is None:
            events = probas.mean(0) >= threshold
        else:
            events = ((probas.mean(0) >= threshold).astype(int) \
                      + (probas_long.mean(0) >= threshold_long).astype(int)) >= 2
        labels = np.argwhere(events).reshape(-1).tolist()
#         labels = labels[:2]
        if len(labels) == 0:
            prediction_dict[str(row_id)] = "nocall"
        else:
            labels_str_list = list(map(lambda x: CFG.target_columns[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[str(row_id)] = label_string
    return prediction_dict

In [None]:
def load_model(backbone_name, weight_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Model(
        backbone_name,
        p=0.5,
        n_mels=CFG.n_mels,
        num_class=CFG.num_classes,
        train_period=CFG.period,
        infer_period=5,
    )
    model = prepare_model_for_inference(model, weight_path).to(device)
    model = model.model
    return model

def prediction(test_audios,
               models_cfg,
               threshold=0.05, 
               threshold_long=None):
    
    models = [load_model(list(models_cfg.keys())[0], list(models_cfg.values())[0]) for models_cfg in models_cfg]
    warnings.filterwarnings("ignore")
    prediction_dicts = {}
    for audio_path in test_audios:
        with timer(f"Loading {str(audio_path)}", logger):
            clip, _ = sf.read(audio_path)
        seconds = []
        row_ids = []
        for second in range(5, 65, 5):
            row_id = "_".join(audio_path.name.split(".")[:-1]) + f"_{second}"
            seconds.append(second)
            row_ids.append(row_id)
        print(row_ids)
        test_df = pd.DataFrame({
            "row_id": row_ids,
            "seconds": seconds
        })
        with timer(f"Prediction on {audio_path}", logger):
            prediction_dict = prediction_for_clip(test_df,
                                                  clip=clip,
                                                  models=models,
                                                  threshold=threshold, threshold_long=threshold_long)
#         row_id = list(prediction_dict.keys())
#         birds = list(prediction_dict.values())
#         prediction_df = pd.DataFrame({
#             "row_id": row_id,
#             "birds": birds
#         })
#         prediction_dfs.append(prediction_df)
#     prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
        prediction_dicts.update(prediction_dict)
    return prediction_dicts

In [None]:
threshold = 0.05
threshold_long = 0.1

prediction_dicts = prediction(test_audios=all_audios,
           models_cfg=CFG.models_cfg,
           threshold=threshold, 
           threshold_long=threshold_long)
print(prediction_dicts)

for i in range(len(sample_submission)):
    sample = sample_submission.row_id[i]
    key = sample.split("_")[0] + "_" + sample.split("_")[1] + "_" + sample.split("_")[3]
    target_bird = sample.split("_")[2]
    print(key, target_bird)
    if key in prediction_dicts:
        sample_submission.iat[i, 1] = (target_bird in prediction_dicts[key])
sample_submission.to_csv("submission.csv", index=False)
# submission = prediction(test_audios=all_audios,
#                         models_cfg=CFG.models_cfg,
#                         threshold=threshold, 
#                         threshold_long=threshold_long)
# submission.to_csv("submission.csv", index=False)

In [None]:
sample_submission