In [None]:
%matplotlib inline
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from scipy.fftpack import dct
import matplotlib.pyplot as plt
from python_speech_features.base import mfcc
import seq2seq_model

In [None]:
root = './'
keep_prob = 0.8
max_output_length = 
rnn_size = 256
num_layers = 2
batch_size = 16

In [None]:
vocab = np.asarray(list(" '+-.ABCDEFGHIJKLMNOPQRSTUVWXYZ_"))
vocab_to_int = {}

for ch in vocab:
    vocab_to_int[ch] = len(vocab_to_int)

vocab_to_int['<GO>'] = len(vocab_to_int)
vocab_to_int['<EOS>'] = len(vocab_to_int)


def onehot(x):
    x = np.asarray(x)
    return np.tile(x, (32, 1)).T == vocab # 32 = vocab length

In [None]:
out_file = tf.gfile.Open(root + 'wsj0/transcripts/wsj0/wsj0.trans')

def get_next_input():
    trans = out_file.readline()
    cont, file = trans.split('(')
    file = file[:-2]
    sample_rate, signal = scipy.io.wavfile.read(FileOpen(root + file.rstrip('\n'), 'rb'))
    X = onehot(list(cont))
    Y = mfcc(signal, sample_rate)

In [None]:
# A custom class inheriting tf.gfile.Open for providing seek with whence
class FileOpen(tf.gfile.Open):
    def seek(self, position, whence = 0):
        if (whence == 0):
            tf.gfile.Open.seek(self, position)
        elif (whence == 1):
            tf.gfile.Open.seek(self, self.tell() + position)
        else:
            raise FileError

In [None]:
# Make a graph and it's session
train_graph = tf.Graph()
train_session = tf.InteractiveSession(graph=train_graph)

# Set the graph to default to ensure that it is ready for training
with train_graph.as_default():
    model_x = tf.placeholder()
    model_y = tf.placeholder()
    texts_lengths = tf.placeholder()
    summaries_lengths = tf.placeholder()
    
    # Create the training and inference logits
    training_logits, inference_logits = seq2seq_model(input_data=model_x,
                                                      target_data=model_y,
                                                      keep_prob=keep_prob,
                                                      text_length=texts_lengths,
                                                      summary_length=summaries_lengths,
                                                      max_output_length=max_output_length,
                                                      vocab_size=len(vocab),
                                                      rnn_size=rnn_size,
                                                      num_layers=num_layers,
                                                      vocab_to_int=vocab_to_int,
                                                      batch_size=batch_size)
    
    # Create tensors for the training logits and inference logits
    training_logits = tf.identity(training_logits.rnn_output, 'logits')
    inference_logits = tf.identity(inference_logits.sample_id, name='predictions')
    
    # Create the weights for sequence_loss
    masks = tf.sequence_mask(summary_length, max_summary_length, dtype=tf.float32, name='masks')

    with tf.name_scope("optimization"):
        # Loss function
        cost = tf.contrib.seq2seq.sequence_loss(
            training_logits,
            targets,
            masks)

        # Optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate)

        # Gradient Clipping
        gradients = optimizer.compute_gradients(cost)
        capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
        train_op = optimizer.apply_gradients(capped_gradients)

In [None]:
# Train the Model
learning_rate_decay = 0.95
min_learning_rate = 0.0005
display_step = 20 # Check training loss after every 20 batches
stop_early = 0 
stop = 3 # If the update loss does not decrease in 3 consecutive update checks, stop training
per_epoch = 3 # Make 3 update checks per epoch
update_check = (len(sorted_texts_short)//batch_size//per_epoch)-1

update_loss = 0 
batch_loss = 0
summary_update_loss = [] # Record the update losses for saving improvements in the model

checkpoint = "best_model.ckpt" 
with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    # If we want to continue training a previous session
    #loader = tf.train.import_meta_graph("./" + checkpoint + '.meta')
    #loader.restore(sess, checkpoint)
    
    for epoch_i in range(1, epochs+1):
        update_loss = 0
        batch_loss = 0
        for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
                get_batches(sorted_summaries_short, sorted_texts_short, batch_size)):
            start_time = time.time()
            _, loss = sess.run(
                [train_op, cost],
                {input_data: texts_batch,
                 targets: summaries_batch,
                 lr: learning_rate,
                 summary_length: summaries_lengths,
                 text_length: texts_lengths,
                 keep_prob: keep_probability})

            batch_loss += loss
            update_loss += loss
            end_time = time.time()
            batch_time = end_time - start_time

            if batch_i % display_step == 0 and batch_i > 0:
                print('Epoch {:>3}/{} Batch {:>4}/{} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
                      .format(epoch_i,
                              epochs, 
                              batch_i, 
                              len(sorted_texts_short) // batch_size, 
                              batch_loss / display_step, 
                              batch_time*display_step))
                batch_loss = 0

            if batch_i % update_check == 0 and batch_i > 0:
                print("Average loss for this update:", round(update_loss/update_check,3))
                summary_update_loss.append(update_loss)
                
                # If the update loss is at a new minimum, save the model
                if update_loss <= min(summary_update_loss):
                    print('New Record!') 
                    stop_early = 0
                    saver = tf.train.Saver() 
                    saver.save(sess, checkpoint)

                else:
                    print("No Improvement.")
                    stop_early += 1
                    if stop_early == stop:
                        break
                update_loss = 0
            
                    
        # Reduce learning rate, but not below its minimum value
        learning_rate *= learning_rate_decay
        if learning_rate < min_learning_rate:
            learning_rate = min_learning_rate
        
        if stop_early == stop:
            print("Stopping Training.")
            break