In [69]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

In [70]:
with open('anna.txt', 'r') as f:
    text = f.read()

vocab = set(text)
vocab_to_int = { c: i for i, c in enumerate(vocab) }
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text])

In [71]:
chars

array([ 8, 66, 79, ..., 67, 49, 77])

In [72]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    
    
    slice_size = batch_size * num_steps
    n_batches = len(chars) // slice_size
    
    x = chars[:slice_size * n_batches]
    y = chars[1:slice_size * n_batches + 1]
    
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    
    split_idx = int( split_frac * n_batches)
    x_train, y_train = x[:split_idx], y[:split_idx]
    x_val, y_val = y[split_idx:], y[split_idx:]
    
    return x_train, y_train, x_val, y_val

In [73]:
x_train, y_train, x_val, y_val = split_data(chars, 10, 50)

In [74]:
x_train.shape

(10, 198500)

In [75]:
def get_batch(arr, num_steps):
    batch_size, slice_size = arr[0].shape
    
    n_batches = int(slice_size / num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arr]

In [76]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
    
    # When we're using this network for sampling later, we'll be passing in
    # one character at a time, so providing an option for that
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
    targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
    
    # Keep probability placeholder for drop out layers
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # One-hot encoding the input and target characters
    x_one_hot = tf.one_hot(inputs, num_classes)
    y_one_hot = tf.one_hot(targets, num_classes)

    ### Build the RNN layers
    # Use a basic LSTM cell
    cell = tf.contrib.rnn.MultiRNNCell([
        tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(lstm_size), output_keep_prob=keep_prob) for _ in range(num_layers)
    ])
    initial_state = cell.zero_state(batch_size, tf.float32)

    ### Run the data through the RNN layers
    # This makes a list where each element is on step in the sequence
    rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(x_one_hot, num_steps, 1)]
    
    # Run each sequence step through the RNN and collect the outputs
    outputs, state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=initial_state)
    final_state = state
    
    # Reshape output so it's a bunch of rows, one output row for each step for each batch
    seq_output = tf.concat(outputs, axis=1)
    output = tf.reshape(seq_output, [-1, lstm_size])
    
    # Now connect the RNN outputs to a softmax layer
    with tf.variable_scope('softmax'):
        softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1))
        softmax_b = tf.Variable(tf.zeros(num_classes))
    
    # Since output is a bunch of rows of RNN cell outputs, logits will be a bunch
    # of rows of logit outputs, one for each step and batch
    logits = tf.matmul(output, softmax_w) + softmax_b
    
    # Use softmax to get the probabilities for predicted characters
    preds = tf.nn.softmax(logits, name='predictions')
    
    # Reshape the targets to match the logits
    y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped)
    cost = tf.reduce_mean(loss)

    # Optimizer for training, using gradient clipping to control exploding gradients
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
    train_op = tf.train.AdamOptimizer(learning_rate)
    optimizer = train_op.apply_gradients(zip(grads, tvars))
    
    # Export the nodes
    # NOTE: I'm using a namedtuple here because I think they are cool
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

In [77]:
batch_size = 100
num_steps = 100 
lstm_size = 512
num_layers = 2
learning_rate = 0.001
keep_prob = 0.5

In [78]:
epochs = 20
# Save every N iterations
save_every_n = 200
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/______.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: keep_prob,
                    model.initial_state: new_state}
            batch_loss, new_state, _ = sess.run([model.cost, model.final_state, model.optimizer], 
                                                 feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
        
            
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    batch_loss, new_state = sess.run([model.cost, model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                saver.save(sess, "checkpoints/i{}_l{}_v{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))             

Epoch 1/20  Iteration 1/3960 Training loss: 4.4183 10.9695 sec/batch
Epoch 1/20  Iteration 2/3960 Training loss: 4.3781 6.9920 sec/batch
Epoch 1/20  Iteration 3/3960 Training loss: 4.2063 7.4278 sec/batch
Epoch 1/20  Iteration 4/3960 Training loss: 4.3075 7.1437 sec/batch
Epoch 1/20  Iteration 5/3960 Training loss: 4.2529 7.2229 sec/batch
Epoch 1/20  Iteration 6/3960 Training loss: 4.1978 7.4531 sec/batch
Epoch 1/20  Iteration 7/3960 Training loss: 4.1327 7.4209 sec/batch
Epoch 1/20  Iteration 8/3960 Training loss: 4.0638 7.4006 sec/batch
Epoch 1/20  Iteration 9/3960 Training loss: 3.9996 7.4449 sec/batch
Epoch 1/20  Iteration 10/3960 Training loss: 3.9476 7.8308 sec/batch
Epoch 1/20  Iteration 11/3960 Training loss: 3.9001 7.1966 sec/batch
Epoch 1/20  Iteration 12/3960 Training loss: 3.8597 7.2074 sec/batch
Epoch 1/20  Iteration 13/3960 Training loss: 3.8234 7.2311 sec/batch
Epoch 1/20  Iteration 14/3960 Training loss: 3.7924 6.9860 sec/batch
Epoch 1/20  Iteration 15/3960 Training los

KeyboardInterrupt: 