# Import Party!

In [None]:
import os,json,re,time,pdb,ast
import random
import numpy as np 
import pandas as pd 
from glob import glob
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from tqdm import tqdm
from collections import Counter
from IPython.display import Audio, display
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

# Pytorch imports for neural networks and tensor manipulations
import torch, torchaudio
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision.transforms import Resize
from torchaudio.transforms import MelSpectrogram
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl


# Libraries for visualization
!pip install torchsummary -q
import torchsummary
from termcolor import cprint




In [None]:
class config:
    seed=2022
    num_fold = 5
    sample_rate= 32_000
    n_fft=1024
    hop_length=512
    n_mels=64
    duration=7
    num_classes = 152
    train_batch_size = 32
    valid_batch_size = 64
    model_name = 'resnet50'
    epochs = 2
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    learning_rate = 1e-4

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


In [None]:
def display_(path,base_dir,disp_=False):
    df=pd.read_csv(path)
    df['path_filename']=df.filename.map(lambda x: base_dir+'/'+x)

    class_dict = dict()

    for index, label in enumerate(df.primary_label.unique()):
        class_dict[index] = label
        df["primary_label"].replace(label, index, inplace = True)
        
    json.dump(class_dict, open("class_dict.json", "w"))
    
    if disp_:
        display(df)
        print(df.info())
        print(df.describe())
        
    return df
    

In [None]:
def plot(df):
    plt.figure(figsize=(20, 6))
    sns.countplot(df['primary_label'])
    plt.xticks(rotation=90)
    plt.title("Distribution of Primary Labels", fontsize=20)
    plt.show()
    
    plt.figure(figsize=(20, 6))
    sns.countplot(df['rating'])
    plt.title("Distribution of Ratings", fontsize=20)
    plt.show()
    
    df['type'] = df['type'].apply(lambda x : ast.literal_eval(x))
    top = Counter([typ.lower() for lst in df['type'] for typ in lst])
    top = dict(top.most_common(10))
    plt.figure(figsize=(20, 6))
    sns.barplot(x=list(top.keys()), y=list(top.values()), palette='hls')
    plt.title("Top 10 song types")
    plt.show()

# Basic Info display

In [None]:
def main():
    seed_everything(config.seed)
    path_trainmeta='../input/birdclef-2022/train_metadata.csv'
    base_train_dir='../input/birdclef-2022/train_audio'
    df=display_(path_trainmeta,base_train_dir,True)
    print('*'*100)
    print('SHORT EDA From base NB consolidated')
    plot(df)
    
main()

# Basic Audio files display

In [None]:
# from torchaudio tutorial docs

def print_stats(waveform, sample_rate=None, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
    if sample_rate:
        print("Sample Rate:", sample_rate)
    print("Shape:", tuple(waveform.shape))
    print("Dtype:", waveform.dtype)
    print(f" - Max:     {waveform.max().item():6.3f}")
    print(f" - Min:     {waveform.min().item():6.3f}")
    print(f" - Mean:    {waveform.mean().item():6.3f}")
    print(f" - Std Dev: {waveform.std().item():6.3f}")
    print()
    print(waveform)
    print()


def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)


def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)


def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")


def _get_sample(path, resample=None):
    effects = [["remix", "1"]]
    if resample:
        effects.extend(
            [
                ["lowpass", f"{resample // 2}"],
                ["rate", f"{resample}"],
            ]
        )
    return torchaudio.sox_effects.apply_effects_file(path, effects=effects)


def get_sample(*, resample=None):
    return _get_sample(SAMPLE_WAV_PATH, resample=resample)


def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")

In [None]:
def display_audiodata(df,idx):
    
    metadata = torchaudio.info(df.path_filename[idx])
    print(metadata)
    # output:
    #AudioMetaData(sample_rate=32000, num_frames=1504653, num_channels=1, bits_per_sample=0, encoding=VORBIS)
    waveform, sample_rate=torchaudio.load(df.path_filename[idx])
    print_stats(waveform, sample_rate=sample_rate)
    plot_waveform(waveform, sample_rate)
    plot_specgram(waveform, sample_rate)
    play_audio(waveform, sample_rate)


In [None]:
def main():
    path_trainmeta='../input/birdclef-2022/train_metadata.csv'
    base_train_dir='../input/birdclef-2022/train_audio'
    df=display_(path_trainmeta,base_train_dir)
    
    for n in range(5):
        display_audiodata(df,n)
    
main()

# Input Pipeline

In [None]:
class base_pipe(Dataset):
    
    def __init__(self,
                 df,
                 size = 640,
                 transform = [
        MelSpectrogram(n_mels = 128),
        Resize((128, 128))
        ]
    ):
        super().__init__()
        self.metadata = df
        self.size = size
        self.transform = transform
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, index):
        path = self.metadata.loc[index, "path_filename"]
        label = self.metadata.loc[index, "primary_label"]
        mono_audio = self.load_audio(path)
        mono_audio = mono_audio.unsqueeze(dim=0)
        return mono_audio, label
    
    
    def load_audio(self, path):
        audio, _ = torchaudio.load(path)
        if self.transform != None:
            for aug in self.transform:
                audio = aug(audio)
                
        return audio[0,:]
    
