In [None]:
from data_pipeline_v4 import DataGen
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.utils import class_weight
import os
from datetime import datetime

Set up the dataset from our generator:

In [None]:
num_epochs = 5
load_latest_model = True

In [None]:
def _fixup_shape(x, y):
  x.set_shape([None, 259, 128]) # n, h, w, c
  y.set_shape([None]) # n, nb_classes
  return x, y

batch_size = 64
tracks = pd.read_csv('./data/processed_genres_mel.csv')

# Parse filepaths
track_fpaths = list(tracks['fpath'])
track_fpaths = ['./data/fma_medium' + fpath for fpath in track_fpaths]

# Set up generator processing function
gen = DataGen()

# Set up train and test data
data = (track_fpaths, list(tracks['parent_genre_id']))
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(lambda fpath, label: tuple(tf.py_function(gen.get_sample, [fpath, label], [tf.float32, tf.int32])),
                      num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=len(track_fpaths))

# Define the split ratio for train/test datasets
num_train_samples = int(0.8 * len(track_fpaths))
num_test_samples = len(track_fpaths) - num_train_samples

# Split into train and test datasets
train_dataset = dataset.take(num_train_samples)
test_dataset = dataset.skip(num_train_samples)

train_dataset = train_dataset.repeat().batch(batch_size).map(_fixup_shape).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.repeat().batch(batch_size).map(_fixup_shape).prefetch(tf.data.AUTOTUNE)

Compute the class weights for balancing:

In [None]:
genres = np.array(tracks['parent_genre_id'])
class_weights = class_weight.compute_class_weight(class_weight='balanced',
                                                  classes=np.unique(genres),
                                                  y=genres)

class_weights = dict(enumerate(class_weights))

class_weights

Add checkpoint callback for saving every few epochs:

In [None]:
# Get the current saving/loading folder
if load_latest_model:
    training_dir_list = os.listdir('./training/')
    training_dir_list.sort()
    save_dir = './training/' + training_dir_list[-1]
else:
    dt_now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    save_dir = f'./training/training_{dt_now}'

In [None]:
checkpoint_path = save_dir + "/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

n_batches = num_train_samples // batch_size

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 save_freq=2*n_batches)

And one for backups so we can continue training if interrupted:

In [None]:
backup_callback = tf.keras.callbacks.BackupAndRestore(
    save_dir,
    save_freq="epoch",
    delete_checkpoint=False,
)

And one for history logging:

In [None]:
csv_logger = tf.keras.callbacks.CSVLogger(save_dir + '/history.csv', append=True)

Build and train the model:

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(259, dropout=0.2, recurrent_dropout=0.2), input_shape=(259, 128)),
    # tf.keras.layers.Dropout(0.2),
    # tf.keras.layers.Dense(256, activation="relu"),
    # tf.keras.layers.Dropout(0.2),
    # tf.keras.layers.Dense(128, activation="relu"),
    # tf.keras.layers.Dropout(0.2),
    # tf.keras.layers.Dense(64, activation="relu"),
    # tf.keras.layers.Dropout(0.2),
    # tf.keras.layers.Dense(32, activation="relu"),  
    tf.keras.layers.Dense(16)
])
    
model.summary()
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=["acc"], optimizer='adam')

history = model.fit(x=train_dataset, epochs=num_epochs,
                    validation_data=test_dataset, class_weight=class_weights,
                    steps_per_epoch=num_train_samples // batch_size,
                    validation_steps=num_test_samples // batch_size,
                    callbacks=[backup_callback, cp_callback, csv_logger])