# Anna KaRNNa

In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.

This network is based off of Andrej Karpathy's [post on RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) and [implementation in Torch](https://github.com/karpathy/char-rnn). Also, some information [here at r2rt](http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html) and from [Sherjil Ozair](https://github.com/sherjilozair/char-rnn-tensorflow) on GitHub. Below is the general architecture of the character-wise RNN.

<img src="assets/charseq.jpeg" width="500">

In [1]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

First we'll load the text file and convert it into integers for our network to use.

In [2]:
with open('anna.txt', 'r') as f:
    text=f.read()
vocab = set(text)
vocab_to_int = {c: i for i, c in enumerate(vocab)}
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)

In [3]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars[:100]

array([70, 65, 50, 79, 34, 31,  9, 67, 16, 57, 57, 57, 19, 50, 79, 79, 38,
       67, 73, 50, 25, 66, 45, 66, 31,  4, 67, 50,  9, 31, 67, 50, 45, 45,
       67, 50, 45, 66, 39, 31, 22, 67, 31, 61, 31,  9, 38, 67, 48, 28, 65,
       50, 79, 79, 38, 67, 73, 50, 25, 66, 45, 38, 67, 66,  4, 67, 48, 28,
       65, 50, 79, 79, 38, 67, 66, 28, 67, 66, 34,  4, 67, 53, 12, 28, 57,
       12, 50, 38, 35, 57, 57, 14, 61, 31,  9, 38, 34, 65, 66, 28], dtype=int32)

Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.

Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.

The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the `split_frac` keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set.

In [5]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    """ 
    Split character data into training and validation sets, inputs and targets for each set.
    
    Arguments
    ---------
    chars: character array
    batch_size: Size of examples in each of batch
    num_steps: Number of sequence steps to keep in the input and pass to the network
    split_frac: Fraction of batches to keep in the training set
    
    
    Returns train_x, train_y, val_x, val_y
    """
    
    
    slice_size = batch_size * num_steps
    n_batches = int(len(chars) / slice_size)
    
    # Drop the last few characters to make only full batches
    x = chars[: n_batches*slice_size]
    y = chars[1: n_batches*slice_size + 1]
    
    # Split the data into batch_size slices, then stack them into a 2D matrix 
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps
    
    # Split into training and validation sets, keep the virst split_frac batches for training
    split_idx = int(n_batches*split_frac)
    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
    
    return train_x, train_y, val_x, val_y

In [6]:
train_x, train_y, val_x, val_y = split_data(chars, 10, 200)

In [7]:
train_x.shape

(10, 178400)

In [8]:
train_x[:,:10]

array([[70, 65, 50, 79, 34, 31,  9, 67, 16, 57],
       [51, 28, 23, 67, 65, 31, 67, 25, 53, 61],
       [67, 11, 50, 34, 11, 65, 66, 28, 63, 67],
       [53, 34, 65, 31,  9, 67, 12, 53, 48, 45],
       [67, 34, 65, 31, 67, 45, 50, 28, 23, 80],
       [67, 40, 65,  9, 53, 48, 63, 65, 67, 45],
       [34, 67, 34, 53, 57, 23, 53, 35, 57, 57],
       [53, 67, 65, 31,  9,  4, 31, 45, 73, 42],
       [65, 50, 34, 67, 66,  4, 67, 34, 65, 31],
       [31,  9,  4, 31, 45, 73, 67, 50, 28, 23]], dtype=int32)

I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size `batch_size X num_steps`. For example, if we want our network to train on a sequence of 100 characters, `num_steps = 100`. For the next batch, we'll shift this window the next sequence of `num_steps` characters. In this way we can feed batches to the network and the cell states will continue through on each batch.

In [9]:
def get_batch(arrs, num_steps):
    batch_size, slice_size = arrs[0].shape
    
    n_batches = int(slice_size/num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]

