# Training pipeline for MFCC (train + save)

In [1]:
import sys
import os
from git_root import git_root

import numpy as np
import tensorflow as tf

from tensorflow.keras.optimizers import Adam

sys.path.append(git_root("utils"))
from utils import load_params

from fetch_data import fetch_data_local
from fetch_data import to_numpy_arrays, prepare_tf_dataset

sys.path.append(git_root("models"))
from MFCC_model import setup_model

from validation_utils import plot_history

In [2]:
print(tf.__version__)

2.0.0


In [3]:
params = load_params()

In [69]:
train = fetch_data_local(map_type="mfcc", train=True)
test = fetch_data_local(map_type="mfcc", train=False)

Fetching: data_mfcc_train.json
Fetching: data_mfcc_test.json


In [70]:
print("train")
print(train.shape)
print("test")
print(test.shape)

train
(900, 3)
test
(100, 3)


In [71]:
len(train.iloc[0, 1])

30

In [72]:
train = to_numpy_arrays(train, mfcc=True)
test = to_numpy_arrays(test, mfcc=True)

In [73]:
print("train")
print(len(train[0]))
print(len(train[1]))
print("test")
print(len(test[0]))
print(len(test[1]))

train
900
900
test
100
100


In [74]:
train = prepare_tf_dataset(train[0], train[1])
test = prepare_tf_dataset(test[0], test[1])

In [75]:
tr_sample_batch, tr_label_batch = next(iter(train))
te_sample_batch, te_label_batch = next(iter(test))
print("train")
print(tr_sample_batch.shape, tr_label_batch.shape)
print("test")
print(te_sample_batch.shape, te_label_batch.shape)

train
(32, 40, 50, 30) (32,)
test
(32, 40, 50, 30) (32,)


In [168]:
#We load the model
net = mfcc.setup_model()
net.summary()

Model: "sequential_46"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1 (Conv2D)               (None, 31, 41, 12)        36012     
_________________________________________________________________
pooling1 (AveragePooling2D)  (None, 15, 20, 12)        0         
_________________________________________________________________
affine (AffineScalar)        (None, 15, 20, 12)        2         
_________________________________________________________________
conv2 (Conv2D)               (None, 13, 18, 12)        1308      
_________________________________________________________________
pooling2 (GlobalAveragePooli (None, 12)                0         
_________________________________________________________________
dense (Dense)                (None, 10)                130       
Total params: 37,452
Trainable params: 37,452
Non-trainable params: 0
_________________________________________________

In [169]:
net.compile(
    optimizer=Adam(learning_rate=1e-3),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
# Since the dataset already takes care of batching,
# we don't pass a `batch_size` argument.
history = net.fit(train, epochs=30, validation_data=test)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

In [None]:
plot_history(history)

In [None]:
print(f"Final validation accuracy is: {history.history['val_accuracy'][-1]:.3f}")

In [None]:
#overwrite = False
if overwrite:
    try:
        file_path = git_root("models", "saved_models", "mfcc.h5")
        net.save(file_path)
    except:
        print("Runtime Error: you might want to check the saved model.")
else:
    print("CAN'T OVERWRITE.")