# <font color='brown' size=4>Objective:</font> 
        
<p> Audio competitions are a joy to work with, as there are a plenty of opportunities to learn from. This notebook is a overview on one of the state of art models for audio representation which learns audio representation as part of its pretraining process and the learnt representation can be used on the downstream tasks similar to language models in NLP world</p>

Note:
<p>We will start with few basic things about audio and gradually move on to the recent advancements in audio representations.</p>

<font color='brown' size=4>Buckle up</font><br>

In [None]:
!wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt
!pip install soundfile 

!pip -q install timm
!pip -q install torchlibrosa
!pip -q install audiomentations
!pip install -q transformers==4.3.0

!pip uninstall -y typing
!pip install soundfile git+git://github.com/pytorch/fairseq.git@b8ea8a9b72c82192da07e3377adf4ebbde16716d
    

In [None]:
import os, glob, random
import numpy as np, pandas as pd
from sklearn.model_selection import StratifiedKFold
import random, glob
import soundfile as sf

import torch
from torch.utils.data import Dataset
from albumentations.pytorch.functional import img_to_tensor

from functools import partial

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d

import timm
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
    tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns, tf_efficientnet_b0_ns

from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
import audiomentations as A
from sklearn import metrics
from sklearn.metrics import log_loss
import wandb
from transformers import get_linear_schedule_with_warmup
import time
import fairseq

import numpy as np
from functools import partial

import sys
sys.path.append('/kaggle/input/pytorch-utils/')

from pytorch_utils import do_mixup, interpolate, pad_framewise_output
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss

from torch.utils.data import Dataset
from albumentations.pytorch.functional import img_to_tensor
from tqdm import tqdm
import librosa
import librosa.display
import matplotlib.pyplot as plt
from transformers import Wav2Vec2Tokenizer,Wav2Vec2Model


#needs to be a secret one
os.environ["WANDB_API_KEY"]='b8a047fb4d3f65c5f4616be42830b539e03c6a42'

#wandb.login()

# Table of Contents

- 1. Audio representations
   - 1.1 Fourier transform
   - 1.2 Spectrograms
   - 1.3 Mel spectrograms

- 2. Wav2vec overview
   - 2.1 Architecture
   - 2.2 Pipeline view
      
- 3. Wav2vec family

- 4. Modeling based on wav2vec representation
   - 4.1 Preparing dataset
   - 4.2 Dataset class
   - 4.3 Model definition
   - 4.4 Utils
   - 4.5 Engine
   - 4.6 Fold run
  
- 5. Acknowledgements

# <font color='brown' size=4>1. Audio representations</font>

<p>A signal is a variation in a certain quantity over time. For audio, the quantity that varies is air pressure. How do we capture this information digitally? We can take samples of the air pressure over time. The rate at which we sample the data can vary, but is most commonly 44.1kHz, or 44,100 samples per second. What we have captured is a waveform for the signal, and this can be interpreted, modified, and analyzed with computer software.</p>

In [None]:
test_file='../input/rfcx-species-audio-detection/train/00204008d.flac'
y, sr = librosa.load(test_file)
plt.plot(y);
plt.title('Signal');
plt.xlabel('Time (samples)');
plt.ylabel('Amplitude');

## <font color='brown' size=4>1.1 Fourier transform</font>

<p>The Fourier transform is a mathematical formula that allows us to decompose a signal into it’s individual frequencies and the frequency’s amplitude. In other words, it converts the signal from the time domain into the frequency domain. The result is called a spectrum.

The fast Fourier transform (FFT) is an algorithm that can efficiently compute the Fourier transform. It is widely used in signal processing</p>

<img src='https://miro.medium.com/max/700/1*xTYCtcx_7otHVu-uToI9dA.png' width=1000>
<div align="center"><font size="3">Source: Google</font></div>

In [None]:
n_fft = 2048
ft = np.abs(librosa.stft(y[:n_fft], hop_length = n_fft+1))
plt.plot(ft);
plt.title('Spectrum');
plt.xlabel('Frequency Bin');
plt.ylabel('Amplitude');

