In [None]:
import os
import numpy as np
import tensorflow as tf
import time
from datetime import timedelta

from temporalcontext import settings
from temporalcontext.functions import read_selmap, read_folds_info, \
    get_inputs_for_lstm, get_lstm_model_filename


# Choose which hybrid variant to train
#lstm_type = 't1'   # Uses only scores from CNN
#lstm_type = 't2'   # Uses only embeddings from CNN
lstm_type = 't3'   # Uses scores & embeddings (concatenated) from the CNN

# NOTE: Set this to None if there were no "secondary" annotations in the dataset
secondary_annots_path = os.path.join(settings.raw_data_root, settings.added_annot_dir)


def variable_learning_rate(epoch):
    var_rates = [0.001 * factor for factor in [1.0, 1/3]]
    for idx, boundary in enumerate([35]):
        if epoch <= boundary:
            return var_rates[idx]
    return var_rates[-1]


In [None]:
selmap = read_selmap(os.path.join(settings.raw_data_root, 'selmap.csv'))
fold_file_idxs = read_folds_info(os.path.join(settings.raw_data_root, 'folds_info.txt'))

train_times = dict()

for lstm_exp in settings.lstm_experiments:

    random_state = np.random.RandomState(settings.random_seed)

    for fold_idx, fold_info in enumerate(fold_file_idxs):

        print('---------- Type {:s}, segment_advance={:.2f}, PP={:d}, Fold {:02d} ----------'.format(
            lstm_type, lstm_exp['segment_advance'], lstm_exp['pp'], fold_idx + 1))
        
        fold_seg_root = os.path.join(settings.project_root, settings.folds_dir,
                                     'f{:02d}'.format(fold_idx + 1),
                                     'seg_adv_{:.2f}'.format(lstm_exp['segment_advance']))

        input_root = os.path.join(fold_seg_root, settings.lstm_data_dir)
        model_dir = os.path.join(fold_seg_root, settings.models_dir)
        
        secondary_annots_info = None if secondary_annots_path is None else (
            secondary_annots_path,
            settings.annot_duration,
            settings.segment_length,
            lstm_exp['segment_advance'])

        # Get data
        song_data_x, song_data_y, nonsong_data_x, nonsong_data_y = \
            get_inputs_for_lstm(
                lstm_type,
                input_root,
                [selmap[f_idx][0] for f_idx in fold_info['train']],
                lstm_exp['time_steps'],
                lstm_exp['pp'],
                settings.section_suffixes,
                secondary_annots_info)

        total_num_samples = song_data_x.shape[0] + nonsong_data_x.shape[0]
        pos_samples_idxs = np.where(song_data_y == 1)[0]
        neg_samples_idxs = np.where(song_data_y != 1)[0]
        num_nonsong_samples = nonsong_data_y.shape[0]
        nonsong_samples_idxs = np.arange(num_nonsong_samples)

        print('All available samples: {:6d} song ({:6d} Pos, {:6d} Neg), {:6d} non-song ({:6d} Other)'.format(
            song_data_x.shape[0], len(pos_samples_idxs), len(neg_samples_idxs),
            nonsong_data_x.shape[0], num_nonsong_samples))

        # Shuffle the indices for random splitting into train & eval subsets
        random_state.shuffle(pos_samples_idxs)
        random_state.shuffle(neg_samples_idxs)
        random_state.shuffle(nonsong_samples_idxs)

        # Restrict to limits. Non-song samples are usually well below limit, so take all as neg.
        pos_samples_idxs = pos_samples_idxs[:settings.max_per_class_training_samples]
        neg_samples_idxs = neg_samples_idxs[:(settings.max_per_class_training_samples - num_nonsong_samples)]

        # Update counts
        num_pos_samples = len(pos_samples_idxs)
        num_neg_samples = len(neg_samples_idxs)

        # Identify the points where the song-pos/song-neg/nonsong-neg groups are to be split
        pos_eval_split = int(round((1.0 - settings.validation_split) * num_pos_samples))
        neg_eval_split = int(round((1.0 - settings.validation_split) * num_neg_samples))
        non_song_eval_split = int(round((1.0 - settings.validation_split) * num_nonsong_samples))

        num_combined_samples = num_pos_samples + num_neg_samples + num_nonsong_samples
        num_combined_train_samples = (pos_eval_split + neg_eval_split + non_song_eval_split)

        print('Pos samples  : {:7d} training, {:5d} eval'.format(
            pos_eval_split, num_pos_samples - pos_eval_split))
        print('Neg samples  : {:7d} ({:7d} + {:6d}) training, {:5d} ({:5d} + {:5d}) eval'.format(
            neg_eval_split + non_song_eval_split, neg_eval_split, non_song_eval_split,
            (num_neg_samples - neg_eval_split) + (num_nonsong_samples - non_song_eval_split),
            num_neg_samples - neg_eval_split, num_nonsong_samples - non_song_eval_split))
        print('Total samples: {:7d} training, {:5d} eval'.format(
            num_combined_train_samples, num_combined_samples - num_combined_train_samples))
        print('Input shape: {}'.format(song_data_x.shape[1:]))

        train_steps = num_combined_train_samples // settings.batch_size
        val_steps = np.ceil((num_combined_samples - num_combined_train_samples) / settings.batch_size).astype(np.int)

        # Combine the song-pos/song-neg groups to form train & eval subsets
        song_train_idxs = np.concatenate([pos_samples_idxs[:pos_eval_split], neg_samples_idxs[:neg_eval_split]])
        song_eval_idxs = np.concatenate([pos_samples_idxs[pos_eval_split:], neg_samples_idxs[neg_eval_split:]])
        nonsong_train_idxs = nonsong_samples_idxs[:non_song_eval_split]
        nonsong_eval_idxs = nonsong_samples_idxs[non_song_eval_split:]
        del pos_samples_idxs, neg_samples_idxs, nonsong_samples_idxs

        # Shuffle so that all pos & all neg samples don't stay together
        random_state.shuffle(song_train_idxs)
        random_state.shuffle(song_eval_idxs)

        # Combine train/eval splits from song and nonsong sets
        train_data_x = np.concatenate([song_data_x[song_train_idxs, ...], nonsong_data_x[nonsong_train_idxs, ...]], axis=0)
        train_data_y = np.concatenate([song_data_y[song_train_idxs], nonsong_data_y[nonsong_train_idxs]], axis=0)
        eval_data_x = np.concatenate([song_data_x[song_eval_idxs, ...], nonsong_data_x[nonsong_eval_idxs, ...]], axis=0)
        eval_data_y = np.concatenate([song_data_y[song_eval_idxs], nonsong_data_y[nonsong_eval_idxs]], axis=0)

        del song_data_x, song_data_y, nonsong_data_x, nonsong_data_y
        del song_train_idxs, song_eval_idxs, nonsong_train_idxs, nonsong_eval_idxs

        class_weights = {
            0: num_combined_train_samples / (2 * (neg_eval_split + non_song_eval_split)),
            1: num_combined_train_samples / (2 * pos_eval_split)
        }
        
        tf.keras.backend.clear_session()
        tf.random.set_seed(settings.random_seed)

        train_dataset = tf.data.Dataset.from_tensor_slices((train_data_x, train_data_y))
        train_dataset = train_dataset.cache().shuffle(settings.buffer_size).batch(settings.batch_size)

        eval_dataset = tf.data.Dataset.from_tensor_slices((eval_data_x, eval_data_y))
        eval_dataset = eval_dataset.batch(settings.batch_size).repeat()

        # Construct LSTM network
        lstm_model = tf.keras.models.Sequential([
            tf.keras.layers.LSTM(32, return_sequences=True, input_shape=train_data_x.shape[1:], dropout=0.05),
            tf.keras.layers.LSTM(16, dropout=0.05),
            tf.keras.layers.Dense(1, activation='sigmoid',
                                  bias_initializer=tf.keras.initializers.Constant(
                                      np.log([pos_eval_split/(neg_eval_split + non_song_eval_split)])))
        ])

        lstm_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=variable_learning_rate(0)),
                           loss='binary_crossentropy',
                           metrics=[tf.keras.metrics.BinaryAccuracy()])

        # Train
        start_time = time.time()
        history = lstm_model.fit(
            x=train_dataset,
            epochs=settings.epochs,
            validation_data=eval_dataset,
            validation_freq=settings.epochs_between_evals,
            validation_steps=val_steps,
            class_weight=class_weights,
            initial_epoch=0,
            shuffle=True,
            callbacks=[tf.keras.callbacks.LearningRateScheduler(variable_learning_rate)],
            verbose=2)
        end_time = time.time()

        # Save trained model. Reset metrics before saving
        lstm_model.reset_metrics()
        output_lstm_model_filename = get_lstm_model_filename(
            lstm_type, lstm_exp['time_steps'], lstm_exp['pp'])
        lstm_model.save(os.path.join(model_dir, output_lstm_model_filename))

        # Free up some memory before next iteration
        del lstm_model, train_dataset, eval_dataset, train_data_x, train_data_y, eval_data_x, eval_data_y

        curr_training_time = end_time - start_time
        if lstm_exp['time_steps'] in train_times:
            train_times[lstm_exp['time_steps']].append(curr_training_time)
        else:
            train_times[lstm_exp['time_steps']] = [curr_training_time]
        print('Training time: {}'.format(timedelta(seconds=curr_training_time)))

print()
print('================================================================================')
print('Training times for hybrid type "{:s}":'.format(lstm_type))
print('         : [Min., Max., Avg.]')

for ts, tr_times in train_times.items():
    min_time = timedelta(seconds=min(tr_times))
    max_time = timedelta(seconds=max(tr_times))
    avg_time = timedelta(seconds=sum(tr_times) / len(tr_times))
    print(' {:s} TS{:3d}: [{}, {}, {}]'.format(lstm_type, ts, min_time, max_time, avg_time))