# Learnable Time-Frequency Representations
This assignment was about training models to learn different time-frequency representations of audio. I chose to work with the Rainforest dataset because I'm familiar with the dataset from a previous Kaggle competition and because I thought it was fun to use a bioacoustics dataset for the assignment.

In [None]:
import math
from pathlib import Path
import random
import uuid
import time

import IPython.display as ipd
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from tqdm.notebook import tqdm

torchaudio.set_audio_backend('sox_io')

Let's define some paths to the Rainforest dataset

In [None]:
input_data = Path('/kaggle/input')
output_data = Path('/kaggle/working')

rainforest_data = input_data / 'rfcx-species-audio-detection'

train_data = rainforest_data / 'train'
test_data = rainforest_data / 'test'

df = pd.read_csv(rainforest_data / 'train_tp.csv')

The dataframe describing the Rainforest dataset has the following format:

In [None]:
df.head()

For each row in the dataframe we are going to extract a short segment of audio of the file corresponding to `recording_id`. We are going to center the extracted segment in the middle of `t_min` and `t_max`, and all extracted segments will be of the same lenght, for simplicity. We are going to do the segmentation in advance and save the extracted and normalized segments as tensors, for faster loading when we train the models.
We specify the wanted length and sample rate of the extracted segments, as well as the encoder we want to use:

In [None]:
SAMPLE_RATE = 22050
CLIP_LEN_SECONDS = 1.
CLIP_LEN_SAMPLES = int(SAMPLE_RATE * CLIP_LEN_SECONDS)
N_FFT = 300
ENCODER = torchaudio.transforms.Spectrogram(n_fft=N_FFT)

And we create a directory where we will save the extracted segments

In [None]:
waveform_tensors = output_data / 'waveform-tensors'

Path.mkdir(waveform_tensors)

In [None]:
weights_dir = output_data / 'weights'

Path.mkdir(weights_dir, exist_ok=True)

`get_normalized_segment` takes a filepath and a timestamp as input and returns a normalized waveform of length `CLIP_LEN_SECONDS` centered around the provided timestamp. We loop through the entire dataframe and find all the interesting segments, we store them as torch tensors in the `waveform_tensors` directory.

In [None]:
def get_normalized_segment(fpath, mid_segment_timestamp):
    audio, sr = torchaudio.load(fpath)
    audio = audio.squeeze()
    resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
    audio = resampler(audio)
    start_idx = int(sr * (mid_segment_timestamp - CLIP_LEN_SECONDS / 2))
    end_idx = start_idx + CLIP_LEN_SAMPLES
    if start_idx < 0:
        start_idx = 0
        end_idx = start_idx + CLIP_LEN_SAMPLES
    elif end_idx > len(audio) - 1:
        end_idx = len(audio) - 1
        start_idx = end_idx - CLIP_LEN_SAMPLES
    seg = audio[start_idx:end_idx]
    seg -= torch.mean(seg)
    seg /= torch.max(torch.abs(seg))
    return seg

for row_idx, row in tqdm(df.iterrows(), total=df.shape[0]):
    fpath = train_data / (row['recording_id'] + '.flac')
    mid_segment_timestamp = (row['t_min'] + row['t_max']) / 2.
    segment = get_normalized_segment(fpath, mid_segment_timestamp)
    assert segment.shape == torch.Size([CLIP_LEN_SAMPLES])
    fname = str(uuid.uuid4()) + '.pt'
    torch.save(segment, waveform_tensors / fname)

We create a class `WaveformDataset` to conveniently load data as we need it. We also create a few helper functions for vizualizing the data.

In [None]:
class WaveformDataset(Dataset):
    def __init__(self, fpaths, encoder):
        self.fpaths = fpaths
        self.encoder = encoder

    def __len__(self):
        return len(self.fpaths)

    def __getitem__(self, idx):
        return torch.load(self.fpaths[idx])
        
    def show_sample(self, idx=None, waveform=None):
        if waveform is None:
            assert idx is not None
            waveform = torch.load(self.fpaths[idx])
        self.plot_waveform(waveform)
        self.plot_encoded(waveform)
        self.display_audio(waveform)
        
    def plot_waveform(self, waveform):
        plt.figure()
        plt.title('Waveform')
        plt.plot(waveform.detach().numpy())
    
    def plot_encoded(self, waveform):
        encoded_waveform = self.encoder(waveform)
        plt.figure()
        plt.title('Encoded waveform')
        plt.imshow(encoded_waveform.detach().numpy())
        
    def display_audio(self, waveform):
        ipd.display(ipd.Audio(waveform.detach().numpy(), rate=SAMPLE_RATE))

    def show_random_sample(self):
        self.show_sample(idx=random.randrange(self.__len__()))

Let's create a dataset and look at a random sample

In [None]:
ds = WaveformDataset(list(waveform_tensors.iterdir()), ENCODER)
ds.show_random_sample()

