# The script to train the deep neural networks

This module can use for processing training. You need modify the number of epochs you want, the size of batch, the initial learning rate, the devay factor for changing learning rate, and number of epochs to decay learning rate. And you can modify output directory you want and input directory you have.

In [3]:
# -*- coding: utf-8 -*-
""" The script to train the deep neural networks

This module can use for processing training. You need modify the number of 
epochs you want, the size of batch, the initial learning rate, the devay factor 
for changing learning rate, and number of epochs to decay learning rate. And you 
can modify output directory you want and input directory you have.

################################################################################
# Author: Weikun Han <weikunhan@gmail.com>
# Crate Date: 03/10/2018        
# Update:
# Reference: https://github.com/jhetherly/EnglishSpeechUpsampler
################################################################################
"""

import numpy as np
import os
import librosa
import tensorflow as tf
from converter_generator import bitrates_and_waveforms
from converter_generator import get_original_noise_pairs
from converter_generator import next_batch
from converter_generator import random_batch
from model import audio_u_net_dnn
from optimizer import setup_optimizer
from optimizer import learing_rate_scheduling

# Please modify input path  to locate you file
DATASETS_ROOT_DIR = './datasets'
FILE_NAME_LISTS_DIR = os.path.join(DATASETS_ROOT_DIR, 'final_dataset')
OUTPUT_TENSORBOARD_DIR = './output'
OUTPUT_MODEL_DIR = './output/model'

# Please modify setting for training
n_epochs = 10
batch_size = 8
initial_learning_rate = 0.001
decay_factor = 0.01
n_epochs_per_decay = 2

# Check location to save datasets
if not os.path.exists(OUTPUT_TENSORBOARD_DIR):
    os.makedirs(OUTPUT_TENSORBOARD_DIR)
if not os.path.exists(OUTPUT_MODEL_DIR):
    os.makedirs(OUTPUT_MODEL_DIR)

print('-------------------------Begining data input---------------------------')

#############
# DATA IMPORT
#############

train_original_noise_pairs = get_original_noise_pairs(FILE_NAME_LISTS_DIR,
                                                      'train')
val_original_noise_pairs = get_original_noise_pairs(FILE_NAME_LISTS_DIR,
                                                    'validation')

# Selet first original noise pair, return first is bit rate pair and second
# is the waveform pair
br_pair, wf_pair = bitrates_and_waveforms(train_original_noise_pairs[0])

# Get original bit rate and waveform
original_bitrate = br_pair[0]
original_waveform = wf_pair[0]

# reshape for mono waveforms
original_waveform = original_waveform.reshape((-1, 1))

# Number of sample for each epoch train
sample_per_epoch = len(train_original_noise_pairs)

print('Number of epochs: {}'.format(n_epochs))
print('Samples per epoch: {}'.format(sample_per_epoch))
print('Batch size: {}'.format(batch_size))
print('-------------------------Processing training---------------------------')

##################
# MODEL DEFINITION
##################

train_flag, x, y_pred = audio_u_net_dnn(original_waveform.dtype,
                                        original_waveform.shape)

# placeholder for the true waveform
y = tf.placeholder(original_waveform.dtype,
                   shape=x.get_shape(),
                   name='y')

###############
# LOSS FUNCTION
###############

with tf.name_scope('mse'):
    mse = tf.reduce_mean(tf.square(tf.subtract(y_pred, y)))

tf.summary.scalar('mse', mse)

######################
# OPTIMIZATION ROUTINE
######################

# Variable that affect learning rate.
n_batches_per_epoch = float(sample_per_epoch)/ batch_size
decay_steps = int(n_batches_per_epoch * n_epochs_per_decay)

# Decay the learning rate based on the number of steps.
learning_rate, global_step = learing_rate_scheduling(initial_learning_rate,
                                                     decay_steps,
                                                     decay_factor)

# Setup the training operator
min_args = {'global_step': global_step}
training_op = setup_optimizer(learning_rate,
                              mse,
                              tf.train.AdamOptimizer,
                              minimize_args=min_args)

