In [None]:
import os
from tqdm.notebook import tqdm
import wandb
import random
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import librosa
from librosa.display import waveshow

import torch
import torchvision
import torchvision.models as models
import torchaudio
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)
print(wandb.__version__)

import plotly.io as pio
pio.templates.default = "plotly_dark"

import warnings
warnings.filterwarnings("ignore")

In [None]:
config = dict(
    seed = 42,
    use_wandb = True,
    
    batch_size=256,
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    num_epochs = 15,
    
    # preprocessing
    target_sample_rate = 32_000,
    n_fft = 1024,
    hop_length = 512,
    n_mels = 64,
    num_samples = 22050,
    duration_seconds = 7,
    num_classes=152,
    # model
    lr = 0.003
)

In [None]:
if config['use_wandb']:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")

    wandb.login(key=secret_value_0)

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(config['seed'])

In [None]:
ROOT_DIR = os.path.abspath('../input/birdclef-2022')
RAW_TRAIN_DATA_DIR = os.path.join(ROOT_DIR, 'train_audio')
TRAIN_DATA_PATH = os.path.join(ROOT_DIR, 'train_metadata.csv')
PROCESSED_ROOT_DIR = os.path.abspath('../input/birdcleff2022mfcc7seconds')
PROCESSED_TRAIN_DATA_DIR = os.path.join('../input/birdcleff2022mfcc7seconds/birdcall_processesed/data')

In [None]:
df = pd.read_csv(TRAIN_DATA_PATH)
df.drop(['scientific_name', 'common_name', 'author', 'license', 'rating', 'url', 'type'], axis=1, inplace=True)
df.head(10)

In [None]:
label_to_id = {}
id_to_label = {}

for i, label in enumerate(os.listdir(RAW_TRAIN_DATA_DIR)):
    label_to_id[label] = i
    id_to_label[i] = label

In [None]:
# df['primary_label'] = df['primary_label'].replace(label_to_id)
# df['path'] = df['filename'].apply(lambda x: os.path.join(TRAIN_DATA_DIR, x))
df.head(10)

In [None]:
class BirdCLEFDataset(Dataset):

    def __init__(self, df, data_dir, split='train', transforms=None):
        self.df = df
        self.data_dir = data_dir
        self.split = split
        self.transforms = transforms.to(config['device'])
        self.target_sr = config['target_sample_rate']
        self.num_samples = config['target_sample_rate']*config['duration_seconds']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.df['path'][idx])
        waveform = waveform.to(config['device'])
        waveform = self._resample(waveform, sample_rate)
        waveform = self._mix_down(waveform)
        waveform = self._cut_if_necessary(waveform)
        waveform = self._right_pad_if_necessary(waveform)
        waveform = self.transforms(waveform)

        if self.split == 'train':
            label = torch.tensor(self.df['primary_label'][idx])
            return waveform, label
        else:
            return waveform
    
    def _cut_if_necessary(self, signal):
        # waveform = Tensor(1, num_samples, )
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]

        return signal
    
    def _right_pad_if_necessary(self, signal):
        if signal.shape[1] < self.num_samples:
            num_missing_samples = self.num_samples - signal.shape[1]
            last_dim_padding = (0, num_missing_samples) # (num of vals to be padded on left side, num of vals to be padded on right side)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
        
        return signal

    def _resample(self, waveform, sr):
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr).to(config['device'])
            waveform = resampler(waveform)

        return waveform

    def _mix_down(self, waveform):
        '''
        Convert to audio waveform into mono waveform
        '''
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        return waveform

class BirdclefMFCCDataset(Dataset):
    
    def __init__(self, df, split='train', transforms=None):
        self.df = df
        self.split = split
        self.transforms = transforms
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        waveform = torch.from_numpy(np.load(self.df['filepath'][idx]))
        label = torch.tensor(self.df['label'][idx])
        
        if self.transforms:
            waveform = self.transforms(waveform)
        
        if self.split == 'train':
            return waveform, torch.tensor(label)
        else:
            return waveform

In [None]:
# mel_spectrogram = torchaudio.transforms.MelSpectrogram(
#     sample_rate=config['target_sample_rate'],
#     n_fft=config['n_fft'],
#     hop_length=config['hop_length'],
#     n_mels=config['n_mels']
# )

# dataset = BirdCLEFDataset(df, TRAIN_DATA_DIR, 'train', transforms=mel_spectrogram)
# signal, label = dataset[np.random.randint(0, len(dataset))]

In [None]:
# librosa.display.specshow(signal[0].cpu().numpy(), sr=config['target_sample_rate'], hop_length=config['hop_length'])
# plt.xlabel("Time")
# plt.ylabel("MFCC")
# plt.colorbar()
# plt.show()

