In [1]:
%matplotlib inline

In [1]:
# -*- coding: utf-8 -*-
""" Training script to input data and train model

This module is use for procssing training. The input function is in 
converter_generator modeule. The model is in model modeule. The optimizer is in
optimizer modeule.

################################################################################
# Author: Weikun Han <weikunhan@gmail.com>
# Crate Date: 03/10/2018        
# Update:
# Reference: https://github.com/jhetherly/EnglishSpeechUpsampler
################################################################################
"""

import numpy as np
import os
import librosa
import tensorflow as tf
from converter_generator import bitrates_and_waveforms
from converter_generator import get_original_noise_pairs
from converter_generator import next_batch
from converter_generator import random_batch
from model import audio_u_net_dnn
from optimizer import setup_optimizer
from optimizer import learing_rate_scheduling

# Please modify input path  to locate you file
DATASETS_ROOT_DIR = './datasets'
FILE_NAME_LISTS_DIR = os.path.join(DATASETS_ROOT_DIR, 'final_dataset')
OUTPUT_TENSORBOARD_DIR = './output'
OUTPUT_MODEL_DIR = './output/model'

# Check location to save datasets
if not os.path.exists(OUTPUT_TENSORBOARD_DIR):
    os.makedirs(OUTPUT_TENSORBOARD_DIR)
if not os.path.exists(OUTPUT_MODEL_DIR):
    os.makedirs(OUTPUT_MODEL_DIR)

# Please modify setting for training
n_epochs = 10
batch_size = 8
initial_learning_rate = 0.001
decay_factor = 0.01
n_epochs_per_decay = 2

print('-------------------------Begining data input---------------------------')

#############
# DATA IMPORT
#############

train_original_noise_pairs = get_original_noise_pairs(FILE_NAME_LISTS_DIR,
                                                      'train')
val_original_noise_pairs = get_original_noise_pairs(FILE_NAME_LISTS_DIR,
                                                    'validation')

# Selet first original noise pair, return first is bit rate pair and second
# is the waveform pair
br_pair, wf_pair = bitrates_and_waveforms(train_original_noise_pairs[0])

# Get original bit rate and waveform
original_bitrate = br_pair[0]
original_waveform = wf_pair[0]

# reshape for mono waveforms
original_waveform = original_waveform.reshape((-1, 1))

# Number of sample for each epoch train
sample_per_epoch = len(train_original_noise_pairs)

print('Number of epochs: {}'.format(n_epochs))
print('Samples per epoch: {}'.format(sample_per_epoch))
print('Batch size: {}'.format(batch_size))
print('-------------------------Processing training---------------------------')

##################
# MODEL DEFINITION
##################

train_flag, x, y_pred = audio_u_net_dnn(original_waveform.dtype,
                                        original_waveform.shape)

# placeholder for the true waveform
y = tf.placeholder(original_waveform.dtype,
                   shape=x.get_shape(),
                   name='y')

###############
# LOSS FUNCTION
###############

mse = tf.reduce_mean(tf.square(tf.subtract(y_pred, y)), name='mse')
tf.summary.scalar('mse', mse)

######################
# OPTIMIZATION ROUTINE
######################

# Variable that affect learning rate.
n_batches_per_epoch = float(sample_per_epoch)/ batch_size
decay_steps = int(n_batches_per_epoch * n_epochs_per_decay)

# Decay the learning rate based on the number of steps.
learning_rate, global_step = learing_rate_scheduling(initial_learning_rate,
                                                     decay_steps,
                                                     decay_factor)

# Setup the training operator
min_args = {'global_step': global_step}
training_op = setup_optimizer(learning_rate,
                              mse,
                              tf.train.AdamOptimizer,
                              minimize_args=min_args)

