In [1]:
%matplotlib inline

In [None]:
# -*- coding: utf-8 -*-
""" Training script to input data and train model

This module is use for procssing training. The input function is in 
converter_generator modeule. The model is in model modeule. The optimizer is in
optimizer modeule.

################################################################################
# Author: Weikun Han <weikunhan@gmail.com>
# Crate Date: 03/10/2018        
# Update:
# Reference: https://github.com/jhetherly/EnglishSpeechUpsampler
################################################################################
"""

import numpy as np
import librosa
import tensorflow as tf
from converter_generator import bitrates_and_waveforms
from converter_generator import get_original_noise_pairs
from converter_generator import next_batch
from converter_generator import random_batch
from model import audio_u_net_dnn
from optimizer import setup_optimizer
from optimizer import learing_rate_scheduling

# Please modify input path  to locate you file
DATASETS_ROOT_DIR = './datasets'
FILE_NAME_LISTS_DIR = os.path.join(DATASETS_ROOT_DIR, 'final_dataset')

# Please modify setting for training
n_epochs = 10
batch_size = 8
initial_learning_rate = 0.001
learning_rate_decay_factor = 0.01
n_epochs_per_decay = 2

#############
# DATA IMPORT
#############

train_original_noise_pairs = get_original_noise_pairs(FILE_NAME_LISTS_DIR,
                                                      'train')
val_original_noise_pairs = get_original_noise_pairs(file_name_lists_dir,
                                                    'validation')

# Selet first original noise pair, return first is bit rate pair and second
# is the waveform pair
br_pair, wf_pair = bitrates_and_waveforms(train_original_noise_pairs[0])

# Get original bit rate and waveform
original_bitrate = br_pair[0]
original_waveform = wf_pair[0]

# reshape for mono waveforms
original_waveform = original_waveform.reshape((-1, 1))

# Number of sample for each epoch train
sample_per_epoch = len(train_original_noise_pairs)

print('Number of epochs: {}'.format(n_epochs))
print('Samples per epoch: {}'.format(sample_per_epoch))
print('Batch size: {}'.format(batch_size))
print('-------------------------Processing training---------------------------')

##################
# MODEL DEFINITION
##################

train_flag, x, y_pred = audio_u_net_dnn(original_waveform.dtype,
                                        original_waveform.shape)

# placeholder for the true waveform
y = tf.placeholder(original_waveform.dtype,
                   shape=x.get_shape(),
                   name='y')

# #############
# LOSS FUNCTION
# #############

mse = tf.reduce_mean(tf.square(tf.subtract(y_pred, y)), name='mse')
tf.summary.scalar('MSE', mse)

# TODO
# ####################
# OPTIMIZATION ROUTINE
# ####################

# Variable that affect learning rate.
num_batches_per_epoch = float(SAMPLES_PER_EPOCH)/BATCH_SIZE
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

# Decay the learning rate based on the number of steps.
lr, global_step = make_variable_learning_rate(INITIAL_LEARNING_RATE,
                                              decay_steps,
                                              LEARNING_RATE_DECAY_FACTOR,
                                              False)

# lr = 1e-4
# min_args = {}
min_args = {'global_step': global_step}
training_op = setup_optimizer(lr, mse, tf.train.AdamOptimizer,
                              using_batch_norm=True,
                              min_args=min_args)

##################
# TRAINING PROCESS
##################

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# initialize tensorboard file writers
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter('/tensorboard', tf.get_default_graph())

# initialize the variables for the session
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    model_name = y_pred.name.replace('/', '_').replace(':', '_')
    val_loss_file = open('val_loss_log.txt', 'w')
    train_loss_file = open('train_loss_log.txt', 'w')

    # The number of batchs in each epoch train
    n_batchs = int(sample_per_epoch / batch_size)

    # The number of iteration
    n_iterations = n_epochs * n_batchs

    # Start train loop
    for i in range(n_iterations):

        # Setting a flag to cheack iteration go to next epoch
        new_epoch_flag = ((i + 1) % n_batchs == 0)

        # Record epoch number
        if new_epoch_flag:
            epoch_number = int((i + 1) / n_batchs)

        # Start validation if enter new epoch
        if new_epoch_flag:

            print('Calculating validation loss by total {} iterations'.format(
                len(val_original_noise_pairs) / batch_size))

            total_val_loss = 0
            val_count = 0

            # Find each validation loss for same batch size
            for val_batch in next_batch(batch_size, val_original_noise_pairs):
                loss = sess.run([mse],
                                feed_dict={train_flag: False,
                                           x: val_batch[1],
                                           y: val_batch[0]})
                total_val_loss += np.mean(loss)
                val_count += 1

            # Calculate the the average validaton loss
            val_loss = total_val_loss / val_count

            print("Epoch is: {}, Validation Loss is: {}".format(epoch_number, 
                                                                val_loss))
            
            # Record the average validation loss for each epoch
            val_loss_file.write('Epoch is: {}, Validation Loss is:{}\n'.format(
                epoch_number, val_loss))

        print('The training iterations is: {}'.format(i))

        # Random generate next batch in the training file
        train_batch = random_batch(batch_size, train_original_noise_pairs)

        # Start recording traning loss if enter a new epoch
        if new_epoch_flag:
            summary, _, train_loss = sess.run([merged, training_op, mse],
                                              feed_dict={train_flag: True,
                                                         x: batch[1],
                                                         y: batch[0]})
                
            print("Epoch is: {}, Training Loss is: {}".format(epoch_number, 
                                                              train_loss))
                
            train_writer.add_summary(summary, i)
                
            # Record the training loss for each epoch
            train_loss_file.write(
                'Epoch is: {}, Training Loss is: {}\n'.format(epoch_number, 
                                                              train_loss))
                
            # Store the training model every 3 epoch
            if epoch_number % 3 == 0:
                save_path = saver.save(
                    sess, "model_checkpoints/{}_{}.ckpt".format(model_name, 
                                                                epoch_number))

        # Run tensorflow for each train batch
        sess.run(training_op, 
                 feed_dict={train_flag: True,
                            x: train_batch[1],
                            y: train_batch[0]})

    val_loss_file.close()
    train_loss_file.close()

    # Save the variables to disk.
    save_path = saver.save(sess, 
                           "model_checkpoints/{}_final.ckpt".format(model_name))

print("Model checkpoints will be saved in file: {}".format(save_path))
print('------------------------Finished model training------------------------')






"""
truth, example = read_file_pair(val_truth_ds_pairs[1])
y_reco = model.eval(feed_dict={train_flag: False,
                               x: example.reshape(1, -1, 1)},
                    session=sess).flatten()

print('difference between truth and example (first 20 elements)')
print(truth.flatten()[:20] - example.flatten()[:20])
print('difference between truth and reconstruction (first 20 elements)')
print(truth.flatten()[:20] - y_reco[:20])

print('writting output audio files')
librosa.output.write_wav('full_train_validation_true.wav',
                         y=truth.flatten(), sr=true_br)
librosa.output.write_wav('full_train_validation_ds.wav',
                         y=example.flatten(), sr=true_br)
librosa.output.write_wav('full_train_validation_reco.wav',
                         y=y_reco, sr=true_br)
                         
""""""