# This notebook uses pre-computed Melspectogram.I created Melspectogram in the following notebook link:
>https://www.kaggle.com/afiaibnath/create-melspec-faster-esc-50

In [None]:
import IPython.display as display

import glob
from collections import Counter

import math

import librosa
import librosa.display
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import torch
import torchaudio
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from pathlib import Path
from PIL import Image
import soundfile as sf
from torch.utils.data import Dataset
from torchvision import models, transforms

In [None]:
import sys
sys.path.append("../input/timmeffnetv2")

import timm

In [None]:
os.listdir('../input/create-melspec-faster-esc-50/train')[:1]

In [None]:
PATH_ESC50_TRAIN="../input/create-melspec-faster-esc-50/train/"
PATH_ESC50_VALID="../input/create-melspec-faster-esc-50/valid/"
PATH_ESC50_TEST="../input/create-melspec-faster-esc-50/test/"

In [None]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))
        
def find_lr(model, loss_fn, optimizer, train_loader, init_value=1e-8, final_value=10.0, device="cpu"):
    number_in_epoch = len(train_loader) - 1
    update_step = (final_value / init_value) ** (1 / number_in_epoch)
    lr = init_value
    optimizer.param_groups[0]["lr"] = lr
    best_loss = 0.0
    batch_num = 0
    losses = []
    log_lrs = []
    for data in train_loader:
        batch_num += 1
        inputs, targets = data
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        # Crash out if loss explodes

        if batch_num > 1 and loss > 4 * best_loss:
            if(len(log_lrs) > 20):
                return log_lrs[10:-5], losses[10:-5]
            else:
                return log_lrs, losses

        # Record the best loss

        if loss < best_loss or batch_num == 1:
            best_loss = loss

        # Store the values
        losses.append(loss.item())
        log_lrs.append(math.log10(lr))

        # Do the backward pass and optimize

        loss.backward()
        optimizer.step()

        # Update the lr for the next step and store

        lr *= update_step
        optimizer.param_groups[0]["lr"] = lr
    if(len(log_lrs) > 20):
        return log_lrs[10:-5], losses[10:-5]
    else:
        return log_lrs, losses        

In [None]:
class FrequencyMask(object):
    """
      Example:
        >>> transforms.Compose([
        >>>     transforms.ToTensor(),
        >>>     FrequencyMask(max_width=10, use_mean=False),
        >>> ])

    """

    def __init__(self, max_width, use_mean=True):
        self.max_width = max_width
        self.use_mean = use_mean

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of 
            size (C, H, W) where the frequency 
            mask is to be applied.

        Returns:
            Tensor: Transformed image with Frequency Mask.
        """
        start = random.randrange(0, tensor.shape[2])
        end = start + random.randrange(1, self.max_width)
        if self.use_mean:
            tensor[:, start:end, :] = tensor.mean()
        else:
            tensor[:, start:end, :] = 0
        return tensor

    def __repr__(self):
        format_string = self.__class__.__name__ + "(max_width="
        format_string += str(self.max_width) + ")"
        format_string += 'use_mean=' + (str(self.use_mean) + ')')

        return format_string

In [None]:
class TimeMask(object):
    """
      Example:
        >>> transforms.Compose([
        >>>     transforms.ToTensor(),
        >>>     TimeMask(max_width=10, use_mean=False),
        >>> ])

    """

    def __init__(self, max_width, use_mean=True):
        self.max_width = max_width
        self.use_mean = use_mean

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of 
            size (C, H, W) where the time mask 
            is to be applied.

        Returns:
            Tensor: Transformed image with Time Mask.
        """
        start = random.randrange(0, tensor.shape[1])
        end = start + random.randrange(0, self.max_width)
        if self.use_mean:
            tensor[:, :, start:end] = tensor.mean()
        else:
            tensor[:, :, start:end] = 0
        return tensor

    def __repr__(self):
        format_string = self.__class__.__name__ + "(max_width="
        format_string += str(self.max_width) + ")"
        format_string += 'use_mean=' + (str(self.use_mean) + ')')
        return format_string

