# Q3 - Training on the MUSAN dataset

In this notebook we will prepare the training pipeline for our processed features

In [None]:
import sys
import os
import time
import h5py
import tqdm
import pickle
from pathlib import Path
from IPython.display import clear_output

# in jupyter (lab / notebook), based on notebook path
module_path = str(Path.cwd().parents[0])

if module_path not in sys.path:
    sys.path.append(module_path)

from utils.SwishNet import *
from utils.SGDRScheduler import *

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score

import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Model
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping

In [None]:
tf.__version__

In [None]:
CLIP_LEN = 2
root_path = '../data/musan_data_derived_{}s.h5'.format(CLIP_LEN)

In [None]:
%%time

# First produce a list of the file names to create splits

with open('../data/f_list_{}s.txt'.format(CLIP_LEN), 'rb') as fp:
    f_list = pickle.load(fp)

In [None]:
idx_list = list(range(len(f_list)))

train_idx, test_val_idx = train_test_split(idx_list, test_size = 0.35, shuffle = True, random_state = 2021)
val_idx, test_idx = train_test_split(test_val_idx, test_size = 25/35, shuffle = True, random_state = 2021)

In [None]:
class data_gen:
    def __init__(self, file_list, idx_list, data_path):
        self.file_list = file_list
        self.idx_list = idx_list
        self.data_path = data_path
        
    def chunker(self, lst, n, shuffle):
        list_ = lst
        if shuffle:
            np.random.shuffle(list_)

        chunks = [list_[i:i + n] for i in range(0, len(lst), n)]

        return chunks
    
    def gen(self, batch_size = 100):

        idx = self.idx_list

        batches = self.chunker(idx, batch_size, shuffle = True)

        with h5py.File(self.data_path, mode = 'r') as db:

            for batch_indexes in batches:

                batch_features = np.array([db[self.file_list[i]]['mfcc'][()] for i in batch_indexes])
                batch_labels = np.array([db[self.file_list[i]]['label'][()] for i in batch_indexes])

                yield [batch_features, batch_labels]
                    
    def build_dataset(self, chunk_size = 100):
        st = time.time()
        gen_ = self.gen(batch_size = chunk_size)
        X = []
        Y = []
        for i, (x, y) in enumerate(gen_):
            clear_output(wait = True)
            print('Batch {} / {} read'.format(str(i + 1), str(len(self.idx_list) // chunk_size)))
            X.append(x)
            Y.append(y)
        
        print("Dataset built, now converting to numpy array")
        X = np.vstack(X)
        Y = np.hstack(Y)
        et = time.time()
        
        print("Took {}s to build dataset".format(str(et - st)))

        return X, Y

In [None]:
train_gen = data_gen(f_list, train_idx, data_path = root_path)
X_train, y_train = train_gen.build_dataset()

In [None]:
print(X_train.shape)
print(y_train.shape)

input_shape = (X_train.shape[1], X_train.shape[2])

In [None]:
val_gen = data_gen(f_list, val_idx, data_path = root_path)
X_val, y_val = val_gen.build_dataset()

In [None]:
print(X_val.shape)
print(y_val.shape)

In [None]:
test_gen = data_gen(f_list, test_idx, data_path = root_path)
X_test, y_test = test_gen.build_dataset()

In [None]:
print(X_test.shape)
print(y_test.shape)

In [None]:
initial_lr = 5e-4

net = SwishNet(input_shape=input_shape, classes=3)
net.summary()
net.compile(loss='sparse_categorical_crossentropy',
              optimizer= Adam(learning_rate = initial_lr),
              metrics=['accuracy'])

In [None]:
file_path = "../model_params/model_{}s.h5".format(CLIP_LEN)
if Path(file_path).exists():
        if input('Target path exists... REMOVE? [Y/N] :').lower()=='y':
            os.remove(str(file_path))

BATCH_SIZE = 128

checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_only=True, mode='auto')
es = EarlyStopping(monitor='val_acc', mode='auto', verbose=1, patience=20)
LR_schedule = SGDRScheduler(min_lr=1e-5, max_lr=5e-4, steps_per_epoch=np.ceil(len(train_idx)/BATCH_SIZE), lr_decay=0.8, cycle_length=5, mult_factor=1.5)

history = net.fit(X_train, 
                  y_train, 
                  validation_data = (X_val, y_val), 
                  epochs=120, 
                  batch_size = BATCH_SIZE, 
                  verbose=1, 
                  shuffle = True,
                  max_queue_size = 5, 
                  workers = 2, 
                  use_multiprocessing = True,
                  callbacks = [checkpoint, es, LR_schedule])

In [None]:
plt.plot(history.history['val_loss'], label = "Validation Loss")
plt.plot(history.history['loss'], label = "Training Loss")
plt.title("Loss Over Training")
plt.legend(loc = "best")
plt.savefig("../docs/plots/loss_{}s.png".format(CLIP_LEN), dpi = 1000)

In [None]:
plt.plot(history.history['val_acc'], label = "Validation Acc")
plt.plot(history.history['acc'], label = "Training Acc")
plt.title("Accuracy Over Training")
plt.legend(loc = "best")
plt.savefig("../docs/plots/acc_{}s.png".format(CLIP_LEN), dpi = 1000)

## Evaluation

In [None]:
file_path = "../model_params/model_{}s.h5".format(CLIP_LEN)


net.load_weights(file_path, by_name = True)

In [None]:
test_loss, test_acc = net.evaluate(X_test, y_test, verbose = 1)

In [None]:
test_acc

In [None]:
preds = net.predict(X_test, verbose = 1)

In [None]:
catg = ['noise', 'music', 'speech']
label2idx = {x : i for i, x in enumerate(catg)}
idx2label = {i : x for i, x in enumerate(catg)}

In [None]:
y_labels = [idx2label[i] for i in y_test]
pred_labels = [idx2label[np.argmax(i)] for i in preds]

In [None]:
def sns_acc(y_labels, pred_labels):
    sns_true = [1 if y == 'speech' else 0 for y in y_labels]
    sns_pred = [1 if y == 'speech' else 0 for y in pred_labels]
    sns_acc = accuracy_score(sns_true, sns_pred)
    return sns_acc

In [None]:
sns_acc_ = sns_acc(y_labels, pred_labels)
sns_acc_

In [None]:
labels = ['noise', 'music', 'speech']
cm = confusion_matrix(y_labels, pred_labels, labels, normalize = 'true')

df_cm = pd.DataFrame(cm, index = labels,
                  columns = labels)

ax = plt.axes()
sns.heatmap(df_cm, annot=True, ax = ax, cmap = 'summer')

ax.set_title("Normalised Confusion Matrix for {}s Clips".format(CLIP_LEN))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig("../docs/plots/cm_{}s.png".format(CLIP_LEN), dpi = 1000)
plt.show()