In [1]:
%matplotlib inline
import os
import tensorflow as tf
import numpy as np
import pandas as pd
from scipy.fftpack import dct
import matplotlib.pyplot as plt
from python_speech_features.base import mfcc
from seq2seq_model import seq2seq_model
import scipy.io.wavfile
import time

In [2]:
root = 'gs://wsj-data/wsj0/'
keep_prob = 0.8
max_input_len = 1000
max_output_len = 100
rnn_size = 256
num_layers = 2
batch_size = 2
learning_rate = 0.0005
num_epochs = 5

learning_rate_decay = 0.95
min_learning_rate = 0.0005
display_step = 1 # Check training loss after every display_step batches
stop_early = 0 
stop = 3 # If the update loss does not decrease in 3 consecutive update checks, stop training

checkpoint = "gs://wsj-data/best_model.ckpt" 

In [3]:
vocab = np.asarray(list(" '+-.ABCDEFGHIJKLMNOPQRSTUVWXYZ_") + ['<GO>', '<EOS>'])
vocab_to_int = {}

for ch in vocab:
    vocab_to_int[ch] = len(vocab_to_int)

In [4]:
# A custom class inheriting tf.gfile.Open for providing seek with whence
class FileOpen(tf.gfile.Open):
    def seek(self, position, whence = 0):
        if (whence == 0):
            tf.gfile.Open.seek(self, position)
        elif (whence == 1):
            tf.gfile.Open.seek(self, self.tell() + position)
        else:
            raise FileError

In [5]:
# https://github.com/zszyellow/WER-in-python/blob/master/wer.py
def wer(r, h):
    """
    This is a function that calculate the word error rate in ASR.
    You can use it like this: wer("what is it".split(), "what is".split()) 
    """
    #build the matrix
    d = numpy.zeros((len(r)+1)*(len(h)+1), dtype=numpy.uint8).reshape((len(r)+1, len(h)+1))
    for i in range(len(r)+1):
        for j in range(len(h)+1):
            if i == 0: d[0][j] = j
            elif j == 0: d[i][0] = i
    for i in range(1,len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                d[i][j] = d[i-1][j-1]
            else:
                substitute = d[i-1][j-1] + 1
                insert = d[i][j-1] + 1
                delete = d[i-1][j] + 1
                d[i][j] = min(substitute, insert, delete)
    return float(d[len(r)][len(h)]) / len(r) * 100

In [6]:
out_file = tf.gfile.Open(root + 'transcripts/wsj0/wsj0.trans')
numcep = 13

def get_next_input():
    trans = out_file.readline()
    cont, file = trans.split('(')
    file = file[:-2]
    sample_rate, signal = scipy.io.wavfile.read(FileOpen(root + 'wav/' + file.rstrip('\n'), 'rb'))
    Y = [vocab_to_int[c] for c in list(cont)]
    X = mfcc(signal, sample_rate, numcep=numcep)
    return X, Y

def pad_sentence_batch(sentence_batch):
    """Pad sentences with <EOS> so that each sentence of a batch has the same length"""
    max_sentence = max([len(sentence) for sentence in sentence_batch]) + 1
    return [sentence + [vocab_to_int['<EOS>']] * (max_sentence - len(sentence)) for sentence in sentence_batch]

def get_next_batch():
    input_batch = np.zeros((batch_size, max_input_len, numcep), 'float32')
    output_batch = [] # Variable shape of maximum length string 
    input_batch_length = np.zeros((batch_size), 'int')
    output_batch_length = np.zeros((batch_size), 'int')
    for i in range(batch_size):
        inp, out = get_next_input()
        input_batch[i, :inp.shape[0]] = inp
        output_batch.append(out)
        input_batch_length[i] = inp.shape[0]
        output_batch_length[i] = len(out) + 1
    output_batch = np.asarray(pad_sentence_batch(output_batch))
    return (input_batch, output_batch, input_batch_length, output_batch_length)

In [7]:
# Make a graph and it's session
train_graph = tf.Graph()
train_session = tf.InteractiveSession(graph=train_graph)

# Set the graph to default to ensure that it is ready for training
with train_graph.as_default():
    model_input = tf.placeholder(tf.float32, [batch_size, max_input_len, numcep], name='model_input')
    model_output = tf.placeholder(tf.int32, [batch_size, None], name='model_output')
    input_lengths = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
    output_lengths = tf.placeholder(tf.int32, [batch_size], name='output_lengths')
    learning_rate_tensor = tf.placeholder(tf.float32, name='learning_rate')
    
    # Create the training and inference logits
    training_logits, inference_logits = seq2seq_model(input_data=model_input,
                                                      target_data=model_output,
                                                      keep_prob=keep_prob,
                                                      input_lengths=input_lengths,
                                                      output_lengths=output_lengths,
                                                      max_output_length=max_output_len,
                                                      vocab_size=len(vocab_to_int),
                                                      rnn_size=rnn_size,
                                                      num_layers=num_layers,
                                                      vocab_to_int=vocab_to_int,
                                                      batch_size=batch_size)
    
    # Create tensors for the training logits and inference logits
    training_logits = tf.identity(training_logits.rnn_output, 'logits')
    inference_logits = tf.identity(inference_logits.sample_id, name='predictions')
    
    # Create the weights for sequence_loss
    masks = tf.sequence_mask(output_lengths, tf.reduce_max(output_lengths), dtype=tf.float32, name='masks')

    with tf.name_scope("optimization"):
        # Loss function
        cost = tf.contrib.seq2seq.sequence_loss(
            training_logits,
            model_output,
            masks)

        # Optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate_tensor)

        # Gradient Clipping
        gradients = optimizer.compute_gradients(cost)
        capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]
        train_op = optimizer.apply_gradients(capped_gradients)