In [10]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
        
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    with tf.name_scope('inputs'):
        inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
        x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')
    
    with tf.name_scope('targets'):
        targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
        y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')
        y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # Build the RNN layers
    with tf.name_scope("RNN_layers"):
        lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
        drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
        cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)
    
    with tf.name_scope("RNN_init_state"):
        initial_state = cell.zero_state(batch_size, tf.float32)

    # Run the data through the RNN layers
    with tf.name_scope("RNN_forward"):
        rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(x_one_hot, num_steps, 1)]
        outputs, state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=initial_state)
    
    final_state = state
    
    # Reshape output so it's a bunch of rows, one row for each cell output
    with tf.name_scope('sequence_reshape'):
        seq_output = tf.concat(outputs, axis=1,name='seq_output')
        output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')
    
    # Now connect the RNN putputs to a softmax layer and calculate the cost
    with tf.name_scope('logits'):
        softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),
                               name='softmax_w')
        softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')
        logits = tf.matmul(output, softmax_w) + softmax_b

    with tf.name_scope('predictions'):
        preds = tf.nn.softmax(logits, name='predictions')
    
    
    with tf.name_scope('cost'):
        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')
        cost = tf.reduce_mean(loss, name='cost')

    # Optimizer for training, using gradient clipping to control exploding gradients
    with tf.name_scope('train'):
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
        train_op = tf.train.AdamOptimizer(learning_rate)
        optimizer = train_op.apply_gradients(zip(grads, tvars))
    
    # Export the nodes 
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

## Hyperparameters

Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are `lstm_size` and `num_layers`. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability.

In [11]:
batch_size = 100
num_steps = 100
lstm_size = 512
num_layers = 2
learning_rate = 0.001

## Write out the graph for TensorBoard

In [12]:
model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    file_writer = tf.summary.FileWriter('./logs/3', sess.graph)

## Training

Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by `save_every_n`) I calculate the validation loss and save a checkpoint.

In [13]:
!mkdir -p checkpoints/anna