## <font color='brown' size=4>1.2 Spectrograms</font>

The fast Fourier transform is a powerful tool that allows us to analyze the frequency content of a signal, but what if our signal’s frequency content varies over time?

When FFT is computed on overlapping windowed segments of the signal, and we get what is called the spectrogram

<img src='https://miro.medium.com/max/700/1*tIBRdtG3EfjmSIlraWVIxw.png' width=1000>
<div align="center"><font size="3">Source: Google</font></div>

In [None]:
spec = np.abs(librosa.stft(y, hop_length=512))
spec = librosa.amplitude_to_db(spec, ref=np.max)
librosa.display.specshow(spec, sr=sr, x_axis='time', y_axis='log');
plt.colorbar(format='%+2.0f dB');
plt.title('Spectrogram');

## <font color='brown' size=4>1.3 Mel spectrograms</font>

<p> Studies have shown that humans do not perceive frequencies on a linear scale. We are better at detecting differences in lower frequencies than higher frequencies. For example, we can easily tell the difference between 500 and 1000 Hz, but we will hardly be able to tell a difference between 10,000 and 10,500 Hz, even though the distance between the two pairs are the same.</p>

<p>A mel spectrogram is a spectrogram where the frequencies are converted to the mel scale.</p>

In [None]:
mel_spect = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=1024)
mel_spect = librosa.power_to_db(mel_spect, ref=np.max)
librosa.display.specshow(mel_spect, y_axis='mel', fmax=8000, x_axis='time');
plt.title('Mel Spectrogram');
plt.colorbar(format='%+2.0f dB');

<p>Now that we know the basics of traditional signal processing, we can get into the latest advancement in ASR models, one among them is wav2vec released by facebook AI </p>

## <font color='brown' size=4>2. Wav2vec overview</font>

<p>Wav2vec uses semi-supervised learning to learn vector representations for preprocessed sound frames. This is similar to what word2vec does to learn word embeddings a text corpus. In the case of wav2vec it samples random parts of the sound file and learns to predict if a given part is in the near future from a current offset position. This is somewhat similar to the masked word task used to train transformers such as BERT. </p>

<p>The algorithm is trained on unlabeled data since it uses the temporal structure of the data to produce labels and it uses random sampling to produce contrasting negative examples. It is a binary classification task (is the proposed processed sound frame in the near future of the current offset or not) and it uses contrastive loss as its objective. In training for this binary classification task, it learns vector representations of sound frames (one 512 dim vector for each 10ms of sound).</p> 
    

**These vector representations are useful features because they concentrate information relevant to predicting speech. These vectors can then be used instead of spectrogram vectors**

<img src='https://d3i71xaburhd42.cloudfront.net/96bd1cd9b37cc9eea6ecc1b46afc29f95a10d424/2-Figure1-1.png' width=1000>
<div align="center"><font size="3">Source: Google</font></div>


## <font color='brown' size=4>2.2 Pipeline overview for general wav2vec model</font>

* Pretraining
* Acoustic model
* Decoding

<p>Wav2vec(after self training stage) is used as an input to an acoustic model. The vector supposedly carries more representation information than other types of features. It can be used as an input in ASR model. </p>

## <font color='brown' size=4>3. Wav2vec family</font>

There are 4 core papers of wav2vec series, all of them coming from Facebook AI. 

All these papers are building blocks of what could be a great innovation in speech recognition but also a lot of other downstream tasks related to speech:

* Wav2vec paper
* VQ - wav2vec
* Wav2vec2.0 paper
* Self-training and Pre-training are Complementary for Speech Recognition

<p> I am no expert in speech domain, but these are some of the recent advancements in speech recognition systems. We will try to extract features through Wav2vec 2 model and finetune it on our usecase. For more detailed reading please go through the links attached in acknowledgement section </p>

## <font color='brown' size=4>4. Modeling based on wav2vec representation</font>

<p>We will take the convolutional feature extractor from the pretrained model of wav2vec and used it to fintune it for our downstream task here</p>

