# A simple baseline of the pytorch version is built based on resnet34 of the timm library. Training and inference code is provided.



# For the training code we use an online transformation, so the training is very slow. Pre-processing before training can speed up model training.


# Many thanks to the notebooks below for the reference:
https://www.kaggle.com/tattaka/birdclef2022-submission-baseline

https://www.kaggle.com/myso1987/birdclef2022-pytorch-resnet34-starter-lb-0-50

etc.

# *The inference code is not easy to understand and will be improved in the future.*


In [None]:
!pip install ../input/timm-package/timm-0.4.12-py3-none-any.whl

In [None]:
import os
import json
import tqdm
import random
import shutil
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
from sklearn.model_selection import train_test_split, GroupKFold, StratifiedKFold, KFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import timm
from torchaudio.transforms import MelSpectrogram, Resample
import re
import torch.nn.functional as F
import soundfile as sf
import glob


class Config:
    seed = 2022 
    num_classes = 152 
    epochs = 21
    batch_size = 48
    n_fold = 5 
    learning_rate = 3e-4 
    img_size = 128 
    print_freq = 100 
    model_save_dir = './' 
    pretrained = False

CFG = Config()

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def extract_call(data, call = 'call'):
    try:
        if re.search(data, call):
            return "True"
        else:
            return "False"
    except:
        return "False"

def get_train_transforms():
    return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0, p=1.0), 
            ToTensorV2(p=1.0),
        ], p=1.)
        
def get_val_transforms():
    return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

    