In [14]:
epochs = 10
save_every_n = 200
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/anna20.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: 0.5,
                    model.initial_state: new_state}
            batch_loss, new_state, _ = sess.run([model.cost, model.final_state, model.optimizer], 
                                                 feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
        
            
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    batch_loss, new_state = sess.run([model.cost, model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                saver.save(sess, "checkpoints/anna/i{}_l{}_{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))

Epoch 1/10  Iteration 1/1780 Training loss: 4.4195 1.4248 sec/batch
Epoch 1/10  Iteration 2/1780 Training loss: 4.3791 0.1471 sec/batch
Epoch 1/10  Iteration 3/1780 Training loss: 4.2382 0.1392 sec/batch
Epoch 1/10  Iteration 4/1780 Training loss: 4.5991 0.1352 sec/batch
Epoch 1/10  Iteration 5/1780 Training loss: 4.4899 0.1354 sec/batch
Epoch 1/10  Iteration 6/1780 Training loss: 4.3978 0.1353 sec/batch
Epoch 1/10  Iteration 7/1780 Training loss: 4.2962 0.1358 sec/batch
Epoch 1/10  Iteration 8/1780 Training loss: 4.2025 0.1354 sec/batch
Epoch 1/10  Iteration 9/1780 Training loss: 4.1165 0.1352 sec/batch
Epoch 1/10  Iteration 10/1780 Training loss: 4.0443 0.1353 sec/batch
Epoch 1/10  Iteration 11/1780 Training loss: 3.9839 0.1354 sec/batch
Epoch 1/10  Iteration 12/1780 Training loss: 3.9335 0.1351 sec/batch
Epoch 1/10  Iteration 13/1780 Training loss: 3.8867 0.1355 sec/batch
Epoch 1/10  Iteration 14/1780 Training loss: 3.8462 0.1354 sec/batch
Epoch 1/10  Iteration 15/1780 Training loss

In [15]:
tf.train.get_checkpoint_state('checkpoints/anna')

model_checkpoint_path: "checkpoints/anna/i1780_l512_1.263.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i200_l512_2.503.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i400_l512_2.027.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i600_l512_1.790.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i800_l512_1.631.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i1000_l512_1.509.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i1200_l512_1.426.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i1400_l512_1.347.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i1600_l512_1.295.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i1780_l512_1.263.ckpt"

## Sampling

Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.

The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.



In [16]:
def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c

In [17]:
def sample(checkpoint, n_samples, lstm_size, vocab_size, prime="The "):
    prime = "Far"
    samples = [c for c in prime]
    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoint)
        new_state = sess.run(model.initial_state)
        for c in prime:
            x = np.zeros((1, 1))
            x[0,0] = vocab_to_int[c]
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

        c = pick_top_n(preds, len(vocab))
        samples.append(int_to_vocab[c])

        for i in range(n_samples):
            x[0,0] = c
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

            c = pick_top_n(preds, len(vocab))
            samples.append(int_to_vocab[c])
        
    return ''.join(samples)

In [24]:
checkpoint = "checkpoints/anna/i1780_l512_1.263.ckpt"
samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime="Far")
print(samp)

Farencial father and her
service; but the prass and who went to her to anyight to him and with
a smile, starding him alone.

"I don't know that the were," said the some window, with a mush on this
tentrance, and as the same of the ride, though she was a might. In all
the presion and his wife, and said, brought that he was in the condition
of the sorn of the prince. Think at the point of all the senver, and
that he cared the complete of answer, who arrowed him would shot a
liver that the summles who was the man, and as the say and he did not
see him and still, which had the position of his strather with her
consideration, that she had a controne, and the clarses of his former was
the come of simply at her through and a smile to see him of and all his
countin, as the some would be a sone of his fas of the child to be a
living of her.

"You call the big about it?"

"I was there and she was it secilling and all this is. He was thinking it
with me in specking intented to me in this are such

In [25]:
checkpoint = "checkpoints/anna/i200_l512_2.503.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farnd hor one at e as tirtere ate har har hees tisigen hhe sor io othe at hired the te oring an ot ine hond orad ans iat ase ile tor tho tons hee thon the har oad shan oe ant oe hor te the seered one an ithe wersen ad ase ilo an te ante on ororeris io e as aresang thes he theran io onte antithr at oros ite hor tan hint heere tos hh sithis hin ant eor oo th tor and
ath tee hh san oe as iot il as the whe woto ar and an iso onte his an hin the ant e ton alerile orennton ha te alil oo arerannd ant at he seans on his here anther inth ererind oo oes th arsote ar oo tarinte thes ond oa shar inthon he wan he ase andes oithin the werin ante ter harse one her oned han ho shesis on ared
alit ond tor ot he see an he sor ar inse an it oot tee han asente tir onteed tor to th ito he seas, ald an arid an herinthe tish ean ad tat as oo ansed in oet an thore ter asint hee the whint an toes he wor al tinte wasd and as hee whod ar ile ho inthes an hesing on ind toer and anthid and asid
on as tee hes and a

In [26]:
checkpoint = "checkpoints/anna/i600_l512_1.790.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farcine sho head ont that that he westerth of in the
selaly, whrises that heared, the could on a posstere.

"Werl hen shate that his. "I's it that the more, the cound to buten his sace and
here of a torking anse so him.

"I cont all thom. And he said at ant hum and the cout heard and and seet and at on a that that's the troight his with sare.

"Yes
in the camsithens, and the were the
candar in the
prialint and thim her
and that were, wely hered to at this that she was seettersted, and
hig as hid and astediages on was to that seep her a seether.

"Well tithing
here anothars the walled tak the mont a comsant, and to sall and tome, he was
stolther had, but sor herest of she sentthing heathing to and his
starteds of the his hid breed hevine wert of the sistere alle, who whone wares her and, werle that he saind and,
to but a tondere, and and the certere to a the canet of insouting to be
the share oun ofe him a for the canting an the seres in of herserted and the cansantico so bothing anderi

In [28]:
checkpoint = "checkpoints/anna/i1000_l512_1.509.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farnige her, but the drown
with with the came
into been one with a lot was
hartion, and said in him, and the mire someted into the
postor to she was some
to the
roment in the dignist was home foor him.

"Why, yes? What's see they?" said Stepan Arkadyevitch was so to tarked.

"You don't could not and troust of though that's a cerst and the
clars, but seading to her. Ho his
coming.

"What's now was stend it to sind a much an and to din the position in the contring to
the
daidsess as his sout, and sore in so mand to the sums to a mance
was that seat of the metting to hard
to the sere with she
canced that he was a starning. The she deeparesed the strang is hear and
the perine and was sone and with his she would not say her, and all the strop of
her ablead, and with her sont to
a convisulle, steping
along to she with the sord of the propining, and her so made alone at him his
canchasion, and that.

"I have to took a latting in the cout on this and which how should him. And
to a condersail, 