##################
# TRAINING PROCESS
##################

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# initialize the variables for the session
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    
    # initialize tensorboard file writers
    merged = tf.summary.merge_all()
    tensorboard_path = os.path.join(OUTPUT_TENSORBOARD_DIR, 'tensorboard')
    train_writer = tf.summary.FileWriter(tensorboard_path, sess.graph)
    model_name = y_pred.name.replace('/', '_').replace(':', '_')
    val_loss_file = open('val_loss_log.txt', 'w')
    train_loss_file = open('train_loss_log.txt', 'w')

    # The number of batchs in each epoch train
    n_batchs = int(sample_per_epoch / batch_size)

    # The number of iteration
    n_iterations = n_epochs * n_batchs

    # Start train loop
    for i in range(n_iterations):

        # Setting a flag to cheack iteration go to next epoch
        new_epoch_flag = ((i + 1) % n_batchs == 0)

        # Record epoch number
        if new_epoch_flag:
            epoch_number = int((i + 1) / n_batchs)

        # Start validation if enter new epoch
        if new_epoch_flag:

            print('Calculating validation loss by total {} iterations'.format(
                len(val_original_noise_pairs) / batch_size))

            total_val_loss = 0
            val_count = 0

            # Find each validation loss for same batch size
            for val_batch in next_batch(batch_size, val_original_noise_pairs):
                loss = sess.run([mse],
                                feed_dict={train_flag: False,
                                           x: val_batch[1],
                                           y: val_batch[0]})
                total_val_loss += np.mean(loss)
                val_count += 1

            # Calculate the the average validaton loss
            val_loss = total_val_loss / val_count

            print("Epoch is: {}, Validation Loss is: {}".format(epoch_number, 
                                                                val_loss))
            
            # Record the average validation loss for each epoch
            val_loss_file.write('Epoch is: {}, Validation Loss is:{}\n'.format(
                epoch_number, val_loss))

        print('The training iterations is: {}'.format(i))

        # Random generate next batch in the training file
        train_batch = random_batch(batch_size, train_original_noise_pairs)

        # Start recording traning loss if enter a new epoch
        if new_epoch_flag:
            summary, _, train_loss = sess.run([merged, training_op, mse],
                                              feed_dict={train_flag: True,
                                                         x: train_batch[1],
                                                         y: train_batch[0]})
                
            print("Epoch is: {}, Training Loss is: {}".format(epoch_number, 
                                                              train_loss))
                
            train_writer.add_summary(summary, i)
                
            # Record the training loss for each epoch
            train_loss_file.write(
                'Epoch is: {}, Training Loss is: {}\n'.format(epoch_number, 
                                                              train_loss))
                
            # Store the training model every 3 epoch
            if epoch_number % 3 == 0:
                model_path = os.path.join(
                    OUTPUT_MODEL_DIR, '{}_{}.ckpt'.format(model_name, 
                                                          epoch_number))
                save_path = saver.save(sess, OUTPUT_DIR)

        # Run tensorflow for each train batch
        sess.run(training_op, 
                 feed_dict={train_flag: True,
                            x: train_batch[1],
                            y: train_batch[0]})

    val_loss_file.close()
    train_loss_file.close()

    # Save the variables to disk.
    model_path = os.path.join(
        OUTPUT_MODEL_DIR, '{}_final.ckpt'.format(model_name))
    save_path = saver.save(sess, model_path)

print("Model checkpoints will be saved in file: {}".format(save_path))
print('------------------------Finished model training------------------------')






"""
truth, example = read_file_pair(val_truth_ds_pairs[1])
y_reco = model.eval(feed_dict={train_flag: False,
                               x: example.reshape(1, -1, 1)},
                    session=sess).flatten()

print('difference between truth and example (first 20 elements)')
print(truth.flatten()[:20] - example.flatten()[:20])
print('difference between truth and reconstruction (first 20 elements)')
print(truth.flatten()[:20] - y_reco[:20])

print('writting output audio files')
librosa.output.write_wav('full_train_validation_true.wav',
                         y=truth.flatten(), sr=true_br)
librosa.output.write_wav('full_train_validation_ds.wav',
                         y=example.flatten(), sr=true_br)
librosa.output.write_wav('full_train_validation_reco.wav',
                         y=y_reco, sr=true_br)
                         
"""

-------------------------Begining data input---------------------------
Number of epochs: 10
Samples per epoch: 814
Batch size: 8
-------------------------Processing training---------------------------
The network summary for audio_u_net_dnn
    input: [80000, 1]
    downsample layer: [39998, 8]
    downsample layer: [19998, 16]
    downsample layer: [9998, 32]
    downsample layer: [4998, 64]
    downsample layer: [2498, 128]
    downsample layer: [1248, 256]
    downsample layer: [623, 512]
    downsample layer: [311, 1024]
    bottleneck layer: [154, 2048]
    upsample layer: [304, 2048]
    upsample layer: [604, 1024]
    upsample layer: [1204, 512]
    upsample layer: [2404, 256]
    upsample layer: [4804, 128]
    upsample layer: [9604, 64]
    upsample layer: [19204, 32]
    upsample layer: [38404, 16]
    restack layer: [40002, 15]
    final convolution layer: [40000, 2]
    output: [80000, 1]
