In [1]:
import data, data_utils
import importlib as I
#I.reload(data_utils)

In [2]:
data_ctl, idx_words, idx_phonemes = data.load_data()
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_phonemes, idx_words)

In [3]:
# parameters 
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
batch_size = 128
xvocab_size = len(data_ctl['idx2pho'].keys())  # 27
yvocab_size = len(data_ctl['idx2alpha'].keys())  # 70
emb_dim = 128

In [4]:
import tensorflow as tf
import numpy as np

In [5]:
tf.reset_default_graph()

In [6]:
enc_ip = [ tf.placeholder(dtype=tf.int32,
                       shape = (None,),
                       name = 'ei_{}'.format(i)) for i in range(xseq_len) ]
# alternatively
#  enc_ip = tf.placeholder(shape=[None,xseq_len], dtype=tf.int32, name='enc_ip')
labels = [ tf.placeholder(dtype=tf.int32,
                       shape = (None,),
                       name = 'ei_{}'.format(i)) for i in range(yseq_len) ]
# alternatively
#  labels = tf.placeholder(shape=[None,yseq_len], dtype=tf.int32, name='labels')
dec_ip = [ tf.zeros_like(enc_ip[0], dtype=tf.int32, name='GO')] + labels[:-1]

In [7]:
keep_prob = tf.placeholder(tf.float32)
basic_cell = tf.nn.rnn_cell.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(emb_dim, state_is_tuple=True),
        output_keep_prob=keep_prob)
stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([basic_cell]*3, state_is_tuple=True)


with tf.variable_scope('decoder') as scope:
    decode_outputs, decode_states = tf.nn.seq2seq.embedding_rnn_seq2seq(enc_ip,dec_ip, stacked_lstm,
                                        xvocab_size, yvocab_size, emb_dim)
    scope.reuse_variables()
    # testing
    decode_outputs_test, decode_states_test = tf.nn.seq2seq.embedding_rnn_seq2seq(
        enc_ip, dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim,
        feed_previous=True)

In [8]:
# we weight the losses based on timestep of decoder output
loss_weights = [tf.ones_like(l, dtype=tf.float32) for l in labels] # gives [1, 1, ..., 1,1] - equal weights
loss = tf.nn.seq2seq.sequence_loss(decode_outputs, labels, loss_weights, yvocab_size)
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)

In [9]:
def get_feed(X, Y):
    feed_dict = {enc_ip[t]: X[t] for t in range(xseq_len)}
    feed_dict.update({labels[t]: Y[t] for t in range(yseq_len)})
    return feed_dict

In [None]:
# training session
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    # create a generator
    train_batch_gen = data_utils.batch_gen(trainX, trainY, batch_size)
    X, Y = train_batch_gen.__next__()
    feed_dict = get_feed(X, Y)
    feed_dict[keep_prob] = 0.5
    _, out = sess.run([train_op, loss], feed_dict)

In [None]:
rtest_batch_gen = data_utils.rand_batch_gen(testX, testY, batch_size)
for i in range(10):
    batchX, batchY = rtest_batch_gen.__next__()
    print(i,batchX[40], batchY[40])

In [None]:
data_utils.decode_word(batchX[12], data_ctl['idx2alpha'])

In [None]:
data_utils.decode_phonemes(batchY[12], data_ctl['idx2pho'])

## Training

In [10]:
def train_batch(train_batch_gen):
    # get batches
    batchX, batchY = train_batch_gen.__next__()
    # build feed
    feed_dict = get_feed(batchX, batchY)
    feed_dict[keep_prob] = 0.5
    _, loss_v = sess.run([train_op, loss], feed_dict)
    return loss_v

In [11]:
def eval_step(eval_batch_gen):
    # get batches
    batchX, batchY = eval_batch_gen.__next__()
    # build feed
    feed_dict = get_feed(batchX, batchY)
    feed_dict[keep_prob] = 1.
    loss_v, dec_op_v = sess.run([loss, decode_outputs_test], feed_dict)
    # dec_op_v is a list; also need to transpose 0,1 indices
    dec_op_v = np.array(dec_op_v).transpose([1,0,2])
    return loss_v, dec_op_v, batchX, batchY

In [12]:
def eval_batch(eval_batch_gen, num_batches):
    losses, predict_loss = [], []
    for i in range(num_batches):
        loss_v, dec_op_v, batchX, batchY = eval_step(eval_batch_gen)
        losses.append(loss_v)
        for j in range(len(dec_op_v)):
            real = batchX.T[j]
            predict = np.argmax(dec_op_v, axis=2)[j]
            predict_loss.append(all(real == predict))
    return np.mean(losses), np.mean(predict_loss)

