# Train MST model

In [18]:
%load_ext autoreload
%autoreload 2

from keras import optimizers
from keras.callbacks import CSVLogger
import numpy as np
import os
import librosa
import glob
import sys

sys.path.insert(0,'../..')
from sed_endtoend.mst.model import MST
from sed_endtoend.data_generator import DataGenerator, Scaler

from params import *

os.environ["CUDA_VISIBLE_DEVICES"]="1"

# files parameters
Nfiles = None
resume = False
load_subset = Nfiles

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
params = {'sequence_time': sequence_time, 'sequence_hop_time':sequence_hop_time,
          'label_list':label_list,'audio_hop':audio_hop, 'audio_win':audio_win,
          'n_fft':n_fft,'sr':sr,'mel_bands':mel_bands,'normalize':normalize_data, 
          'frames':frames,'get_annotations':get_annotations, 'dataset': dataset}

sequence_frames = int(np.ceil(sequence_time*sr/audio_hop))

# Datasets
labels = {}# Labels

train_files = sorted(glob.glob(os.path.join(audio_folder,'train', '*.wav')))
val_files = sorted(glob.glob(os.path.join(audio_folder,'validate', '*.wav')))

if load_subset is not None:
    train_files = train_files[:load_subset]
    val_files = val_files[:load_subset]

train_labels = {}
train_mel = {}
val_labels = {}
val_mel = {}

print('Founding scaler')
for n,id in enumerate(train_files):
    labels[id] = os.path.join(label_folder, 'train',os.path.basename(id).replace('.wav','.txt'))
for id in val_files:
    labels[id] = os.path.join(label_folder, 'validate',os.path.basename(id).replace('.wav','.txt'))

# Generators
print('Making training generator')
training_generator = DataGenerator(train_files, labels, **params)

params['sequence_hop_time'] = sequence_time # To calculate F1_1s

print('Making validation generator')
validation_generator = DataGenerator(val_files, labels, **params)

print('Getting data')

x_val,_,mel_val,_ = validation_generator.return_all()
x_train,_,mel_train,_ = training_generator.return_all()

scaler = Scaler(normalizer=normalize_data)

scaler.fit(mel_train)

mel_train = scaler.transform(mel_train)
mel_val = scaler.transform(mel_val)

Founding scaler
Making training generator
Making validation generator
Getting data
0.0 %
10.0 %
20.0 %
30.0 %
40.0 %
50.0 %
60.0 %
70.0 %
80.0 %
90.0 %
0.0 %
10.0 %
20.0 %
30.0 %
40.0 %
50.0 %
60.0 %
70.0 %
80.0 %
90.0 %


In [21]:
sequence_frames = x_val.shape[1]
sequence_samples = int(sequence_time*sr)

# Build model

print('\nBuilding model...')

model = MST(mel_bands,sequence_samples,audio_win,audio_hop)

model.summary()

opt = optimizers.Adam(lr=learning_rate)

# Fit model
print('\nFitting model...')

csv_logger = CSVLogger(os.path.join(expfolder, 'training.log'))

model.compile(loss='mean_squared_error',optimizer=opt)

history = model.fit(x=x_train, y=mel_train, batch_size=2*batch_size,
                            epochs=epochs, verbose=fit_verbose,
                            validation_split=0.0,
                            shuffle=True,
                            callbacks=[csv_logger],
                            validation_data=(x_val,mel_val))
                            
model.save_weights(os.path.join(expfolder, 'weights_best.hdf5'))


Building model...
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_7 (InputLayer)         (None, 22050, 1)          0         
_________________________________________________________________
conv1d_19 (Conv1D)           (None, 44, 512)           524800    
_________________________________________________________________
batch_normalization_19 (Batc (None, 44, 512)           2048      
_________________________________________________________________
activation_19 (Activation)   (None, 44, 512)           0         
_________________________________________________________________
conv1d_20 (Conv1D)           (None, 44, 256)           393472    
_________________________________________________________________
batch_normalization_20 (Batc (None, 44, 256)           1024      
_________________________________________________________________
activation_20 (Activation)   (None, 44, 256)           0 