# Anna KaRNNa

In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.

This network is based off of Andrej Karpathy's [post on RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) and [implementation in Torch](https://github.com/karpathy/char-rnn). Also, some information [here at r2rt](http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html) and from [Sherjil Ozair](https://github.com/sherjilozair/char-rnn-tensorflow) on GitHub. Below is the general architecture of the character-wise RNN.

<img src="assets/charseq.jpeg" width="500">

In [1]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

First we'll load the text file and convert it into integers for our network to use.

In [2]:
with open('anna.txt', 'r') as f:
    text=f.read()
vocab = set(text)
vocab_to_int = {c: i for i, c in enumerate(vocab)}
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)

In [3]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars[:100]

array([73, 45, 31,  2, 62, 25, 17, 38, 44, 46, 46, 46, 18, 31,  2,  2, 66,
       38,  1, 31, 81, 39, 61, 39, 25, 57, 38, 31, 17, 25, 38, 31, 61, 61,
       38, 31, 61, 39, 75, 25, 49, 38, 25, 22, 25, 17, 66, 38, 54,  0, 45,
       31,  2,  2, 66, 38,  1, 31, 81, 39, 61, 66, 38, 39, 57, 38, 54,  0,
       45, 31,  2,  2, 66, 38, 39,  0, 38, 39, 62, 57, 38, 20, 30,  0, 46,
       30, 31, 66, 82, 46, 46, 63, 22, 25, 17, 66, 62, 45, 39,  0], dtype=int32)

Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.

Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.

The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the `split_frac` keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set.

In [5]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    """ 
    Split character data into training and validation sets, inputs and targets for each set.
    
    Arguments
    ---------
    chars: character array
    batch_size: Size of examples in each of batch
    num_steps: Number of sequence steps to keep in the input and pass to the network
    split_frac: Fraction of batches to keep in the training set
    
    
    Returns train_x, train_y, val_x, val_y
    """
    
    
    slice_size = batch_size * num_steps
    n_batches = int(len(chars) / slice_size)
    
    # Drop the last few characters to make only full batches
    x = chars[: n_batches*slice_size]
    y = chars[1: n_batches*slice_size + 1]
    
    # Split the data into batch_size slices, then stack them into a 2D matrix 
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps
    
    # Split into training and validation sets, keep the virst split_frac batches for training
    split_idx = int(n_batches*split_frac)
    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
    
    return train_x, train_y, val_x, val_y

In [6]:
train_x, train_y, val_x, val_y = split_data(chars, 10, 200)

In [7]:
train_x.shape

(10, 178400)

In [8]:
train_x[:,:10]

array([[73, 45, 31,  2, 62, 25, 17, 38, 44, 46],
       [27,  0,  4, 38, 45, 25, 38, 81, 20, 22],
       [38, 47, 31, 62, 47, 45, 39,  0, 77, 38],
       [20, 62, 45, 25, 17, 38, 30, 20, 54, 61],
       [38, 62, 45, 25, 38, 61, 31,  0,  4, 64],
       [38, 55, 45, 17, 20, 54, 77, 45, 38, 61],
       [62, 38, 62, 20, 46,  4, 20, 82, 46, 46],
       [20, 38, 45, 25, 17, 57, 25, 61,  1, 78],
       [45, 31, 62, 38, 39, 57, 38, 62, 45, 25],
       [25, 17, 57, 25, 61,  1, 38, 31,  0,  4]], dtype=int32)

I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size `batch_size X num_steps`. For example, if we want our network to train on a sequence of 100 characters, `num_steps = 100`. For the next batch, we'll shift this window the next sequence of `num_steps` characters. In this way we can feed batches to the network and the cell states will continue through on each batch.

In [9]:
def get_batch(arrs, num_steps):
    batch_size, slice_size = arrs[0].shape
    
    n_batches = int(slice_size/num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]