In [None]:
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    batchX, batchY = train_batch_gen.__next__()
    feed_dict = get_feed(batchX, batchY)
    feed_dict[keep_prob] = 1.
    loss_v, dec_op_val = sess.run([loss, decode_outputs_test], feed_dict)

In [None]:
dec_op_val[0].shape

In [None]:
#train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)
a,b = train_batch_gen.__next__()
print(a.shape,b.shape)

In [13]:
import sys

In [14]:
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 16)
train_eval_batch_gen = data_utils.rand_batch_gen(trainX, trainY, 16)
train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, 128)

sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(100000):
    try:
        train_batch(train_batch_gen)
        if i % 5000 == 0:
            val_loss, val_predict = eval_batch(val_batch_gen, 16)
            train_loss, train_predict = eval_batch(train_eval_batch_gen, 16)
            print('\nIteration #{}'.format(i))
            print('val   loss : {}'.format(val_loss))
            print('train loss : {}'.format(train_loss))

            #print("val loss   : {0}, val predict   = {1}%".format(val_loss, val_predict * 100))
            #print("train loss : {0}, train predict = {1}%".format(train_loss, train_predict * 100))

            sys.stdout.flush()
    except KeyboardInterrupt:
        print("interrupted by user at {}".format(i))
        break

Instructions for updating:
Use `tf.global_variables_initializer` instead.

Iteration #0
val   loss : 3.255979061126709
train loss : 3.2577998638153076

Iteration #5000
val   loss : 0.6567713022232056
train loss : 0.6086580753326416

Iteration #10000
val   loss : 0.4550465941429138
train loss : 0.38915368914604187

Iteration #15000
val   loss : 0.30922865867614746
train loss : 0.3227866291999817

Iteration #20000
val   loss : 0.32482245564460754
train loss : 0.2962823212146759

Iteration #25000
val   loss : 0.2795450687408447
train loss : 0.2243996113538742

Iteration #30000
val   loss : 0.2611120045185089
train loss : 0.23757104575634003

Iteration #35000
val   loss : 0.2358589470386505
train loss : 0.19370242953300476

Iteration #40000
val   loss : 0.2391183078289032
train loss : 0.18203002214431763

Iteration #45000
val   loss : 0.20033389329910278
train loss : 0.15348316729068756

Iteration #50000
val   loss : 0.24642911553382874
train loss : 0.1512463390827179
interrupted by user a

In [15]:
test_batch_gen = data_utils.rand_batch_gen(testX, testY, 16)

In [38]:
eval_loss, output, X, Y = eval_step(test_batch_gen)
model_op = np.argmax(output, axis = 2)

In [40]:
print('{0: <23} {1: <20} {2: <20}\n'.format('pronunciation','real spelling','model spelling'))
for i in range(len(X)):
    pronounce = data_utils.decode_word(X.T[i], data_ctl['idx2pho'])
    real_spell = data_utils.decode_phonemes( Y.T[i], data_ctl['idx2alpha'])
    model_spell = data_utils.decode_phonemes(model_op[i], data_ctl['idx2alpha'])
    print('{0: <23} {1:<19}  {2:<20}'.format(pronounce, real_spell[::2], model_spell[::2]))

pronunciation           real spelling        model spelling      

MAH0KLAE1NAH0HHAE0N     mcclanahan           mcclanahan          
BAH1NDAH0SWEH2R         bundeswehr           bundessware         
IH0NRIY1KWEH0Z          enriquez             enriques            
LOW0NEH1ROW0            lonero               lonero              
BEH1NAH0VIY0DEH0S       benevides            benevedese          
DIH0STRAH1KTIH0V        destructive          distructive         
CHAO1NCHUW0LIY0         cianciulli           chonchully          
PRAA2KLAH0MEY1SHAH0NZ   proclamations        proclamations       
FRIY1THIH1NGKER0        freethinker          freethinker         
KAH0NTEH1MPTAH0BAH0L    contemptible         contemptable        
PRIY1SKUW2LER0          preschooler          prescouler          
TOW2MIY0IY1CHIY0        tomiichi             tomiacio            
PAA1STAH0L              postle               postel              
GUW2SIY0AO1RAH0         gusciora             gusiora             
STAE1GNEY