##################
# TRAINING PROCESS
##################

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# initialize the variables for the session
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    
    # initialize tensorboard file writers
    merged = tf.summary.merge_all()
    tensorboard_path = os.path.join(OUTPUT_TENSORBOARD_DIR, 'tensorboard')
    train_writer = tf.summary.FileWriter(tensorboard_path, sess.graph)
    model_name = y_pred.name[: 15]
    val_loss_file = open('val_loss_log.txt', 'w')
    train_loss_file = open('train_loss_log.txt', 'w')

    # The number of batchs in each epoch train
    n_batchs = int(sample_per_epoch / batch_size)

    # The number of iteration
    n_iterations = n_epochs * n_batchs

    # Start train loop
    for i in range(n_iterations):
        
        print('The training iterations is: {}'.format(i))
        
        # Random generate next batch in the training file
        train_batch = random_batch(batch_size, train_original_noise_pairs)

        # Setting a flag to cheack iteration go to next epoch
        new_epoch_flag = ((i + 1) % n_batchs == 0)

        # Record epoch number
        if new_epoch_flag:
            epoch_number = int((i + 1) / n_batchs)

        # Start validation if enter new epoch
        if new_epoch_flag:

            print('Calculating validation loss by total {} iterations'.format(
                len(val_original_noise_pairs) / batch_size))

            total_val_loss = 0
            val_count = 0

            # Find each validation loss for same batch size
            for val_batch in next_batch(batch_size, val_original_noise_pairs):
                loss = sess.run([mse],
                                feed_dict={train_flag: False,
                                           x: val_batch[1],
                                           y: val_batch[0]})
                total_val_loss += np.mean(loss)
                val_count += 1

            # Calculate the the average validaton loss
            val_loss = total_val_loss / val_count

            print('Epoch is: {}, Validation Loss is: {}'.format(epoch_number, 
                                                                val_loss))
            
            # Record the average validation loss for each epoch
            val_loss_file.write(
                'Epoch is: {}, Validation Loss is:{}\n'.format(epoch_number, 
                                                               val_loss))

        # Start recording traning loss if enter a new epoch
        if new_epoch_flag:
            summary, _, train_loss = sess.run([merged, training_op, mse],
                                              feed_dict={train_flag: True,
                                                         x: train_batch[1],
                                                         y: train_batch[0]})
                
            print('Epoch is: {}, Training Loss is: {}'.format(epoch_number, 
                                                              train_loss))
                
            train_writer.add_summary(summary, i)
                
            # Record the training loss for each epoch
            train_loss_file.write(
                'Epoch is: {}, Training Loss is: {}\n'.format(epoch_number, 
                                                              train_loss))
                
            # Store the training model every 3 epoch
            if epoch_number % 3 == 0:
                model_path = os.path.join(
                    OUTPUT_MODEL_DIR, '{}_{}.ckpt'.format(model_name, 
                                                          epoch_number))
                save_path = saver.save(sess, model_path)

        # Run tensorflow for each train batch
        sess.run(training_op, 
                 feed_dict={train_flag: True,
                            x: train_batch[1],
                            y: train_batch[0]})

    val_loss_file.close()
    train_loss_file.close()

    # Save the variables to disk.
    model_path = os.path.join(
        OUTPUT_MODEL_DIR, '{}_final.ckpt'.format(model_name))
    save_path = saver.save(sess, model_path)

print("Model checkpoints will be saved in file: {}".format(save_path))
print('------------------------Finished model training------------------------')

-------------------------Begining data input---------------------------
Number of epochs: 10
Samples per epoch: 814
Batch size: 8
-------------------------Processing training---------------------------
The network summary for audio_u_net_dnn
    input: [None, 80000, 1]
    downsample layer: [None, 39998, 8]
    downsample layer: [None, 19998, 16]
    downsample layer: [None, 9998, 32]
    downsample layer: [None, 4998, 64]
    downsample layer: [None, 2498, 128]
    downsample layer: [None, 1248, 256]
    downsample layer: [None, 623, 512]
    downsample layer: [None, 311, 1024]
    bottleneck layer: [None, 154, 2048]
    upsample layer: [None, 304, 2048]
    upsample layer: [None, 604, 1024]
    upsample layer: [None, 1204, 512]
    upsample layer: [None, 2404, 256]
    upsample layer: [None, 4804, 128]
    upsample layer: [None, 9604, 64]
    upsample layer: [None, 19204, 32]
    upsample layer: [None, 38404, 16]
    restack layer: [None, 40002, 15]
    final convolution layer: [None

InvalidArgumentError: You must feed a value for placeholder tensor 'audio_u_net_dnn_1/Placeholder' with dtype bool
	 [[Node: audio_u_net_dnn_1/Placeholder = Placeholder[dtype=DT_BOOL, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'audio_u_net_dnn_1/Placeholder', defined at:
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/asyncio/base_events.py", line 422, in run_forever
    self._run_once()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/asyncio/base_events.py", line 1432, in _run_once
    handle._run()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2903, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-d35328a1f840>", line 83, in <module>
    original_waveform.shape)
  File "/Users/WeikunHan/Documents/GitHub/Weikun-Zhengshuang/model.py", line 593, in audio_u_net_dnn
    train_flag = tf.placeholder(tf.bool)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1680, in placeholder
    return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3141, in _placeholder
    "Placeholder", dtype=dtype, shape=shape, name=name)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
    op_def=op_def)
  File "/Users/WeikunHan/anaconda3/envs/over-the-air_SRA/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'audio_u_net_dnn_1/Placeholder' with dtype bool
	 [[Node: audio_u_net_dnn_1/Placeholder = Placeholder[dtype=DT_BOOL, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