class MyDataset(Dataset):
    def __init__(self, image_paths=None, label_paths=None, transforms=None, mode='train'):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transforms = transforms
        self.mode = mode
        self.len = len(label_paths)
        self.target_sample_rate = 32000
        self.num_samples = 32000*5
        self.mel_spectrogram = T.MelSpectrogram(sample_rate=self.target_sample_rate, n_fft=2048, win_length=None, hop_length=1024, center=True,
                                       pad_mode="reflect", power=2.0, norm='slaney', onesided=True, n_mels=128,
                                       mel_scale="htk", )
        
    def __getitem__(self, idx):
        if self.mode == 'train':
            audio, sample_rate = torchaudio.load(self.image_paths[idx])
            audio = self.to_mono(audio)
            if sample_rate != self.target_sample_rate:
                resample = Resample(sample_rate, self.target_sample_rate)
                audio = resample(audio)
            if audio.shape[0] > self.num_samples:
                audio = self.crop_audio(audio)
            else:    
                audio = self.pad_audio(audio)
            mel =self.mel_spectrogram(audio)
            mel = torchaudio.transforms.AmplitudeToDB()(mel)
            mel = np.array(mel)
            mel = self.scale_minmax(mel, 0, 255)
            inputs = mel[:, :, np.newaxis] #[m,n]->[m,n,1]
            labels = torch.tensor(self.label_paths[idx], dtype=torch.long)
            augments = self.transforms(image=inputs)
            inputs = augments['image']
            return inputs, labels

        elif self.mode == 'test':
            SR = self.target_sample_rate
            audio, sample_rate = torchaudio.load(self.image_paths)
            audio = self.to_mono(audio)
            sample = self.label_paths.loc[idx, :]  #test, is no label, only clip
            row_id = sample.row_id 
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            end_index = int(SR * (end_seconds + (60 - 5) / 2) + len(audio) // 3)
            start_index = int(SR * (start_seconds - (60 - 5) / 2) + len(audio) // 3)
            
            audio = audio[start_index:end_index]
            if sample_rate != self.target_sample_rate:
                resample = Resample(sample_rate, self.target_sample_rate)
                audio = resample(audio)
            if audio.shape[0] > self.num_samples:
                audio = self.crop_audio(audio)
            else:    
                audio = self.pad_audio(audio)
            mel =self.mel_spectrogram(audio)
            mel = torchaudio.transforms.AmplitudeToDB()(mel)
            mel = np.array(mel)
            mel = self.scale_minmax(mel, 0, 255)
            inputs = mel[:, :, np.newaxis] 
            augments = self.transforms(image=inputs)
            inputs = augments['image']
            return inputs, row_id 

    def pad_audio(self, audio):
        pad_length = self.num_samples - audio.shape[0]
        last_dim_padding = (0, pad_length)
        audio = F.pad(audio, last_dim_padding)
        return audio
        
    def crop_audio(self, audio):
        return audio[:self.num_samples]
    
    def to_mono(self, audio):
        return torch.mean(audio, axis=0)
    
    def scale_minmax(self, X, min=0.0, max=1.0):
        X_std = (X - X.min()) / (X.max() - X.min())
        X_scaled = X_std * (max - min) + min
        return X_scaled
    
    def __len__(self):
        return self.len


# =============================== model ========================
class MyModel(nn.Module):
    def __init__(self, num_classes=2, pretrained=True):
        super().__init__()
        self.model = timm.create_model('resnet34', in_chans=1, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

def prediction_for_clip(audio_path, test_df, models,  threshold=0.05):

    test_dataset = MyDataset(audio_path, test_df, transforms=get_val_transforms(), mode='test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    prediction_dict = {}
    for inputs, row_id in test_loader:
#         print(row_id)
        inputs = inputs.to(device)
        with torch.no_grad():
            probas = []
            for model in models:
                model.eval()
                output = model(inputs)
                output = torch.sigmoid(output)
                probas.append(output.detach().cpu().numpy().reshape(-1))
            probas = np.array(probas)
            events = probas.mean(0) >= threshold
            events2 = probas.mean(0)
        labels = np.argwhere(events).reshape(-1).tolist()
        labels2 = np.argmax(events2).reshape(-1).tolist()
#         print(events)
        print(labels)
        if len(labels) == 0:
            prediction_dict[str(row_id)] = "nocall"
        else:
            labels_str_list = list(map(lambda x: class_dict[x], labels2))
            label_string = " ".join(labels_str_list)
            prediction_dict[str(row_id)] = label_string
        print(prediction_dict[str(row_id)])
    return prediction_dict

if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    seed_everything(CFG.seed)
    
    root_path = "../input/birdclef-2022/"
    input_path = root_path + '/train_audio/'
    train_meta = pd.read_csv(root_path + 'train_metadata.csv')
    for  i in range(len(train_meta)):
        train_meta.loc[i, "filename"] = input_path + train_meta.loc[i, "filename"]

    print("Length of data before call extraction : {}".format(len(train_meta)))
    train_meta["type"] = train_meta["type"].apply(extract_call)
    train_meta = train_meta[train_meta["type"] == "True"].reset_index(drop=True)
    train_meta.drop("type", 1, inplace = True)
    print("Length of data after call extraction : {}".format(len(train_meta)))

    class_dict = dict()
    for index, label in enumerate(train_meta.primary_label.unique()):
        class_dict[index] = label
        train_meta["primary_label"].replace(label, index, inplace = True)
    print(class_dict)
    #================================inference==========================
    test_audios = list(glob.glob("../input/birdclef-2022/test_soundscapes/*.ogg"))
    sample_submission = pd.read_csv('../input/birdclef-2022/sample_submission.csv')
    
    threshold = 0.2 

    model = MyModel(num_classes=CFG.num_classes, pretrained=CFG.pretrained).to(device)
    model.load_state_dict(torch.load('../input/train-bird-pytorch-baseline/fold_1_best.pth'))
    model.to(device)
    models = [model]
    
    prediction_dicts = {}
    for audio_path in test_audios:
        print(audio_path)
        seconds = []
        row_ids = []
        for second in range(5, 65, 5):
            row_id = audio_path.split("/")[-1][:-4] + f"_{second}"
            seconds.append(second)
            row_ids.append(row_id)
        print(row_ids)
        
        test_df = pd.DataFrame({ "row_id": row_ids, "seconds": seconds })
        prediction_dict = prediction_for_clip(audio_path, test_df, models=models, threshold=threshold)
        prediction_dicts.update(prediction_dict)

    for i in range(len(sample_submission)):
        sample = sample_submission.row_id[i]
        key = sample.split("_")[0] + "_" + sample.split("_")[1] + "_" + sample.split("_")[3]
        target_bird = sample.split("_")[2]
        print(key, target_bird)
        if key in prediction_dicts:
            sample_submission.iat[i, 1] = (target_bird in prediction_dicts[key])
    sample_submission.to_csv("submission.csv", index=False)
    print(sample_submission)