$$\Huge{\textit{Pretrained networks for Audio}}$$

# Introduction : What is PANNs ?

PANNs stands for Pretrained Audio Neural Networks.
It appreared in [this paper](https://arxiv.org/abs/1912.10211) from Qiuqiang Kong and .al.

The idea of PANNs is to use the huge [AudioSet dataset](https://research.google.com/audioset/) to pretrain large-scale neural networks for audio patern recognition.


But audio patern recognition is a broad subject, so let us describe the common tasks and methods used.

## Tasks in Audio

### Speech recognition

Alexa, google home, and even the simple speech-to-text feature in mobile phones uses speech recognition.

Speech recognition is transcribing an audio input to a text of natural language said during this audio clip.

### Audio Tagging

Audio tagging can be seen as the classification task for audio clips.

For example the [GTZAN dataset](http://marsyas.info/downloads/datasets.html) makes you tag the genre of music clips (rock, blues, pop, jazz, ...)

### Sound event detection

Sound event detection can be seen as the [instance segmentation](https://paperswithcode.com/task/instance-segmentation) task for audio clips.

As audio is 1-dimentional data, you can use audio tagging for sound event detection. 

For example by tagging a 2s window of the audio clip every 0.5s.

## Methods for Audio

### Audio as a time serie

Audio is a magnitude moving through time : a 1-D data, a time serie. We call it a waveform.

Plenty of algorithms exists for such data:

*   Recurent neural networks (RNN, LSTM, GRU, ...)
*   Causal convolutional neural network

But as such, they only uses the time dimension. In fact when it comes to identify a sound, frequency is often a crucial information.


### Audio as an image

Because often in audio, the frequency **and** time patterns are important, we often want to keep both in our input of neural network.

That means that we transform our 1-D data to a 2-D data : just like an image.


Then you can use all 2D-Convolutional based neural networks algorithms.

There is multiple ways of transforming audio to images:

*   Log mel spectrograms
*   Wavegrams

And PANNs allows you to combine the two of them !

#### Logmel

Log Mel spectrograms are built using :


1.   A short time fourier transform
2.   A mel filter
3.   The logarithm function

They are described in [this paper](https://arxiv.org/abs/1808.01935)

![](https://encrypted-tbn0.gstatic.com/images?q=tbn%3AANd9GcRPF8wAi-2Y48scV3uIEGkimJ8FVYatfloaFw&usqp=CAU)

#### Wavegram

Wavegram were introduced in the [PANNs paper](https://arxiv.org/abs/1912.10211) :

"*Wavegram  is  our  proposed  feature  that  is  similar  to  log mel spectrogram, but is learnt from a neural network. The designof Wavegram is to learn a time-frequency representation thatis a modification to Fourier transform. A Wavegram has a timeaxis  and  a  frequency  axis.*"

And PANNs combines this with log mel spectrograms :

![Texte alternatif…](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F264781%2Fd8ead0ba-89ab-f5e3-32ab-4b06209415fd.png?ixlib=rb-1.2.2&auto=format&gif-q=60&q=75&s=49a4cecae966253f2543eaea80325e72)


# Use PANNs for audio tagging on custom dataset

You can see [this repository](https://github.com/qiuqiangkong/panns_transfer_to_gtzan) to find a fast implementation.

If you want to build your own, in this section we will describe step by step how to !

## 1) Prepare the data

### Download dataset

First download the [GTZAN dataset](http://marsyas.info/downloads/datasets.html)

It should take some time depending of your connection.

In [2]:
!wget http://opihi.cs.uvic.ca/sound/genres.tar.gz

--2022-03-07 20:54:16--  http://opihi.cs.uvic.ca/sound/genres.tar.gz
Resolving opihi.cs.uvic.ca (opihi.cs.uvic.ca)... 142.104.68.135
Connecting to opihi.cs.uvic.ca (opihi.cs.uvic.ca)|142.104.68.135|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1225571541 (1.1G) [application/x-gzip]
Saving to: ‘genres.tar.gz’


2022-03-07 21:14:43 (976 KB/s) - ‘genres.tar.gz’ saved [1225571541/1225571541]



Then extract it

In [1]:
!tar -xvf  'genres.tar.gz'

tar: genres.tar.gz: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now


### Pack dataset

General infos on the dataset

In [3]:
labels = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 
    'pop', 'reggae', 'rock']
    
lb_to_idx = {lb: idx for idx, lb in enumerate(labels)}
idx_to_lb = {idx: lb for idx, lb in enumerate(labels)}
classes_num = len(labels)

#### Normalization

Sample rate is important, ou need to have *the same sample rate* than the model !

PANNs are trained on 32kHz so we note that here.

In [4]:
sample_rate = 32000

Also the sound needs to be monophonic, hopefully, the librosa librairy allows us to load from a .wav with a chosen sample_rate and monophonic.

In [5]:
import librosa, os

audio_path = os.path.join('genres', 'blues', 'blues.00000.wav')
(audio, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)



FileNotFoundError: ignored

We can play it in colab using IPython display

In [None]:
from IPython.display import Audio

Audio(audio, rate=sample_rate)

#### Packing into h5 with h5py

h5 format is much faster to load that the wav format.

So we preprocess and save the data in the h5 format using h5py.

First we define how to format our audio clips

In [None]:
# We take 30s clips padding if shorter, truncate if longer
clip_samples = sample_rate * 30

def pad_truncate_sequence(x, clip_samples):
    if len(x) < clip_samples:
        return np.concatenate((x, np.zeros(clip_samples - len(x))))
    else:
        return x[0:clip_samples]

And how to gather target data (y)

Note that the targets labels are found in audio files names, so be sure to do that on your dataset files.

In [None]:
def get_target(audio_name, lb_to_idx):
    return lb_to_idx[audio_name.split('.')[0]]

So we define the function to pack our dataset into h5 format.

(Takes around 17 minutes for GTZAN)

In [None]:
import h5py
import os, time
import numpy as np
import librosa
from tqdm import tqdm

def pack_audio_files_to_hdf5(dataset_dir, packed_hdf5_path,
                             sample_rate, clip_samples,
                             classes_num, lb_to_idx):

    # Paths
    audios_dir = os.path.join(dataset_dir)
    if not packed_hdf5_path.endswith('.h5'):
        packed_hdf5_path += '.h5'
    if os.path.exists(packed_hdf5_path):
        os.remove(packed_hdf5_path)
    if os.path.dirname(packed_hdf5_path) != '':
        os.makedirs(os.path.dirname(packed_hdf5_path), exist_ok=True)

    (audio_names, audio_paths) = traverse_folder(audios_dir)
    
    audio_names = sorted(audio_names)
    audio_paths = sorted(audio_paths)
    audios_num = len(audio_names)

    # targets are found using get_target
    targets = [get_target(audio_name, lb_to_idx) for audio_name in audio_names]

    meta_dict = {
        'audio_name': np.array(audio_names), 
        'audio_path': np.array(audio_paths), 
        'target': np.array(targets), 
        'fold': np.arange(len(audio_names)) % 10 + 1}

    feature_time = time.time()
    with h5py.File(packed_hdf5_path, 'w') as hf:
        hf.create_dataset(
            name='audio_name', 
            shape=(audios_num,), 
            dtype='S80')

        hf.create_dataset(
            name='waveform', 
            shape=(audios_num, clip_samples), 
            dtype=np.int16)

        hf.create_dataset(
            name='target', 
            shape=(audios_num, classes_num), 
            dtype=np.float32)

        hf.create_dataset(
            name='fold', 
            shape=(audios_num,), 
            dtype=np.int32)
 
        for n in tqdm(range(audios_num), total=audios_num):
            audio_name = meta_dict['audio_name'][n]
            fold = meta_dict['fold'][n]
            audio_path = meta_dict['audio_path'][n]
            target = meta_dict['target'][n]

            (audio, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
            audio = pad_truncate_sequence(audio, clip_samples)

            hf['audio_name'][n] = audio_name.encode()
            hf['waveform'][n] = float32_to_int16(audio)
            hf['target'][n] = to_one_hot(target, classes_num)
            hf['fold'][n] = fold

    print('Write hdf5 to {}'.format(packed_hdf5_path))
    print('Time: {:.3f} s'.format(time.time() - feature_time))

def to_one_hot(k, classes_num):
    target = np.zeros(classes_num)
    target[k] = 1
    return target

def traverse_folder(fd):
    paths = []
    names = []

    for root, dirs, files in os.walk(fd):
        for name in files:
            if name.endswith('.wav'):
                filepath = os.path.join(root, name)
                names.append(name)
                paths.append(filepath)

    return names, paths

def float32_to_int16(x):
    if np.max(np.abs(x)) > 1.:
        x /= np.max(np.abs(x))
    return (x * 32767.).astype(np.int16)

dataset_dir = 'genres'
packed_hdf5_path = 'GTZAN_dataset.h5'
pack_audio_files_to_hdf5(dataset_dir=dataset_dir,
                         packed_hdf5_path=packed_hdf5_path,
                         sample_rate=sample_rate,
                         clip_samples=clip_samples,
                         classes_num=classes_num,
                         lb_to_idx=lb_to_idx)

Now we can see our audio files packed into h5 format :

In [None]:
f = h5py.File(packed_hdf5_path, 'r')

for key in f:
    print(f[key])

We have all our 1000 waveforms of size 960 000 = 32 000 x 30 (clip_samples = sample_rate x 30s)

### Create the dataloaders

For a small dataset you could load everything in RAM, but audio are often too big so we need to load data by batches (hence that's why h5 is usefull).

For that we use a DataLoader.

Here we work with pytorch so we will use the object `torch.utils.data.DataLoader`

For that we need :

1.   A dataset, to link indexes to the path in the h5 file
2.   A sampler, to get the data indexes to form batches
3.   A collate function, to group data in batches

Also, we need to think about data augmentation beforehand because the data augmentation might affect the sampler.


#### Data augmentation

In [None]:
augmentation = ['mixup']

class Mixup(object):
    def __init__(self, mixup_alpha, random_seed=1234):
        """Mixup coefficient generator.
        """
        self.mixup_alpha = mixup_alpha
        self.random_state = np.random.RandomState(random_seed)

    def get_lambda(self, batch_size):
        """Get mixup random coefficients.
        Args:
          batch_size: int
        Returns:
          mixup_lambdas: (batch_size,)
        """
        mixup_lambdas = []
        for n in range(0, batch_size, 2):
            lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
            mixup_lambdas.append(lam)
            mixup_lambdas.append(1. - lam)

        return np.array(mixup_lambdas)


def do_mixup(x, mixup_lambda):
    """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 
    (1, 3, 5, ...).
    Args:
      x: (batch_size * 2, ...)
      mixup_lambda: (batch_size * 2,)
    Returns:
      out: (batch_size, ...)
    """
    out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
        x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
    return out

#### Dataset

In [None]:
class Hdf5Dataset(object):
    def __init__(self):
        """This class takes the meta of an audio clip as input, and return 
        the waveform and target of the audio clip. This class is used by DataLoader. 
        Args:
          clip_samples: int
          classes_num: int
        """
        pass
    
    def __getitem__(self, meta):
        """Load waveform and target of an audio clip.
        
        Args:
          meta: {
            'audio_name': str, 
            'hdf5_path': str, 
            'index_in_hdf5': int}
        Returns: 
          data_dict: {
            'audio_name': str, 
            'waveform': (clip_samples,), 
            'target': (classes_num,)}
        """
        hdf5_path = meta['hdf5_path']
        index_in_hdf5 = meta['index_in_hdf5']

        with h5py.File(hdf5_path, 'r') as hf:
            audio_name = hf['audio_name'][index_in_hdf5].decode()
            waveform = int16_to_float32(hf['waveform'][index_in_hdf5])
            target = hf['target'][index_in_hdf5].astype(np.float32)

        data_dict = {
            'audio_name': audio_name, 'waveform': waveform, 'target': target}
            
        return data_dict

def int16_to_float32(x):
    return (x / 32767.).astype(np.float32)

# Dataset
dataset = Hdf5Dataset()

#### Sampler

In [None]:
class TrainSampler(object):
    def __init__(self, hdf5_path, holdout_fold, batch_size, random_seed=1234):
        """Balanced sampler. Generate batch meta for training.
        
        Args:
          indexes_hdf5_path: string
          batch_size: int
          black_list_csv: string
          random_seed: int
        """

        self.hdf5_path = hdf5_path
        self.batch_size = batch_size
        self.random_state = np.random.RandomState(random_seed)

        with h5py.File(hdf5_path, 'r') as hf:
            self.folds = hf['fold'][:].astype(np.float32)

        self.indexes = np.where(self.folds != int(holdout_fold))[0]
        self.audios_num = len(self.indexes)

        # Shuffle indexes
        self.random_state.shuffle(self.indexes)
        
        self.pointer = 0

    def __iter__(self):
        """Generate batch meta for training. 
        
        Returns:
          batch_meta: e.g.: [
            {'audio_name': 'YfWBzCRl6LUs.wav', 
             'hdf5_path': 'xx/balanced_train.h5', 
             'index_in_hdf5': 15734, 
             'target': [0, 1, 0, 0, ...]}, 
            ...]
        """
        batch_size = self.batch_size

        while True:
            batch_meta = []
            i = 0
            while i < batch_size:
                index = self.indexes[self.pointer]
                self.pointer += 1

                # Shuffle indexes and reset pointer
                if self.pointer >= self.audios_num:
                    self.pointer = 0
                    self.random_state.shuffle(self.indexes)
                
                batch_meta.append({
                    'hdf5_path': self.hdf5_path, 
                    'index_in_hdf5': self.indexes[self.pointer]})
                i += 1

            yield batch_meta

    def state_dict(self):
        state = {
            'indexes': self.indexes,
            'pointer': self.pointer}
        return state
            
    def load_state_dict(self, state):
        self.indexes = state['indexes']
        self.pointer = state['pointer']


class EvaluateSampler(object):
    def __init__(self, hdf5_path, holdout_fold, batch_size, random_seed=1234):
        """Balanced sampler. Generate batch meta for training.
        
        Args:
          indexes_hdf5_path: string
          batch_size: int
          black_list_csv: string
          random_seed: int
        """

        self.hdf5_path = hdf5_path
        self.batch_size = batch_size

        with h5py.File(hdf5_path, 'r') as hf:
            self.folds = hf['fold'][:].astype(np.float32)

        self.indexes = np.where(self.folds == int(holdout_fold))[0]
        self.audios_num = len(self.indexes)
        
    def __iter__(self):
        """Generate batch meta for training. 
        
        Returns:
          batch_meta: e.g.: [
            {'audio_name': 'YfWBzCRl6LUs.wav', 
             'hdf5_path': 'xx/balanced_train.h5', 
             'index_in_hdf5': 15734, 
             'target': [0, 1, 0, 0, ...]}, 
            ...]
        """
        batch_size = self.batch_size
        pointer = 0

        while pointer < self.audios_num:
            batch_indexes = np.arange(pointer, 
                min(pointer + batch_size, self.audios_num))

            batch_meta = []

            for i in batch_indexes:
                batch_meta.append({
                    'hdf5_path': self.hdf5_path, 
                    'index_in_hdf5': self.indexes[i]})

            pointer += batch_size
            yield batch_meta


batch_size = 8
holdout_fold = 2

# Data generator
train_sampler = TrainSampler(
    hdf5_path=packed_hdf5_path,
    holdout_fold=holdout_fold,
    batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size)

validate_sampler = EvaluateSampler(
    hdf5_path=packed_hdf5_path, 
    holdout_fold=holdout_fold, 
    batch_size=batch_size)

#### Combine for train and validation DataLoaders

In [None]:
import torch

def collate_fn(list_data_dict):
    """Collate data into batches.
    Args:
      list_data_dict, e.g., [{'audio_name': str, 'waveform': (clip_samples,), ...}, 
                             {'audio_name': str, 'waveform': (clip_samples,), ...},
                             ...]
    Returns:
      np_data_dict, dict, e.g.,
          {'audio_name': (batch_size,), 'waveform': (batch_size, clip_samples), ...}
    """
    np_data_dict = {}
    
    for key in list_data_dict[0].keys():
        np_data_dict[key] = np.array([data_dict[key] for data_dict in list_data_dict])
    
    return np_data_dict


# Data loader
num_workers = 0
train_loader = torch.utils.data.DataLoader(dataset=dataset, 
    batch_sampler=train_sampler, collate_fn=collate_fn, 
    num_workers=num_workers, pin_memory=True)

validate_loader = torch.utils.data.DataLoader(dataset=dataset, 
    batch_sampler=validate_sampler, collate_fn=collate_fn, 
    num_workers=num_workers, pin_memory=True)

## 2) Download a pretrained model

The dataset is composed of waveforms only.

This means it's up to the model to perform transformations such as logmel or wavegrames.

Good thing for us, that's why PANNs is so cool !

We can feed in an waveform input and have the pretrained PANNs model to perform everything without us having to implement anything !

PANNs is a straight transformation : waveform -> audio patern embedding

But, we still need to adapt the usage of those embeddings for a custom dataset of course !

Download a [pretrained model](https://zenodo.org/record/3576403#.XxnzEefRZPY)

Here we took the best on AudioSet : Wavegram_Logmel_Cnn14

In [None]:
!wget https://zenodo.org/record/3576403/files/Wavegram_Logmel_Cnn14_mAP%3D0.439.pth?download=1 -O Wavegram_Logmel_Cnn14.pth

In [None]:
pretrained_model_path = 'Wavegram_Logmel_Cnn14.pth'

Then get the architecture from [PANNs offical github repository](https://github.com/qiuqiangkong/audioset_tagging_cnn), here the [Wavegram_Logmel_Cnn14](https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/594d28a85fc6c24212f5be2aaf2aff8dcebace83/pytorch/models.py#L2322) and import the appropriate librairies.

In [None]:
!pip install torchlibrosa

First the ConvPreWavBlock and ConvBlock modules to build the Wavegram

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvPreWavBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvPreWavBlock, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=3, stride=1,
                              padding=1, bias=False)
                              
        self.conv2 = nn.Conv1d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=3, stride=1, dilation=2, 
                              padding=2, bias=False)
                              
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        x = F.max_pool1d(x, kernel_size=pool_size)
        
        return x


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x


def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)

And then the model itself

In [None]:
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation

class Wavegram_Logmel_Cnn14(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 
        fmax, classes_num):
        
        super(Wavegram_Logmel_Cnn14, self).__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        # Wavegram
        self.pre_conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False)
        self.pre_bn0 = nn.BatchNorm1d(64)
        self.pre_block1 = ConvPreWavBlock(64, 64)
        self.pre_block2 = ConvPreWavBlock(64, 128)
        self.pre_block3 = ConvPreWavBlock(128, 128)
        self.pre_block4 = ConvBlock(in_channels=4, out_channels=64)

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, pad_mode=pad_mode, 
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 
            n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=128, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_layer(self.pre_conv0)
        init_bn(self.pre_bn0)
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)
 
    def forward(self, input, mixup_lambda=None):
        """ Input: (batch_size, data_length)"""

        a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :])))
        a1 = self.pre_block1(a1, pool_size=4)
        a1 = self.pre_block2(a1, pool_size=4)
        a1 = self.pre_block3(a1, pool_size=4)
        a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3)
        a1 = self.pre_block4(a1, pool_size=(2, 1))

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)
        
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
            a1 = do_mixup(a1, mixup_lambda)
        
        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')

        # Concate wavegram and spectrogram along the channel dimention
        x = torch.cat((x, a1), dim=1)


        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)
        
        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))
        
        output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return output_dict