In [None]:
resnet18 = models.resnet50()
resnet18.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet18.fc = torch.nn.Linear(512, config['num_classes'])
x = torch.randn(64, 1, 64, 438)
output = resnet18(x)
print(output.shape)

In [None]:
if config['use_wandb']:
    wandb.init(
        project="BirdCLEF 2022",
        entity="raghavprabhakar",
        config=config
    )

In [None]:
# The train function for every Epoch
def fit(model, dataset, dataloader, optim, criterion, mode='train'):
    # Choice of training and testing mode
    if mode == 'train':
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    running_corrects = 0.0
    
    tqdm_loop = tqdm(
        dataloader,
        total=dataset.__len__() // dataloader.batch_size,
        desc=mode, leave=True
    )

    # Loop over the dataloader and train over every batch of images
    for i, data in enumerate(tqdm_loop):
        # Copy data to the gpu
        images, labels = data
        images, labels = images.to(config['device']), labels.to(config['device'])
        
        # Zero the parameter gradients during training
        if mode=='train':
            optim.zero_grad()

        # Predict classes using images from the training set
        outputs = model(images)

        # Compute the loss based on model output and real labels
        loss = criterion(outputs, labels)

        # Calculate statistics
        running_loss += loss.item()
        running_corrects += (outputs.max(1)[1] == labels).sum().item()
        
        # Perform model updates according to the loss function (criterion)
        if mode=='train':
            # Backpropagate the loss
            loss.backward()
            # Adjust parameters based on the calculated gradients
            optim.step()
    
    # Record the average statistics
    epoch_loss = running_loss / dataset.__len__()
    epoch_acc = running_corrects / dataset.__len__()
    
    tqdm_loop.set_postfix(loss=epoch_loss, acc=epoch_acc)


    return epoch_loss, epoch_acc

In [None]:
optim = torch.optim.Adam(resnet18.parameters(), lr=config['lr'])
loss_fn = nn.CrossEntropyLoss()

In [None]:
from sklearn.model_selection import train_test_split
df = pd.DataFrame(os.listdir('../input/birdcleff2022mfcc7seconds/birdcall_processesed/data'), columns=['filepath'])
df['filepath'] = df['filepath'].apply(lambda x: os.path.join(PROCESSED_TRAIN_DATA_DIR, x))
df['label'] = df['filepath'].apply(lambda  x: int(x.split('_')[-1].split('.')[0]))
df.head()

train_df, valid_df = train_test_split(df, test_size=0.2)
train_df = train_df.reset_index()
valid_df = valid_df.reset_index()
print(train_df.shape, valid_df.shape)

In [None]:
train_dataset = dataset = BirdclefMFCCDataset(train_df)
valid_dataset = dataset = BirdclefMFCCDataset(valid_df)

train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True)

In [None]:
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
best_acc = 0
# Run train loop for given number of epochs
for epoch in range(config['num_epochs']):
    print(f'Epoch: {epoch+1} / {config["num_epochs"]}')
    print('-' * 10)

    # Train model one epoch and display statistics
    train_epoch_loss, train_epoch_acc = fit(resnet18.to(config['device']), train_dataset, train_dataloader, optim, loss_fn, mode='train')
    train_loss.append(train_epoch_loss)
    train_acc.append(train_epoch_acc)
    print(f'Train Loss: {train_epoch_loss:.4f} | Train Acc: {train_epoch_acc:.4f}')

    # Run validation for one epoch and display statistics
    with torch.no_grad():
        valid_epoch_loss, valid_epoch_acc = fit(resnet18.to(config['device']), valid_dataset, valid_dataloader, optim, loss_fn, mode='valid')
        valid_loss.append(valid_epoch_loss)
        valid_acc.append(valid_epoch_acc)
        print(f'Valid Loss: {valid_epoch_loss:.4f} | Valid Acc: {valid_epoch_acc:.4f}')
    
    if config['use_wandb']:
        wandb.log({
            "train_loss": train_epoch_loss,
            "valid_loss": valid_epoch_loss,
            "train_acc": train_epoch_acc,
            "valid_acc": valid_epoch_acc,

        })
    
    if valid_epoch_acc >= best_acc:
        print(f'Model improved from {best_acc} to {valid_epoch_acc}, Saving best model...')
        torch.save(resnet18.state_dict(), f'resnet50_best.pt')
        best_acc = valid_epoch_acc

# Save the model and all the metrics in a log file

checkpoint = {
            'total_epochs'      : config['num_epochs'],
            'state_dict'        : resnet18.state_dict(),
            'optimizer'         : optim.state_dict(),
            'train_loss'        : train_loss,
            'train_acc'         : train_acc,
            'val_loss'          : valid_loss,
            'val_acc'           : valid_acc,
            }
torch.save(checkpoint, 'resnet50_last.pt')
print("Model Saved")