In [None]:
# imports
import os
import math
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import moviepy.editor as mped
import matplotlib.pyplot as plt
import IPython
from IPython.display import Audio

DEVICE = 'cuda'

In [None]:
# dataset
class AudioDataset(Dataset):
    # Constructor
    def __init__(self, path, dataset = 'train', chunk_size = 4, chunk_step = 1/40, sr = 22050,
                 frame_length = 1024, win_length = 1024, hop_length = 512, use_mel = True, mel_count = 128,
                 keep_audio = True, use_wobble = True, wobble = 2):
        self.path = path
        self.chunk_size = chunk_size
        self.chunk_step = chunk_step
        self.sr = sr
        self.frame_length = frame_length
        self.win_length = win_length
        self.hop_length = hop_length
        self.use_mel = use_mel
        self.mel_count = mel_count
        self.pos_audio = None
        self.neg_audio = None
        self.keep_audio = keep_audio
        self.use_wobble = use_wobble
        self.wobble = wobble
        
        def get_file(file):
            video = mped.VideoFileClip(file)
            audio = video.audio
            audio_tensor = torch.Tensor(audio.to_soundarray(fps=self.sr, nbytes=2)).t()
            video.close()
            audio.close()
            audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) # Stereo to mono

            return audio_tensor
        
        print('Loading audio into memory')
        labels = ['Negative', 'Positive']
        for label in labels:
            for root, directories, files in os.walk(os.path.join(self.path, dataset.capitalize(), label), topdown=False):
                for file in files:
                    if file.endswith('.mp4'):
                        try :
                            if label == 'Positive':
                                if self.pos_audio is None:
                                    self.pos_audio = get_file(os.path.join(root, file))
                                elif torch.is_tensor(self.pos_audio):
                                    audio = get_file(os.path.join(root, file))
                                    self.pos_audio = torch.cat((self.pos_audio, audio), 1)
                            elif label == 'Negative':
                                if self.neg_audio is None:
                                    self.neg_audio = get_file(os.path.join(root, file))
                                elif torch.is_tensor(self.neg_audio):
                                    audio = get_file(os.path.join(root, file))
                                    self.neg_audio = torch.cat((self.neg_audio, audio), 1)
                        except :
                            print(f'Error loading {os.path.join(root, file)}')

        print('Finished loading audio into memory')
        self.pos_audio_len = int(math.ceil(self.pos_audio.shape[1] / (self.chunk_step * self.sr))
                                 - (self.chunk_size / self.chunk_step))
        self.neg_audio_len = int(math.ceil(self.neg_audio.shape[1] / (self.chunk_step * self.sr))
                                 - (self.chunk_size / self.chunk_step))
        
        print('Generating positive spectogram')
        self.pos_spectogram = self.get_spectogram(self.pos_audio)
        if not self.keep_audio:
            self.pos_audio = None
        print('Generating negative spectogram')
        self.neg_spectogram = self.get_spectogram(self.neg_audio)
        if not self.keep_audio:
            self.neg_audio = None
        
        #total_spectogram = torch.cat((self.pos_spectogram, self.neg_spectogram), 2)
        #self.spec_mean = torch.mean(total_spectogram)
        #self.spec_std = torch.std(total_spectogram)
        #total_spectogram = None
        #print(f'Dataset mean = {self.spec_mean:.4f} std = {self.spec_std:.4f}')
        
        # hard coding mean and std for my dataset
        if use_mel:
            self.spec_mean = -16.5 # mel spectogram
            self.spec_std = 14.5 # mel spectogram
        else:
            self.spec_mean = -30 # spectogram
            self.spec_std = 15 # spectogram
        
        self.pos_spectogram = (self.pos_spectogram - self.spec_mean) / self.spec_std
        self.neg_spectogram = (self.neg_spectogram - self.spec_mean) / self.spec_std
    
    def get_spectogram(self, audio):
        if self.use_mel:
            spectogram = torchaudio.transforms.MelSpectrogram(sample_rate = self.sr,
                                                                n_fft = self.frame_length,
                                                                win_length = self.win_length,
                                                                hop_length = self.hop_length,
                                                                n_mels = self.mel_count,
                                                                window_fn = torch.hamming_window)(audio)
        else:  
            window = torch.hamming_window(self.win_length, requires_grad=False)
            spectogram = torchaudio.functional.spectrogram(audio,
                                                            pad = 0,
                                                            n_fft = self.frame_length,
                                                            win_length = self.win_length,
                                                            hop_length = self.hop_length,
                                                            window = window,
                                                            power = 2,
                                                            normalized = False)
            
        spectogram = torchaudio.functional.amplitude_to_DB(spectogram,
                                                            amin = 1e-10,
                                                            multiplier = 10.0,
                                                            db_multiplier = 0)
        
        return spectogram
    
    # Get the length
    def __len__(self):
        return self.pos_audio_len + self.neg_audio_len
    
    # Getter
    def __getitem__(self, idx):
        wobble = random.randint(0,self.wobble)
        
        if(idx >= self.pos_audio_len):
            start = math.ceil((idx - self.pos_audio_len) * self.chunk_step * self.sr)
            if self.use_wobble:
                start += wobble * self.hop_length
            if self.keep_audio:
                end = math.ceil(start + self.chunk_size * self.sr)
                audio = self.neg_audio[:,start:end]

            start_spectogram = math.floor(start / self.hop_length)
            end_spectogram = math.floor(start_spectogram + (self.chunk_size * self.sr) / self.hop_length)
            spectogram = self.neg_spectogram[:,:,start_spectogram:end_spectogram]

            label = 0.
        else:
            start = math.ceil(idx * self.chunk_step * self.sr)
            if self.use_wobble:
                start += wobble * self.hop_length
            if self.keep_audio:
                end = math.ceil(start + self.chunk_size * self.sr)
                audio = self.pos_audio[:,start:end]
            
            start_spectogram = math.floor(start / self.hop_length)
            end_spectogram = math.floor(start_spectogram + (self.chunk_size * self.sr) / self.hop_length)
            spectogram = self.pos_spectogram[:,:,start_spectogram:end_spectogram]

            label = 1.
        if self.keep_audio:
            audio = audio.squeeze(0)
        else:
            audio = torch.tensor([0])
        spectogram = spectogram.squeeze(0)

        if self.keep_audio:
            if audio.shape[0] < self.chunk_size * self.sr:
                pad = torch.zeros(self.chunk_size * self.sr)
                pad[:audio.shape[0]] = audio
                audio = pad
        spec_length = math.floor((self.chunk_size * self.sr) / self.hop_length)
        if spectogram.shape[1] < spec_length:
            pad = torch.zeros(spectogram.shape[0], spec_length)
            pad[:,:spectogram.shape[1]] = spectogram
            spectogram = pad
        
        return spectogram, audio, label

