In [24]:
import pickle
import tensorflow as tf
import numpy as np

from nov27.prepare_data import parse_seq

## Build graph

In [32]:
tf.reset_default_graph()

In [42]:
hidden_state_size = 256
summaries_dir = 'nov27/summaries/{}'.format(hidden_state_size)

max_seq_len = 100
prefix = 'nov27/bible/kj{}'.format(max_seq_len)
seq_file = prefix + '.tfrecords'
vocab_file = prefix + '_vocab'

batch_size = 100
num_epochs = 100

with open(vocab_file, 'rb') as fin:
    ch_to_idx = pickle.load(fin)
    num_chars = len(ch_to_idx)    

dataset = tf.contrib.data.TFRecordDataset([seq_file])
dataset = dataset.map(parse_seq)
dataset = dataset.shuffle(1000).repeat(num_epochs).padded_batch(batch_size, [None])

iterator = dataset.make_one_shot_iterator()

In [34]:
cell_state = tf.placeholder(tf.float32, [batch_size, hidden_state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, hidden_state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

with tf.name_scope("input"):
    inputs = iterator.get_next()

xs = inputs[:, :-1]
valid_xs_mask = tf.not_equal(xs, 0)

xs_seq_len = tf.reduce_sum(tf.to_int32(valid_xs_mask), axis=1)
one_hot_xs = tf.one_hot(xs, depth=num_chars)
ys = inputs[:, 1:]
one_hot_ys = tf.one_hot(ys, depth=num_chars)
    
rnn_cell = tf.nn.rnn_cell.LSTMCell(hidden_state_size, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(rnn_cell, one_hot_xs, sequence_length=xs_seq_len, initial_state=init_state)

with tf.variable_scope("output"):
    W_hy = tf.get_variable('W_hy', [hidden_state_size, num_chars])
    B_hy = tf.get_variable('B_hy', [num_chars])
    
logits = tf.tensordot(outputs, W_hy, axes=1) + B_hy
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_ys, logits=logits)
masked_cross_entropy = tf.multiply(cross_entropy, tf.to_float(valid_xs_mask))
summed_entropy = tf.reduce_sum(masked_cross_entropy, 1)
sequence_entropy = tf.divide(summed_entropy, tf.to_float(xs_seq_len))
total_entropy = tf.reduce_mean(sequence_entropy)

train_op = tf.train.AdamOptimizer().minimize(total_entropy)

tf.summary.scalar('entropy', total_entropy)
summary_op = tf.summary.merge_all()  

In [35]:
saver = tf.train.Saver()
with tf.Session() as sesh:
    writer = tf.summary.FileWriter(summaries_dir, sesh.graph)
    sesh.run(tf.global_variables_initializer())
    
    for step in range(5000+1):
        _ = sesh.run([train_op], feed_dict={
                         hidden_state: np.zeros([batch_size, hidden_state_size]),
                         cell_state: np.zeros([batch_size, hidden_state_size])
                     })
        if step % 100 == 0:
            cost, summary, _ = sesh.run([total_entropy, summary_op, train_op], feed_dict={
                hidden_state: np.zeros((batch_size, hidden_state_size)),
                cell_state: np.zeros((batch_size, hidden_state_size))})
            
            writer.add_summary(summary, step)
            print("Step {}\tcost {}".format(step, cost))
        if step % 1000 == 0:
            saver.save(sesh, summaries_dir + "/model.ckpt", global_step=step)
writer.close()

Step 0	cost 4.427514553070068
Step 100	cost 3.019747734069824
Step 200	cost 2.4725136756896973
Step 300	cost 2.2674739360809326
Step 400	cost 2.1300594806671143
Step 500	cost 2.0357859134674072
Step 600	cost 1.9588953256607056
Step 700	cost 1.9420647621154785
Step 800	cost 1.8960859775543213
Step 900	cost 1.8563787937164307
Step 1000	cost 1.775043249130249
Step 1100	cost 1.7879135608673096
Step 1200	cost 1.7991793155670166
Step 1300	cost 1.7858997583389282
Step 1400	cost 1.735626220703125
Step 1500	cost 1.6872413158416748
Step 1600	cost 1.6137914657592773
Step 1700	cost 1.667502522468567
Step 1800	cost 1.6832202672958374
Step 1900	cost 1.627293586730957
Step 2000	cost 1.5775028467178345
Step 2100	cost 1.5419609546661377
Step 2200	cost 1.5398194789886475
Step 2300	cost 1.5281741619110107
Step 2400	cost 1.5041083097457886
Step 2500	cost 1.4725151062011719
Step 2600	cost 1.5203986167907715
Step 2700	cost 1.5064752101898193
Step 2800	cost 1.5421086549758911
Step 2900	cost 1.386820435523986

## Generate output

In [43]:
tf.reset_default_graph()

summaries_dir = 'nov27/summaries/{}/'.format(hidden_state_size)
max_seq_len = 100
prefix = 'nov27/bible/kj{}'.format(max_seq_len)
seq_file = prefix + '.tfrecords'
vocab_file = prefix + '_vocab'

batch_size = 100
num_epochs = 100

with open(vocab_file, 'rb') as fin:
    ch_to_idx = pickle.load(fin)
    num_chars = len(ch_to_idx) 
    idx_to_ch = {v: k for k, v in ch_to_idx.items()}

start_char = '<S>'
stop_char = '</S>'
    
dataset = tf.contrib.data.TFRecordDataset([seq_file])
dataset = dataset.map(parse_seq)
dataset = dataset.shuffle(1000).repeat(num_epochs).padded_batch(batch_size, [None])

iterator = dataset.make_one_shot_iterator()

In [44]:
cell_state = tf.placeholder(tf.float32, [None, hidden_state_size])
hidden_state = tf.placeholder(tf.float32, [None, hidden_state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

input_char = tf.placeholder(tf.float32, [None, None, num_chars])
    
rnn_cell = tf.nn.rnn_cell.LSTMCell(hidden_state_size, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_char, sequence_length=[1], initial_state=init_state)

with tf.variable_scope("output"):
    W_hy = tf.get_variable('W_hy', [hidden_state_size, num_chars])
    B_hy = tf.get_variable('B_hy', [num_chars])
    
logits = tf.tensordot(outputs, W_hy, axes=1) + B_hy
probs_op = tf.nn.softmax(logits=logits)

In [55]:
model_path = summaries_dir + "model.ckpt-5000"

output = ''
saver = tf.train.Saver()

with tf.Session() as sesh:
    saver.restore(sesh, model_path)
    
    # init
    cur_char = start_char
    cur_char_vec = np.zeros((1, 1, num_chars))
    cur_char_pos = ch_to_idx[start_char]
    cur_char_vec[0, 0, cur_char_pos] = 1.
    
    cur_hidden_state = np.zeros((1, hidden_state_size))
    cur_output_state= np.zeros((1, hidden_state_size))
    
    while True:
        probs, cur_state = sesh.run([probs_op, state], feed_dict={
            input_char: cur_char_vec,
            cell_state: cur_output_state,
            hidden_state: cur_hidden_state,
        })
        probs = np.squeeze(probs)
        cur_char_pos = np.random.choice(num_chars, p=probs)
        cur_char = idx_to_ch[cur_char_pos]
        
        if cur_char == stop_char:
            break
            
        output +=  cur_char
        
        cur_char_vec = np.zeros((1, 1, num_chars))
        cur_char_vec[0, 0, cur_char_pos] = 1.
        cur_output_state, cur_hidden_state = cur_state
            
print(output[:-1])

INFO:tensorflow:Restoring parameters from nov27/summaries/256/model.ckpt-5000
 And when the LORD spead to pering to her unto thempinion be
praised her spake.