In [None]:
class PrecomputedESC50(Dataset):
    def __init__(self,path, max_freqmask_width, max_timemask_width, use_mean=True, dpi=50):
        files = Path(path).glob('*.png')
        self.items = [(f,int(f.name.split("-")[-1].replace(".wav.png",""))) for f in files]
        self.length = len(self.items)
        self.max_freqmask_width = max_freqmask_width
        self.max_timemask_width = max_timemask_width
        self.use_mean = use_mean
        self.img_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
            transforms.RandomApply([FrequencyMask(self.max_freqmask_width, self.use_mean)], p=0.5),
            transforms.RandomApply([TimeMask(self.max_timemask_width, self.use_mean)], p=0.5)])
        
    def __getitem__(self, index):
        filename, label = self.items[index]
        img = Image.open(filename).convert('RGB')
        return (self.img_transforms(img), label)
            
    def __len__(self):
        return self.length

In [None]:
#spec_resnet = models.resnet50(pretrained=True)
spec_effnet= timm.create_model("tf_efficientnetv2_b3",pretrained=True)

for name,param in spec_effnet.named_parameters():
    if("bn" not in name):
        param.requires_grad = False
    
    


spec_effnet.classifier = nn.Sequential(nn.Linear(spec_effnet.classifier.in_features,500),
                               nn.ReLU(),
                               nn.Dropout(), nn.Linear(500,50))

In [None]:
#spec_effnet

In [None]:
bs=16
esc50pre_train = PrecomputedESC50(PATH_ESC50_TRAIN, max_freqmask_width=10, max_timemask_width=10 )

esc50pre_valid = PrecomputedESC50(PATH_ESC50_VALID,max_freqmask_width=10, max_timemask_width=10 )

esc50pre_test = PrecomputedESC50(PATH_ESC50_TEST,max_freqmask_width=10, max_timemask_width=10 )

esc50_train_loader = torch.utils.data.DataLoader(esc50pre_train, bs, shuffle=True)
esc50_val_loader = torch.utils.data.DataLoader(esc50pre_valid, bs, shuffle=True)
esc50_test_loader = torch.utils.data.DataLoader(esc50pre_test, bs, shuffle=True)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")


In [None]:
lr = 1e-2
spec_effnet.to(device) 
torch.save(spec_effnet.state_dict(), "spec_effnet.pth")
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(spec_effnet.parameters(), lr=lr)
logs,losses = find_lr(spec_effnet, loss_fn, optimizer, esc50_train_loader, device=device)
plt.plot(logs, losses)

In [None]:
idx=np.argmin(losses)

In [None]:
(logs[idx])

In [None]:
torch.cuda.empty_cache()
import gc
del esc50pre_test,esc50pre_train,esc50pre_valid,logs, losses
gc.collect()

In [None]:
spec_effnet.load_state_dict(torch.load("spec_effnet.pth"))

optimizer = optim.Adam([
                        {'params': spec_effnet.conv_stem.parameters()},
                        {'params': spec_effnet.bn1.parameters()},
                        {'params': spec_effnet.act1.parameters()},
                        {'params': spec_effnet.blocks.parameters(),'lr': 1e-4},
                        {'params': spec_effnet.conv_head.parameters(), 'lr': 1e-4},
                        {'params': spec_effnet.bn2.parameters(), 'lr': 1e-4},
                        {'params': spec_effnet.act2.parameters(), 'lr': 1e-4},
                        {'params': spec_effnet.global_pool.parameters(), 'lr': 1e-4},
                        {'params': spec_effnet.classifier.parameters(), 'lr': 1e-8}
                        ], lr=1e-2)


train(spec_effnet, optimizer, nn.CrossEntropyLoss(), esc50_train_loader, esc50_val_loader, epochs=10, device=device)

for param in spec_effnet.parameters():
    param.requires_grad = True

#torch.cuda.empty_cache()

train(spec_effnet, optimizer, nn.CrossEntropyLoss(), esc50_train_loader, esc50_val_loader, epochs=30, device=device)