## 3) Finetune on your dataset

Be sure you use GPU accelerator on colab.
If not, go tu Runtime > Change runtime type to change to GPU.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

### Build your custom model head

In [None]:
class Transfer_Audio_Tagging(nn.Module):
    def __init__(self, base, freeze_base):
        """Classifier for a new task using pretrained Cnn14 as a sub module.
        """
        super(Transfer_Audio_Tagging, self).__init__()        
        self.base = base

        # Transfer to another task layer
        self.fc_transfer = nn.Linear(2048, classes_num, bias=True)

        if freeze_base:
            # Freeze AudioSet pretrained layers
            for param in self.base.parameters():
                param.requires_grad = False

        self.init_weights()

    def init_weights(self):
        init_layer(self.fc_transfer)

    def load_from_pretrain(self, pretrained_checkpoint_path):
        checkpoint = torch.load(pretrained_checkpoint_path,
                                map_location=torch.device(device))
        self.base.load_state_dict(checkpoint['model'])

    def forward(self, input, mixup_lambda=None):
        """Input: (batch_size, data_length)
        """
        output_dict = self.base(input, mixup_lambda)
        embedding = output_dict['embedding']

        clipwise_output =  torch.log_softmax(self.fc_transfer(embedding), dim=-1)
        output_dict['clipwise_output'] = clipwise_output
 
        return output_dict