In [8]:
with tf.Session(graph=train_graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    # If we want to continue training a previous session
    #loader = tf.train.import_meta_graph("./" + checkpoint + '.meta')
    #loader.restore(sess, checkpoint)
    
    for epoch_i in range(1, num_epochs+1):
        batch_loss = 0
        batch_i = 0
        out_file.seek(0)
        while (True):
            batch_i += 1
            try:
                input_batch, output_batch, input_lengths_batch, output_lengths_batch = get_next_batch()
            except:
                print("Epoch {} completed".format(epoch_i))
                break
                
            start_time = time.time()
            _, loss = sess.run(
                [train_op, cost],
                {model_input: input_batch,
                 model_output: output_batch,
                 learning_rate_tensor: learning_rate,
                 output_lengths: output_lengths_batch,
                 input_lengths: input_lengths_batch})

            batch_loss += loss
            end_time = time.time()
            batch_time = end_time - start_time
            
            if batch_i % display_step == 0 and batch_i > 0:
                print('Epoch {:>3}/{} Batch {:>4} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
                      .format(epoch_i,
                              num_epochs, 
                              batch_i,
                              batch_loss / display_step, 
                              batch_time*display_step))
                batch_loss = 0            
                    
        # Reduce learning rate, but not below its minimum value
        learning_rate *= learning_rate_decay
        if learning_rate < min_learning_rate:
            learning_rate = min_learning_rate

Epoch   1/5 Batch    1 - Loss:  3.543, Seconds: 6.85
Epoch   1/5 Batch    2 - Loss:  3.354, Seconds: 7.04
Epoch   1/5 Batch    3 - Loss:  3.151, Seconds: 8.81
Epoch   1/5 Batch    4 - Loss:  3.074, Seconds: 10.36
Epoch   1/5 Batch    5 - Loss:  3.003, Seconds: 6.58
Epoch   1/5 Batch    6 - Loss:  2.932, Seconds: 6.95
Epoch   1/5 Batch    7 - Loss:  2.986, Seconds: 6.72
Epoch   1/5 Batch    8 - Loss:  3.016, Seconds: 6.84
Epoch   1/5 Batch    9 - Loss:  2.976, Seconds: 7.21
Epoch   1/5 Batch   10 - Loss:  3.012, Seconds: 7.04
Epoch   1/5 Batch   11 - Loss:  2.986, Seconds: 7.24
Epoch   1/5 Batch   12 - Loss:  2.906, Seconds: 7.83
Epoch   1/5 Batch   13 - Loss:  2.998, Seconds: 7.97
Epoch   1/5 Batch   14 - Loss:  2.907, Seconds: 7.61
Epoch   1/5 Batch   15 - Loss:  2.897, Seconds: 7.35
Epoch   1/5 Batch   16 - Loss:  2.963, Seconds: 7.00
Epoch   1/5 Batch   17 - Loss:  2.896, Seconds: 7.32
Epoch   1/5 Batch   18 - Loss:  2.970, Seconds: 6.83
Epoch   1/5 Batch   19 - Loss:  2.896, Second

KeyboardInterrupt: 