--------------------Finished model building--------------------
INFO:tensorflow:Summa

The training iterations is: 217
The training iterations is: 218
The training iterations is: 219
The training iterations is: 220
The training iterations is: 221
The training iterations is: 222
The training iterations is: 223
The training iterations is: 224
The training iterations is: 225
The training iterations is: 226
The training iterations is: 227
The training iterations is: 228
The training iterations is: 229
The training iterations is: 230
The training iterations is: 231
The training iterations is: 232
The training iterations is: 233
The training iterations is: 234
The training iterations is: 235
The training iterations is: 236
The training iterations is: 237
The training iterations is: 238
The training iterations is: 239
The training iterations is: 240
The training iterations is: 241
The training iterations is: 242
The training iterations is: 243
The training iterations is: 244
The training iterations is: 245
The training iterations is: 246
The training iterations is: 247
The trai

The training iterations is: 463
The training iterations is: 464
The training iterations is: 465
The training iterations is: 466
The training iterations is: 467
The training iterations is: 468
The training iterations is: 469
The training iterations is: 470
The training iterations is: 471
The training iterations is: 472
The training iterations is: 473
The training iterations is: 474
The training iterations is: 475
The training iterations is: 476
The training iterations is: 477
The training iterations is: 478
The training iterations is: 479
The training iterations is: 480
The training iterations is: 481
The training iterations is: 482
The training iterations is: 483
The training iterations is: 484
The training iterations is: 485
The training iterations is: 486
The training iterations is: 487
The training iterations is: 488
The training iterations is: 489
The training iterations is: 490
The training iterations is: 491
The training iterations is: 492
The training iterations is: 493
The trai

Epoch is: 7, Training Loss is: 0.0033227477688342333
The training iterations is: 707
The training iterations is: 708
The training iterations is: 709
The training iterations is: 710
The training iterations is: 711
The training iterations is: 712
The training iterations is: 713
The training iterations is: 714
The training iterations is: 715
The training iterations is: 716
The training iterations is: 717
The training iterations is: 718
The training iterations is: 719
The training iterations is: 720
The training iterations is: 721
The training iterations is: 722
The training iterations is: 723
The training iterations is: 724
The training iterations is: 725
The training iterations is: 726
The training iterations is: 727
The training iterations is: 728
The training iterations is: 729
The training iterations is: 730
The training iterations is: 731
The training iterations is: 732
The training iterations is: 733
The training iterations is: 734
The training iterations is: 735
The training iterat

The training iterations is: 952
The training iterations is: 953
The training iterations is: 954
The training iterations is: 955
The training iterations is: 956
The training iterations is: 957
The training iterations is: 958
The training iterations is: 959
The training iterations is: 960
The training iterations is: 961
The training iterations is: 962
The training iterations is: 963
The training iterations is: 964
The training iterations is: 965
The training iterations is: 966
The training iterations is: 967
The training iterations is: 968
The training iterations is: 969
The training iterations is: 970
The training iterations is: 971
The training iterations is: 972
The training iterations is: 973
The training iterations is: 974
The training iterations is: 975
The training iterations is: 976
The training iterations is: 977
The training iterations is: 978
The training iterations is: 979
The training iterations is: 980
The training iterations is: 981
The training iterations is: 982
The trai

"\ntruth, example = read_file_pair(val_truth_ds_pairs[1])\ny_reco = model.eval(feed_dict={train_flag: False,\n                               x: example.reshape(1, -1, 1)},\n                    session=sess).flatten()\n\nprint('difference between truth and example (first 20 elements)')\nprint(truth.flatten()[:20] - example.flatten()[:20])\nprint('difference between truth and reconstruction (first 20 elements)')\nprint(truth.flatten()[:20] - y_reco[:20])\n\nprint('writting output audio files')\nlibrosa.output.write_wav('full_train_validation_true.wav',\n                         y=truth.flatten(), sr=true_br)\nlibrosa.output.write_wav('full_train_validation_ds.wav',\n                         y=example.flatten(), sr=true_br)\nlibrosa.output.write_wav('full_train_validation_reco.wav',\n                         y=y_reco, sr=true_br)\n                         \n"