# Options for Logmel
mel_bins = 64
fmin = 50
fmax = 14000
window_size = 1024
hop_size = 320
audioset_classes_num = 527

base = Wavegram_Logmel_Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, 
            fmax, audioset_classes_num)

# Whether we freeze the base or not
freeze_base = False

model = Transfer_Audio_Tagging(base, freeze_base)

Load the pretrain weights from AudioSet

In [None]:
model.load_from_pretrain(pretrained_model_path)

# GPU
if 'cuda' in device:
    model.to(device)

    # Parallel
    # print(f'GPU number: {torch.cuda.device_count()}')
    # model = torch.nn.DataParallel(model)

In [None]:
!nvidia-smi

### Build your custom train loop

#### Optimizer

In [None]:
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
    betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

#### Loss function

In [None]:
def loss_func(output_dict, target_dict):
    loss = - torch.mean(target_dict['target'] * output_dict['clipwise_output'])
    return loss

#### Data augmentation

In [None]:
if 'mixup' in augmentation:
    mixup_augmenter = Mixup(mixup_alpha=1.)

#### Checkpoints

In [None]:
checkpoints_dir = 'checkpoints'
os.makedirs('checkpoints', exist_ok=True)

#### Custom forward function

In [None]:
def forward(model, generator, return_input=False, 
    return_target=False):
    """Forward data to a model.
    
    Args: 
      model: object
      generator: object
      return_input: bool
      return_target: bool
    Returns:
      audio_name: (audios_num,)
      clipwise_output: (audios_num, classes_num)
      (ifexist) segmentwise_output: (audios_num, segments_num, classes_num)
      (ifexist) framewise_output: (audios_num, frames_num, classes_num)
      (optional) return_input: (audios_num, segment_samples)
      (optional) return_target: (audios_num, classes_num)
    """
    def append_to_dict(dict, key, value):
        if key in dict.keys():
            dict[key].append(value)
        else:
            dict[key] = [value]

    output_dict = {}
    device = next(model.parameters()).device

    # Forward data to a model in mini-batches
    for n, batch_data_dict in enumerate(generator):
        batch_waveform = move_data_to_device(batch_data_dict['waveform'], device)
        
        with torch.no_grad():
            model.eval()
            batch_output = model(batch_waveform)

        append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name'])

        append_to_dict(output_dict, 'clipwise_output', 
            batch_output['clipwise_output'].data.cpu().numpy())
            
        if return_input:
            append_to_dict(output_dict, 'waveform', batch_data_dict['waveform'])
            
        if return_target:
            if 'target' in batch_data_dict.keys():
                append_to_dict(output_dict, 'target', batch_data_dict['target'])

    for key in output_dict.keys():
        output_dict[key] = np.concatenate(output_dict[key], axis=0)

    return output_dict