In [None]:
# model class
class LSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_outputs, dropout = 0.2):
        super(LSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        # x -> batch_size, sequence_length, input_size
        self.batchnorm = torch.nn.BatchNorm1d(hidden_size)
        self.fc = torch.nn.Linear(hidden_size, num_outputs)
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
        
        out, _ = self.lstm(x, (h0, c0))
        # out: batch_size, sequence_length, hidden_size
        # out: (N, 88, 1024)
        out = out[:, -1, :]
        # out (N, 1024) i.e. Take the last in the sequence only
        out = self.batchnorm(out)
        out = self.fc(out)
        out = self.sigmoid(out)
        return out

In [None]:
# hyperparameters
chunk_size = 3 # seconds
chunk_step = 1/40 # seconds
use_wobble = True
wobble = 1 # wobbles the dataset by 0 or 1 hop randomly (no point increasing this unless chunk_step is reduced)
sr = 22050 # Hz
frame_length = 1024 # samples
win_length = 1024 # samples
hop_length = 512 # samples
use_mel = True
mel_count = 128
if use_mel:
    input_size = mel_count
else:
    input_size = math.ceil((frame_length / 2) + 1)
hidden_size = 1024
sequence_length = math.floor((chunk_size * sr) / hop_length)
num_layers = 2
num_outputs = 1
num_epochs = 100
batch_size = 128
learning_rate = 0.0001
dropout = 0.15
threshold = 0.99
model_file = 'model.pth'

In [None]:
# model, loss function and optimiser
model = LSTM(input_size, hidden_size, num_layers, num_outputs, dropout).to(DEVICE)
if os.path.isfile(model_file):
    try:
        state_dict = torch.load(model_file)
        model.load_state_dict(state_dict) ## Will this work?
    except:
        print('Failed to load saved model')
print(model)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

In [None]:
# training and test sets
train_dataset = None # frees up some memory when re-running this cell
test_dataset = None # frees up some memory when re-running this cell
train_dataset = AudioDataset('Data',
                             dataset = 'train',
                             chunk_size = chunk_size,
                             chunk_step = chunk_step,
                             sr = sr,
                             frame_length = frame_length,
                             win_length = win_length,
                             hop_length = hop_length,
                             use_mel = use_mel,
                             mel_count = mel_count,
                             keep_audio = True,
                             use_wobble = use_wobble,
                             wobble = wobble)

test_dataset = AudioDataset('Data',
                             dataset = 'test',
                             chunk_size = chunk_size,
                             chunk_step = chunk_step,
                             sr = sr,
                             frame_length = frame_length,
                             win_length = win_length,
                             hop_length = hop_length,
                             use_mel = use_mel,
                             mel_count = mel_count,
                             keep_audio = True,
                             use_wobble = use_wobble,
                             wobble = wobble)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# training
def analyse_predictions(title, predictions, labels, audio, spectograms, outputs,
                        show_false_negatives = False, show_false_positives = True):
        correct = torch.sum(predictions == labels)
        total = labels.shape[0]
        true_pos = (predictions == labels) * labels
        true_neg = (predictions == labels) * -(labels-1)
        false_pos = (predictions != labels) * -(labels-1)
        false_neg = (predictions != labels) * labels
        true_pos_sum = torch.sum(true_pos)
        true_neg_sum = torch.sum(true_neg)
        false_pos_sum = torch.sum(false_pos)
        false_neg_sum = torch.sum(false_neg)
        false_pos_index = false_pos.nonzero(as_tuple=False)
        false_neg_index = false_neg.nonzero(as_tuple=False)
        accuracy = correct / total
        print(title)
        print(f'Accuracy: {accuracy}')
        print(f'True Positive: {true_pos_sum}')
        print(f'True Negative: {true_neg_sum}')
        print(f'False Positive: {false_pos_sum}')
        if show_false_positives:
            for idx in false_pos_index:
                print('False Positive Example:')
                print(f'Score: {outputs[idx[0]].item()}')
                spect = spectograms[idx[0]].squeeze(0).to('cpu')
                plt.imshow(spect, cmap='rainbow_r')
                plt.show()
                if torch.is_tensor(audio[idx[0]]):
                    IPython.display.display(Audio(data=audio[idx[0]], rate=sr))
        print(f'False Negative: {false_neg_sum}')
        if show_false_negatives:
            for idx in false_neg_index:
                print('False Negative Example:')
                print(f'Score: {outputs[idx[0]].item()}')
                spect = spectograms[idx[0]].squeeze(0).to('cpu')
                plt.imshow(spect, cmap='rainbow_r')
                plt.show()
                IPython.display.display(Audio(data=audio[idx[0]], rate=sr))

model.train()
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (spectograms, audio, labels) in enumerate(train_loader):  
        spectograms = spectograms.reshape(-1, sequence_length, input_size).to(DEVICE)
        labels = torch.unsqueeze(labels,1).float().to(DEVICE)
        
        # Forward pass
        outputs = model(spectograms)
        loss = criterion(outputs, labels)
    
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('=======================================================================================')
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.12f}')
            if (i+1) % 1000 == 0:
                
                # How are we doing against a batch of the training set?
                model.eval()
                (spectograms, audio, labels) = next(iter(train_loader))
                spect_plots = spectograms
                spectograms = spectograms.reshape(-1, sequence_length, input_size).to(DEVICE)
                labels = torch.unsqueeze(labels,1).float().to(DEVICE)
                outputs = model(spectograms)
                predictions = (outputs>threshold).float()
                analyse_predictions('TRAIN SET', predictions, labels, audio, spect_plots, outputs,
                                    show_false_positives = True, show_false_negatives = True)

                # How are we doing against a batch of the test set?
                (spectograms, audio, labels) = next(iter(test_loader))
                spect_plots = spectograms
                spectograms = spectograms.reshape(-1, sequence_length, input_size).to(DEVICE)
                labels = torch.unsqueeze(labels,1).float().to(DEVICE)
                outputs = model(spectograms)
                predictions = (outputs>threshold).float()
                analyse_predictions('TEST SET', predictions, labels, audio, spect_plots, outputs,
                                   show_false_positives = True, show_false_negatives = False)
                model.train()

                # save model state
                torch.save(model.state_dict(), model_file)