We now define the model we are going to use. The model takes normalized audio waveforms as input, the first part of the model encodes the input with some waveform -> time-frequency transformation. The second part of the model is the decoder, it consists of transposed 1d convolution blocks that upsamples the time dimension of the activations through the layers and simultanously downsamples the frequency dimension, producing a pure time domain signal

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.tconv1 = nn.ConvTranspose1d(
            in_channels=151,
            out_channels=128,
            kernel_size=16,
            stride=4,
            dilation=4,
        )
        
        self.tconv2 = nn.ConvTranspose1d(
            in_channels=128,
            out_channels=64,
            kernel_size=32,
            stride=4,
            dilation=4,
        )
        
        self.tconv3 = nn.ConvTranspose1d(
            in_channels=64,
            out_channels=32,
            kernel_size=32,
            stride=4,
            dilation=4,
        )
        
        self.tconv4 = nn.ConvTranspose1d(
            in_channels=32,
            out_channels=1,
            kernel_size=74,
            stride=2,
            dilation=1,
        )
        
    def forward(self, x):
        x = self.tconv1(x)
        x = F.leaky_relu(x)
        x = self.tconv2(x)
        x = F.leaky_relu(x)
        x = self.tconv3(x)
        x = F.leaky_relu(x)
        x = self.tconv4(x)
        x = x.squeeze()
        return x

class EncoderDecoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

After trying multiple things, what worked "best" was to try to minimize the L1 loss between the encoded input and output. I also tried to minimize the L1 loss of the waveforms directly, but that gave even worse results.

In [None]:
def spec_loss(pred, target, encoder, criterion):
    encoded_pred = encoder(pred)
    encoded_target = encoder(target)
    log_pred = torch.log(encoded_pred + 1e-8)
    log_target = torch.log(encoded_target+ 1e-8)
    l1 = criterion(log_pred, log_target)
    return l1

The training loop is pretty straightforward, I decrease the learning rate by a factor of 10 when the validation loss stops to decrease.

In [None]:
def train(model, encoder, train_ds, val_ds, weights_path, device, batch_size=32, epochs=20):

    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True, factor=0.1)

    criterion = nn.L1Loss()

    model = model.to(device)
    criterion = criterion.to(device)
    encoder = encoder.to(device)

    best_val_loss = math.inf

    for epoch in tqdm(range(epochs), desc='Training'):
        start_time = time.time()
        model.train()
        train_loss = 0.

        for waveform_train_batch in train_dl:
            waveform_train_batch = waveform_train_batch.to(device)
            waveform_train_preds = model(waveform_train_batch)
            loss = spec_loss(waveform_train_preds, waveform_train_batch, encoder, criterion)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_dl)

        with torch.no_grad():
            val_loss = 0.
            model.eval()
            for waveform_val_batch in val_dl:
                waveform_val_batch = waveform_val_batch.to(device)
                waveform_val_preds = model(waveform_val_batch)
                loss = spec_loss(waveform_val_preds, waveform_val_batch, encoder, criterion)
                val_loss += loss.item()

            val_loss /= len(val_dl)

        elapsed = time.time() - start_time
        print(f'Epoch {epoch+1} (time: {elapsed:.0f}s): train_loss: {train_loss} val_loss: {val_loss}')

        if val_loss < best_val_loss:
            print(f'Saving new best model at epoch {epoch} (val_loss improved from {best_val_loss} to {val_loss})')
            torch.save(model, weights_path)
            best_val_loss = val_loss

        scheduler.step(val_loss)

In [None]:
weights_path = weights_dir / 'weights1.pt'

fpaths = list(waveform_tensors.iterdir())
train_fpaths, val_fpaths = train_test_split(fpaths, test_size=0.33)

train_ds = WaveformDataset(train_fpaths, ENCODER)
val_ds = WaveformDataset(val_fpaths, ENCODER)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = EncoderDecoder(ENCODER)

train(model, ENCODER, train_ds, val_ds, weights_path, device, epochs=50)

We can now compare a few audio snippets to their reconstructed counterparts

In [None]:
model = torch.load(weights_dir / 'weights1.pt')
model = model.cpu()
model.eval()

val_ds = WaveformDataset(val_fpaths, ENCODER.cpu())

In [None]:
x = val_ds[random.randrange(len(val_ds))]
val_ds.show_sample(waveform=x)

x_reconstruct = model(x.unsqueeze(0))
val_ds.show_sample(waveform=x_reconstruct)

In [None]:
x = val_ds[random.randrange(len(val_ds))]
val_ds.show_sample(waveform=x)

x_reconstruct = model(x.unsqueeze(0))
val_ds.show_sample(waveform=x_reconstruct)

We can see (and hear!) that the reconstructed audio does not sound very good, although we can see that the shape of the waveforms are pretty similar. By looking at the spectrograms, we can see that the model manages to localize the birdcall in time, but the dominating frequencies are off.