In [10]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
        
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    
    inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
    x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')


    targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
    y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')
    y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # Build the RNN layers
    
    lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
    drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
    cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)

    initial_state = cell.zero_state(batch_size, tf.float32)

    # Run the data through the RNN layers
    outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=initial_state)
    final_state = state
    
    # Reshape output so it's a bunch of rows, one row for each cell output
    
    seq_output = tf.concat(outputs, axis=1,name='seq_output')
    output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')
    
    # Now connect the RNN putputs to a softmax layer and calculate the cost
    softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),
                           name='softmax_w')
    softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')
    logits = tf.matmul(output, softmax_w) + softmax_b

    preds = tf.nn.softmax(logits, name='predictions')
    
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')
    cost = tf.reduce_mean(loss, name='cost')

    # Optimizer for training, using gradient clipping to control exploding gradients
    tvars = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
    train_op = tf.train.AdamOptimizer(learning_rate)
    optimizer = train_op.apply_gradients(zip(grads, tvars))

    # Export the nodes 
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

## Hyperparameters

Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are `lstm_size` and `num_layers`. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability.

In [11]:
batch_size = 100
num_steps = 100
lstm_size = 512
num_layers = 2
learning_rate = 0.001

## Write out the graph for TensorBoard

In [12]:
model = build_rnn(len(vocab),
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    file_writer = tf.summary.FileWriter('./logs/1', sess.graph)

## Training

Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by `save_every_n`) I calculate the validation loss and save a checkpoint.

In [13]:
!mkdir -p checkpoints/anna