In [None]:
def calculate_accuracy(y_true, y_score):
    N = y_true.shape[0]
    accuracy = np.sum(np.argmax(y_true, axis=-1) == np.argmax(y_score, axis=-1)) / N
    return accuracy

def move_data_to_device(x, device):
    if 'float' in str(x.dtype):
        x = torch.Tensor(x)
    elif 'int' in str(x.dtype):
        x = torch.LongTensor(x)
    else:
        return x

    return x.to(device)

#### Training loop

In [None]:
from sklearn import metrics

iteration = 0
stop_iteration = 1000
validation_cycle = 50
checkpoint_cycle = 50

# Train on mini batches
for batch_data_dict in train_loader:
    
    print(f'Iteration {iteration}', end=', ')
    
    if 'mixup' in augmentation:
        batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(len(batch_data_dict['waveform']))
    
    # Move data to device as tensor
    for key in batch_data_dict.keys():
        batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
    
    # Train
    model.train()

    if 'mixup' in augmentation:
        batch_output_dict = model(batch_data_dict['waveform'], 
            batch_data_dict['mixup_lambda'])
        batch_target_dict = {'target': do_mixup(batch_data_dict['target'], 
            batch_data_dict['mixup_lambda'])}
    else:
        batch_output_dict = model(batch_data_dict['waveform'], None)
        batch_target_dict = {'target': batch_data_dict['target']}

    # loss
    loss = loss_func(batch_output_dict, batch_target_dict)
    print(f'loss: {loss}', end='')

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Evaluate
    if iteration % validation_cycle == 0 and iteration > 0:

        output_dict = forward(model, validate_loader, return_target=True)
        clipwise_output = output_dict['clipwise_output']    # (audios_num, classes_num)
        target = output_dict['target']    # (audios_num, classes_num)

        cm = metrics.confusion_matrix(np.argmax(target, axis=-1),
                                        np.argmax(clipwise_output, axis=-1),
                                        labels=None)
        
        val_accuracy = calculate_accuracy(target, clipwise_output)
        print(f', val_acc:{val_accuracy}')
        print('Confusion Matrix:')
        print(cm)
        print(idx_to_lb)

    # Save model 
    if iteration % checkpoint_cycle == 0 and iteration > 0:
        checkpoint = {
            'iteration': iteration, 
            'model': model.state_dict()}

        checkpoint_name = f'{iteration}_iterations.pth'
        checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
        
        torch.save(checkpoint, checkpoint_path)
        print(f'Model saved at {checkpoint_name}')
    
    print()

    # Stop learning
    if iteration == stop_iteration:
        break 

    iteration += 1

We easily reach a validation accuracy >0.70 !

Of course we could do much better with longer training and appropriate tunning but let's leave it at that for now !

# From Audio tagging to Sound event detection

# Parameters importance

## Pretrained weights

## Data augmentation

## Hyperparameters

learning rate 0.001 - 0.0001

melbin 64 

hop_size small is better