class pl_pipeline(pl.LightningDataModule):
    
    def __init__(
        self,
        ds,
        df,
        bs,
    
    ):
        super().__init__()
        self.train_df,self.val_df=train_test_split(df,test_size=0.25,)
        self.train_df,self.val_df=self.train_df.reset_index(),self.val_df.reset_index()
        self.ds=ds
        self.bs=bs
        
    def train_dataloader(self):
        train_ds=self.ds(self.train_df)
        train_loader=DataLoader(train_ds,batch_size=self.bs)
        return train_loader
    
    def val_dataloader(self):
        val_ds=self.ds(self.val_df)
        val_loader=DataLoader(val_ds,batch_size=self.bs)
        return val_loader
    
    
    

# Model

In [None]:
# Convolution shape updating function
def conv_shape(shape, kernel_size, stride, padding):
    H, W = shape[0], shape[1]
    H = ((H - kernel_size + 2*padding) // stride) + 1
    W = ((W - kernel_size + 2*padding) // stride) + 1
    return H, W

class Conv(pl.LightningModule):
    
    def __init__(self, 
                   in_channels,
                   out_channels,
                   kernel_size,
                   stride=(1,1),
                   padding=(0,0),
                   momentum=0.15):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(in_channels, momentum = momentum),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.conv_block(x)


class CLEFNetwork(pl.LightningModule):
    
    def __init__(self,
                 num_classes,
                 in_channels = 1,
                 H = 128,
                 W = 128,
                 num_downs = 3):
        super().__init__()
        
        self.num_C = num_classes
        self.num_downs = num_downs
        self.in_channels = in_channels
        self.C = 8
        self.H, self.W = self.calc_HW(H, W)
        self.in_conv_block = Conv(self.in_channels, self.C, 7, (2, 2))
        self.conv_block = nn.ModuleList(
                [
                    Conv(self.C * 2**i, self.C * 2**(i+1), 3, (2, 2))
                    for i in range(self.num_downs-1)
                ]
        )
        self.fc_block = nn.Sequential(
                nn.Linear(self.H * self.W * self.C * 2**(self.num_downs - 1), 1024),
                nn.Linear(1024, 1024),
                nn.Linear(1024, self.num_C)
        )
        
    def calc_HW(self, H, W):
        H, W = conv_shape((H, W), 7, 2, 0)
        for num_down in range(self.num_downs - 1):
            H, W = conv_shape((H, W), 3, 2, 0)
        return H, W
        
        
    def forward(self, x):
        x = self.in_conv_block(x)
        for block in self.conv_block:
            x = block(x)
        x = x.view(x.shape[0], -1)
        x = self.fc_block(x)
        return x
  

In [None]:
def main():
    class_labels_path = "./class_dict.json"
    seed_everything(config.seed)
    path_trainmeta='../input/birdclef-2022/train_metadata.csv'
    base_train_dir='../input/birdclef-2022/train_audio'
    df=display_(path_trainmeta,base_train_dir)
    
    class_labels_path = "./class_dict.json"
    class_labels = json.load(open(class_labels_path, "r"))
    num_classes = len(class_labels.keys())
    print("Number of class : {}".format(num_classes))
    
    model = CLEFNetwork(num_classes)
    rand_data = torch.rand(5, 1, 128, 128)
    print(model(rand_data).shape)
    
    for name, param in model.named_parameters():
        print(f"{name} : {param.shape}, requires_grad : {param.requires_grad}")
        
    torchsummary.summary(model, (1, 128, 128), device = "cpu")

    
    
main()

# Lightning Classifier

In [None]:
class clssifier(pl.LightningModule):
    
    def __init__(
        self,
        model,
        loss=CrossEntropyLoss(),
        lr=1e-4
    ):
        super().__init__()
        self.model=model
        self.loss=loss
        self.lr=lr
        
    def accuracy_func(self,pred, true):
        pred = torch.argmax(pred, dim = 1)
        acc = sum(true == pred)
        return acc
    
    def configure_optimizers(self):
        opt=torch.optim.Adam(self.model.parameters(),lr=self.lr)
        return opt
    
    def training_step(self,batch,batch_idx):
        patch,label=batch
        output=self.model(patch)
        loss = self.loss(output, label)
        acc = self.accuracy_func(output, label)
        self.log('train_loss',loss)
        return loss
    
    def validation_step(self,batch,batch_idx):
        patch,label=batch
        output=self.model(patch)
        val_loss = self.loss(output, label)
        acc = self.accuracy_func(output, label)
        self.log('val_loss',val_loss)
        return val_loss

# Lets get the party started!

In [None]:
def main():
    class_labels_path = "./class_dict.json"
    seed_everything(config.seed)
    path_trainmeta='../input/birdclef-2022/train_metadata.csv'
    base_train_dir='../input/birdclef-2022/train_audio'
    df=display_(path_trainmeta,base_train_dir)
    
    data_module=pl_pipeline(base_pipe,df,128)
    class_labels_path = "./class_dict.json"
    class_labels = json.load(open(class_labels_path, "r"))
    num_classes = len(class_labels.keys())
    
    model = CLEFNetwork(num_classes)
    rand_data = torch.rand(5, 1, 128, 128)
    pl_classifier=clssifier(model)
    trainer=pl.Trainer(accelerator='gpu',max_epochs=1)
    trainer.fit(pl_classifier,data_module)

    
main()

# Todo 
* Add data augmentation block (GPU compatible pipeline inclusion)
* Different model arch
* read disscussion forum
* other pytorch implementation analysis
* wandb logging


# Reference

https://www.kaggle.com/code/utcarshagrawal/birdclef-audio-pytorch-tutorial

https://www.kaggle.com/code/sagnik1511/birdclef-2022-model-building-and-training#Neural-Network-Model

[pytorch audio tutorial](https://pytorch.org/audio/stable/tutorials/audio_io_tutorial.html)