In [14]:
epochs = 1
save_every_n = 200
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/anna20.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: 0.5,
                    model.initial_state: new_state}
            batch_loss, new_state, _ = sess.run([model.cost, model.final_state, model.optimizer], 
                                                 feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
        
            
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    batch_loss, new_state = sess.run([model.cost, model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                saver.save(sess, "checkpoints/anna/i{}_l{}_{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))

Epoch 1/1  Iteration 1/178 Training loss: 4.4172 7.4726 sec/batch
Epoch 1/1  Iteration 2/178 Training loss: 4.3646 6.8088 sec/batch
Epoch 1/1  Iteration 3/178 Training loss: 4.1658 6.3709 sec/batch
Epoch 1/1  Iteration 4/178 Training loss: 4.0509 6.5977 sec/batch
Epoch 1/1  Iteration 5/178 Training loss: 3.9489 6.7300 sec/batch
Epoch 1/1  Iteration 6/178 Training loss: 3.8610 7.1383 sec/batch
Epoch 1/1  Iteration 7/178 Training loss: 3.7905 7.1475 sec/batch
Epoch 1/1  Iteration 8/178 Training loss: 3.7346 7.6222 sec/batch
Epoch 1/1  Iteration 9/178 Training loss: 3.6882 8.0197 sec/batch
Epoch 1/1  Iteration 10/178 Training loss: 3.6498 7.5991 sec/batch
Epoch 1/1  Iteration 11/178 Training loss: 3.6150 7.0193 sec/batch
Epoch 1/1  Iteration 12/178 Training loss: 3.5843 6.5204 sec/batch
Epoch 1/1  Iteration 13/178 Training loss: 3.5573 6.7677 sec/batch
Epoch 1/1  Iteration 14/178 Training loss: 3.5349 6.2226 sec/batch
Epoch 1/1  Iteration 15/178 Training loss: 3.5142 6.3748 sec/batch
Epoc

In [15]:
tf.train.get_checkpoint_state('checkpoints/anna')

model_checkpoint_path: "checkpoints/anna/i178_l512_2.364.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i178_l512_2.364.ckpt"

## Sampling

Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.

The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.



In [16]:
def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c

In [17]:
def sample(checkpoint, n_samples, lstm_size, vocab_size, prime="The "):
    prime = "Far"
    samples = [c for c in prime]
    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoint)
        new_state = sess.run(model.initial_state)
        for c in prime:
            x = np.zeros((1, 1))
            x[0,0] = vocab_to_int[c]
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

        c = pick_top_n(preds, len(vocab))
        samples.append(int_to_vocab[c])

        for i in range(n_samples):
            x[0,0] = c
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

            c = pick_top_n(preds, len(vocab))
            samples.append(int_to_vocab[c])
        
    return ''.join(samples)

In [20]:
checkpoint = "checkpoints/anna/i178_l512_2.364.ckpt"
samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime="Far")
print(samp)

Farlying thees wot ang al he ans thes he alt or wethan side sorint of hererensis wimeresintin his wos the soren of ant oner wan tasdet ils has thas whes ane he worin tho hin wing, wothing the touthe hed the simthe with sithe arer ous wo wins ans tothe he sede the wim he his sh so whe the to she toun hised hot he to sher on whand te he sotin hither ans hese silt tor th aretas on al tore th whase here tor ther afeter sasithe and ant tins af hee hase we wot te sanseres ot at tar ins an wertint thind tim han sate tous an ane soustond wase al af oto tote he the hered
at oulid,. she tor of to he soulle the toren the saned af ond ting he tas then an he sorerer the with whe ther thar heas he the sares or whas,
ade sime hond sote tangerigg althasese tan sot or wothe sersithe he serin he sheres thens thingerens an he wit the sereran ad ato sin one te hor anses ons, tastithe his tor wand the the ale wat hos torens ans out that he she sering ond an tire ther ans an the tharesd ande wore and terers

In [21]:
checkpoint = "checkpoints/anna/i178_l512_2.364.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Fard tifhes afdrerer ins of he cis the wer sont ane has he and, to sothes and ond to tir or af the sasd and we tit the thet tout hed somtin thes he hed afe an him an her wos the her atinntis hor he astot hes on thin ter alse tho cheresed the cer ife we sha san tare shes thind and the woud siting wat herisser of thint, the se toud son to the the wese to tot heres thering onthe the horend. Taut whe sarsad ote ton tous hor ther heres and ale and thereres at orate ant if ther soril san ont it wher had ansed tha carin oth the ansed in tor ile tor of wout there sot out aner wer athin thice sothe hot ha the hered he cererat ind, hit wer outhe sot te we he soril othe he shinsed wass and whis he te an whe wared had and on his the cat af oustha he hime shon he afis har was an he ther tans tas itis the wotho ter whon we whing the sas he the werens ofe se the he seser ate
toun so han sing th an tore sorigh sas tor ithe won hes thot he were sit and wot hos ther wass intore tort te more hareringe wh

In [22]:
checkpoint = "checkpoints/anna/i178_l512_2.364.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farde nfarin soton to th the sin on tore to misg tous tont ot tere ta sete sont af her teretis and wim this,
aft if hate hise the tan an tome herereren tho ho seres ind thot thes hime whe the tin whe hersere anthe the to the sing whate sore has and he thend
tithin so tarit ha therarer touthed. And had the che he hha sit in hit hirint to thet his se toth thing one tor at he and and af han serithen he whan that hh there aresin as timhese hid and and thins, ho sarise sis than at te wat ha cerase sath ase with the her he hors whor wher an her hot oned af ter ifersentin tho se thon wot hed and wes tere tire the withan toot hot an winge at anes the then sis an oned he sartin an hit ale touth sas ore seter and he he sosinn ans hestin the thes af wor he are al af tous, whit wald witthe wath we whe tha sar afd out te sete an tho hhande whes hes of ont out hos and and on thit hout hin and tant, thin tho the sose sithe and wheng, whins tans withan af an ane thes tha wos herite shase hor an onte h

In [23]:
checkpoint = "checkpoints/anna/i178_l512_2.364.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farlingig ante hithes arad in he her hare an he serestorer the sis hon se won wha this seterit welsensint the alse ther wise th ane wors, whad, sed ante thimisgis hat hhe thing ande an of wout, tha ge hed wans oth wime the who he than hesd anderers ind an tho the sont he the sedrand af his hime wered hims
ansore the sas therersand ons hise tit an thes thass
oned and the hed sor on the he wond whis sosisg thete har afe te hire thet ant oneteressens ont ther taring tho he ariled
tat he the ton the soud she hint to sothed ha the sisles ofe het her wis sh tomerat an af the cereres tore her tate wersent herseding ther the tise thictesen whet al he tothe ho har who cone so hith sore tas he so she son hadedidet in he has and ot win him whe wo f hit hite sith sosthim af tint of tat he win has he hed hins here sas the thes af the sast ato ho te sot the se soreren whand hins at he sers hhe he the hor wat teule to who chan wint has
tond wos toun hit the an tore tout, whin had and wer there the wh