The code is heavily inspired from this [notebook](https://www.kaggle.com/gopidurgaprasad/rfcs-audio-detection-pytorch-stater)

## <font color='brown' size=4>4.1 Preparing dataset</font>

In [None]:
root_dir='../input/rfcx-species-audio-detection/'

train = pd.read_csv(root_dir+"train_tp.csv").sort_values("recording_id")
ss = pd.read_csv(root_dir+"sample_submission.csv")

FOLDS = 5
SEED = 42

train_gby = train.groupby("recording_id")[["species_id"]].first().reset_index()
train_gby = train_gby.sample(frac=1, random_state=SEED).reset_index(drop=True)
train_gby.loc[:, 'kfold'] = -1

X = train_gby["recording_id"].values
y = train_gby["species_id"].values

kfold = StratifiedKFold(n_splits=FOLDS)
for fold, (t_idx, v_idx) in enumerate(kfold.split(X, y)):
    train_gby.loc[v_idx, "kfold"] = fold

train = train.merge(train_gby[['recording_id', 'kfold']], on="recording_id", how="left")
print(train.kfold.value_counts())
train.to_csv("train_folds.csv", index=False)

In [None]:
train.head(2)

## <font color='brown' size=4>4.2 Dataset class</font>

In [None]:
class AudioDataset(Dataset):
    def __init__(self, df, period=10, transforms=None, data_path="train", train=True):
        self.period = period
        self.transforms = transforms
        self.data_path = data_path
        self.train = train
        
        if train: 
            dfgby = df.groupby("recording_id").agg(lambda x: list(x)).reset_index()
            self.recording_ids = dfgby["recording_id"].values
            self.species_ids = dfgby["species_id"].values
            self.t_mins = dfgby["t_min"].values
            self.t_maxs = dfgby["t_max"].values
        else:
            self.recording_ids = df["recording_id"].values

    
    def __len__(self):
        return len(self.recording_ids)
    
    def __getitem__(self, idx):

        recording_id = self.recording_ids[idx]

        if self.train:
            species_id = self.species_ids[idx]
            t_min, t_max = self.t_mins[idx], self.t_maxs[idx]
        else:
            species_id = [0]
            t_min, t_max = [0], [0]

        y, sr = sf.read(f"{self.data_path}/{recording_id}.flac")
        #y, sr = librosa.load(f"{self.data_path}/{recording_id}.flac", sr=16000) 

        len_y = len(y)
        effective_length = sr * self.period
        rint = np.random.randint(len(t_min))
        tmin, tmax = round(sr * t_min[rint]), round(sr * t_max[rint])

        if len_y < effective_length:
            start = np.random.randint(effective_length - len_y)
            new_y = np.zeros(effective_length, dtype=y.dtype)
            new_y[start:start+len_y] = y
            y = new_y.astype(np.float32)

        elif len_y > effective_length:
            center = round((tmin + tmax) / 2)
            big = center - effective_length
            if big < 0:
                big = 0

            start = np.random.randint(big, center)
            y = y[start:start+effective_length]
            if len(y) < effective_length:
                new_y = np.zeros(effective_length, dtype=y.dtype)
                start1 = np.random.randint(effective_length - len(y))
                new_y[start1:start1+len(y)] = y
                y = new_y.astype(np.float32)
            else:
                y = y.astype(np.float32)
        else:
            y = y.astype(np.float32)
            start = 0
        
        if self.transforms:
            y = self.transforms(samples=y, sample_rate=sr)
            
        start_time = start / sr
        end_time = (start + effective_length) / sr

        label = np.zeros(24, dtype='f')

        if self.train:
            for i in range(len(t_min)):
                if (t_min[i] >= start_time) & (t_max[i] <= end_time):
                    label[species_id[i]] = 1
                elif start_time <= ((t_min[i] + t_max[i]) / 2) <= end_time:
                    label[species_id[i]] = 1
        
        return {
            "waveform" : y,
            "target" : torch.tensor(label, dtype=torch.float),
            "id" : recording_id
        }


In [None]:
a=AudioDataset(train,data_path=root_dir+'train')
a[0]

## <font color='brown' size=4>4.3 Model definition</font>

<p>Hugging face has released wav2vec2 model in their recent version. This makes it lot more easier to use those pretrained models for downstream finetuning </p>

In [None]:
class AudioClassifier(nn.Module):
    def __init__(self,path,classes_num,hfmodel=False):
        super().__init__()

        if not hfmodel:
            self.model, self.cfg,self.task = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
            self.model = self.model[0]
            self.fc = Linear(512, classes_num)
        else:
            self.model = Wav2Vec2Model.from_pretrained(path)
            self.fc = Linear(256, classes_num)
        
        self.dropout = Dropout(0.3)
        self.hfmodel=hfmodel
        
    
    def forward(self, input, spec_aug=False, mixup_lambda=None):

        if not self.hfmodel:
            wav2feature = self.model.feature_extractor(input)
            x1 = F.max_pool1d(wav2feature, kernel_size=3, stride=1)
            x2 = F.avg_pool1d(wav2feature, kernel_size=3, stride=1)
            x = x1 + x2
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.fc(torch.sum(x,axis=2))
        
        else:
            wav2feature = self.model(input).last_hidden_state
            x1 = F.max_pool1d(wav2feature, kernel_size=3)
            x2 = F.avg_pool1d(wav2feature, kernel_size=3)   
            x = x1 + x2
            x = F.dropout(x, p=0.5, training=self.training)
            x = self.fc(torch.sum(x,axis=1))
        
        
        return x

In [None]:
m=AudioClassifier('./wav2vec_large.pt',24)

res=m(torch.tensor(a[0]['waveform']).reshape(1,-1))
res.shape

## <font color='brown' size=4>4.4 Utils</font>

In [None]:
# https://www.kaggle.com/c/rfcx-species-audio-detection/discussion/198418#1086063
def _one_sample_positive_class_precisions(scores, truth):
    num_classes = scores.shape[0]
    pos_class_indices = np.flatnonzero(truth > 0)

    if not len(pos_class_indices):
        return pos_class_indices, np.zeros(0)

    retrieved_classes = np.argsort(scores)[::-1]

    class_rankings = np.zeros(num_classes, dtype=np.int)
    class_rankings[retrieved_classes] = range(num_classes)

    retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
    retrieved_class_true[class_rankings[pos_class_indices]] = True

    retrieved_cumulative_hits = np.cumsum(retrieved_class_true)

    precision_at_hits = (
            retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
            (1 + class_rankings[pos_class_indices].astype(np.float)))
    return pos_class_indices, precision_at_hits

def lwlrap(truth, scores):
    assert truth.shape == scores.shape
    num_samples, num_classes = scores.shape
    precisions_for_samples_by_classes = np.zeros((num_samples, num_classes))
    for sample_num in range(num_samples):
        pos_class_indices, precision_at_hits = _one_sample_positive_class_precisions(scores[sample_num, :], truth[sample_num, :])
        precisions_for_samples_by_classes[sample_num, pos_class_indices] = precision_at_hits

    labels_per_class = np.sum(truth > 0, axis=0)
    weight_per_class = labels_per_class / float(np.sum(labels_per_class))

    per_class_lwlrap = (np.sum(precisions_for_samples_by_classes, axis=0) /
                        np.maximum(1, labels_per_class))
    return per_class_lwlrap, weight_per_class

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class MetricMeter(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.y_true = []
        self.y_pred = []
    
    def update(self, y_true, y_pred):
        self.y_true.extend(y_true.cpu().detach().numpy().tolist())
        self.y_pred.extend(torch.sigmoid(y_pred).cpu().detach().numpy().tolist())

    @property
    def avg(self):
        
        score_class, weight = lwlrap(np.array(self.y_true), np.array(self.y_pred))
        self.score = (score_class * weight).sum()

        return {
            "lwlrap" : self.score
        }

## <font color='brown' size=4>4.5 Engine</font>

In [None]:
def train_epoch(config, model, loader, criterion, optimizer, scheduler, epoch,tokenizer):
    losses = AverageMeter()
    scores = MetricMeter()

    model.train()
    t = tqdm(loader)
    

    for i, sample in enumerate(t):
        optimizer.zero_grad()
        
        waveform_list=sample['waveform'].tolist()
        input_values =tokenizer(waveform_list, return_tensors = "pt").input_values
        
        input = input_values.to(config['device'])
        target = sample['target'].to(config['device'])
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if scheduler and config['step_scheduler']:
            scheduler.step()

        bs = input.size(0)
        #scores.update(target, torch.sigmoid(torch.max(output['framewise_output'], dim=1)[0]))
        scores.update(target, output)
        losses.update(loss.item(), bs)

        t.set_description(f"Train E:{epoch} - Loss{losses.avg:0.4f}")
    t.close()
    return scores.avg, losses.avg
        
def valid_epoch(config, model, loader, criterion, epoch,tokenizer):
    losses = AverageMeter()
    scores = MetricMeter()
  
    model.eval()
    
    with torch.no_grad():
        t = tqdm(loader)
        for i, sample in enumerate(t):
            
            waveform_list=sample['waveform'].tolist()
            input_values =tokenizer(waveform_list, return_tensors = "pt").input_values
        
            input = input_values.to(config['device'])
            target = sample['target'].to(config['device'])
            output = model(input)
            loss = criterion(output, target)

            bs = input.size(0)
            scores.update(target,output)
            losses.update(loss.item(), bs)
            t.set_description(f"Valid E:{epoch} - Loss:{losses.avg:0.4f}")
    t.close()
    return scores.avg, losses.avg

def test_epoch(config, model, loader,tokenizer):

    model.eval()
    pred_list = []
    id_list = []
    with torch.no_grad():
        t = tqdm(loader)
        for i, sample in enumerate(t):
            waveform_list=sample['waveform'].tolist()
            input_values =tokenizer(waveform_list, return_tensors = "pt").input_values
            
            input = input_values.to(config['device'])
            id = sample["id"]
            output = torch.sigmoid(model(input)).cpu().detach().numpy().tolist()
            pred_list.extend(output)
            id_list.extend(id)

    return pred_list, id_list

## <font color='brown' size=4>4.6 Fold run</font>

<p>Let's train it on few epochs</p>

In [None]:

def seed_everything(seed):
  # Setting seed
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def main(fold,config):
    seed_everything(config['seed'])

    config['fold'] = fold
    config['save_path'] = os.path.join(config['output_dir'],config['exp_name'])
    os.makedirs(config['save_path'], exist_ok=True)

    train_df = pd.read_csv(config['train_csv'])
    sub_df = pd.read_csv(config['sub_csv'])
    
    if config['DEBUG']:
        train_df = train_df.sample(200)

    train_fold = train_df[train_df.kfold != fold]
    valid_fold = train_df[train_df.kfold == fold]

    train_dataset = AudioDataset(
        df = train_fold,
        period=config['period'],
        transforms=None,
        data_path=root_dir+'train',
        train=True
    )

    valid_dataset = AudioDataset(
        df = valid_fold,
        period=config['period'],
        #stride=5,
        transforms=None,
        data_path=root_dir+'train',
        train=True
    )
    

    test_dataset = AudioDataset(
        df = sub_df,
        period=60,
        #stride=5,
        transforms=None,
        data_path=root_dir+'test',
        train=False
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=config['num_workers']
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        drop_last=False,
        num_workers=config['num_workers']
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config['batch_size']//2,
        shuffle=False,
        drop_last=False,
        num_workers=config['num_workers']
    )

    tokenizer = Wav2Vec2Tokenizer.from_pretrained(config.hf_path)
    
    #model = AudioClassifier('./wav2vec_large.pt',24,True)
    model = AudioClassifier(config.hf_path,24,True)
    
    model = model.to(config['device'])

    if config['pretrain_weights']:
        model.load_state_dict(torch.load(config['pretrain_weights'], map_location=config['device']), strict=False)
        model = model.to(config['device'])

    criterion = BCEWithLogitsLoss() #PANNsLoss() #MaskedBCEWithLogitsLoss() #BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    num_train_steps = int(len(train_loader) * config['epochs'])
    num_warmup_steps = int(0.1 * config['epochs'] * len(train_loader))
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

    best_lwlrap = -np.inf
    early_stop_count = 0
    wandb.watch(model, criterion, log="all", log_freq=10)

    for epoch in range(config['start_epoch'], config['epochs']):
        train_avg, train_loss = train_epoch(config, model, train_loader, criterion, optimizer, scheduler, epoch,tokenizer)
        valid_avg, valid_loss = valid_epoch(config, model, valid_loader, criterion, epoch,tokenizer)

        if config['epoch_scheduler']:
            scheduler.step()
        
        content = f"""
                {time.ctime()} \n
                Fold:{config['fold']}, Epoch:{epoch}, lr:{optimizer.param_groups[0]['lr']:.7}\n
                Train Loss:{train_loss:0.4f} - LWLRAP:{train_avg['lwlrap']:0.4f}\n
                Valid Loss:{valid_loss:0.4f} - LWLRAP:{valid_avg['lwlrap']:0.4f}\n
        """
        print(content)
        
        wandb.log({"LWLRAP_train": train_avg['lwlrap']})
        wandb.log({"LWLRAP_valid": valid_avg['lwlrap']})

        with open(config['save_path']+'/log_'+config['exp_name']+'.txt', 'a') as appender:
            appender.write(content+'\n')
        
        if valid_avg['lwlrap'] > best_lwlrap:
            print(f"########## >>>>>>>> Model Improved From {best_lwlrap} ----> {valid_avg['lwlrap']}")
            torch.save(model.state_dict(), os.path.join(config['save_path'],'fold-'+str(config['fold'])+'.bin'))
            best_lwlrap = valid_avg['lwlrap']
            early_stop_count = 0
        else:
            early_stop_count += 1
      
        if config['early_stop'] == early_stop_count:
            print("\n we reached early stoping count :", early_stop_count)
            break
    
    model.load_state_dict(torch.load(os.path.join(config['save_path'],'fold-'+str(config['fold'])+'.bin'), map_location=config['device']))
    model = model.to(config['device'])

    target_cols = sub_df.columns[1:].values.tolist()
    test_pred, ids = test_epoch(config, model, test_loader,tokenizer)
    print(np.array(test_pred).shape)

    test_pred_df = pd.DataFrame({
        "recording_id" : sub_df.recording_id.values
    })
    test_pred_df[target_cols] = test_pred
    test_pred_df.to_csv(os.path.join(config['save_path'], f"fold-{config['fold']}-submission.csv"), index=False)
    print(os.path.join(config['save_path'], f"fold-{config['fold']}-submission.csv"))


if __name__ == "__main__":
    
    for fold in range(1):
       config = dict(
                DEBUG = False,
                wandb = False,
                exp_name = "W2V_V1",
                network = "AudioClassifier",
                pretrain_weights = None,
                lr = 1e-3,
                step_scheduler = True,
                epoch_scheduler = False,
                period = 10,
                seed = 42,
                start_epoch = 0,
                epochs = 1,
                batch_size = 2,
                num_workers = 2,
                early_stop = 10,

                device = ('cuda' if torch.cuda.is_available() else 'cpu'),
                train_csv = "./train_folds.csv",
                test_csv = "test_df.csv",
                sub_csv = root_dir+"sample_submission.csv",
                output_dir = "weights",
                hf_path="facebook/wav2vec2-base-960h"
              )

       with wandb.init(project=f"pytorch_audio_{config['exp_name']}"+f"_wand_{fold}", config=config):
      
          # access all HPs through wandb.config, so logging matches execution!
          config = wandb.config
          
          main(fold,config)
    
       print(f"***********Fold:{fold} done*******")
       print("\n")

## <font color='brown' size=4>6. Acknowledgements</font>

1. https://medium.com/analytics-vidhya/understanding-the-mel-spectrogram-fca2afa2ce53
2. https://maelfabien.github.io/machinelearning/wav2vec/#e-end-to-end-model
3. https://www.kaggle.com/gopidurgaprasad/rfcs-audio-detection-pytorch-stater    