# Recurrent Neural Network for MNIST

This notebook implements a (dynamic) multilayer recurrent neural network using long short term memory (LSTM) units and TensorFlow. It contains detailed explanations for each step.

Author: Anna-Lena Popkes

In [28]:
% matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


## Explanation: truncated Backpropagation

A recurrent neural network is designed in a way such that the output at a certain time step depends on arbitrarily distant inputs. In other words: when building an RNN in TensorFlow, the graph would have to be as wide as the input sequence.
Unfortunately, this makes backpropagation computation both expensive and ineffective because gradients propagated over many time steps tend to either vanish (most of the time) or explode.

A common solution to this problem is to create an "unrolled" version of the recurrent network that contains a fixed number (*n_steps*) of RNN inputs and outputs. In other words: backpropagation is "truncated" such that errors are only backpropagated for a fixed number of steps. A higher number of steps enables capturing long-term dependencies but is also more expensive (both regarding memory and computation).

The model is then trained on this finite approximation of the recurrent network. Accordingly, at each time step the network is fed with inputs of length *n_steps*. The backward pass is performed after each input block. A short explanation is given on TensorFlow's [website](https://www.tensorflow.org/tutorials/recurrent#truncated-backpropagation).

In [25]:
# Global parameters
eta = 0.01 # learning rate
n_epochs = 4
n_input = 28
n_classes = 10
batch_size = 100
n_batches = mnist.train.images.shape[0]//batch_size

# Network parameters
n_hidden = 20 # number of hidden units per layer
n_layers = 3 # number of layers 
n_steps = 28 # number of truncated backprop steps

## Dynamic vs. static RNN 

Tensorflow provides two RNN functions, namely *tf.nn.rnn* and *tf.nn.dynamic_rnn*. 

*tf.nn.rnn* creates an unrolled graph for a *fixed* RNN length. For example, when calling *tf.nn.rnn* with an input sequence of length 200, a static graph with 200 time steps is created. This has the disadvantage that we cannot feed longer or shorter sequences into the network than originally specified.

*tf.nn.dynamic_rnn* solves this problem. It uses a tf.While loop to *dynamically* construct the graph when it's executed. This makes graph creation faster and allows for the input batches to vary in size.

One difference between the two functions is the form of the input data. Whereas *tf.nn.rnn* takes a list of tensors as an input (namely a list of  n_steps tensors with shape (batch_size, input_size), *tf.nn.dynamic_rnn* takes as input the whole tensor of shape (batch_size, n_steps, input_size).

## Network architecture

The basic architecture of a recurrent network looks as follows ([source](http://www.deeplearningbook.org/)). 


![title](figures/basic_rnn.png)



As visibile in the figure, the state $h^{(t)}$  of the network depends both on the input $x^{(t)}$ and the previous state $h^{(t-1)}$.
It is computed as follows:

$$ h^{(t)} = \sigma(U x^{(t)} + W h^{(t-1)} + b) $$

with $\sigma$ beign the $\tanh$ in our implementation.


The output is computed as: 

$$ o^{(t)} = V h^{(t)} + c $$
$$ \hat{y}^{(t)} = \text{softmax}(o^{(t)}) $$


In [3]:
# Create placeholder variables for the input and targets
X_placeholder = tf.placeholder(tf.float32, shape=[batch_size, n_steps, n_input])
y_placeholder = tf.placeholder(tf.int32, shape=[batch_size, n_classes])

# Create placeholder variables for final weight and bias matrix 
V = tf.Variable(tf.random_normal(shape=[n_hidden, n_classes]))
c = tf.Variable(tf.random_normal(shape=[n_classes]))

# For each initialized LSTM cell we need to specify how many hidden
# units the cell should have.
cell = tf.contrib.rnn.LSTMCell(num_units=n_hidden)

# To create multiple layers we call the MultiRNNCell function that takes 
# a list of RNN cells as an input and wraps them into a single cell
cell = tf.contrib.rnn.MultiRNNCell([cell]*n_layers)

# Create a zero-filled state tensor as an initial state
init_state = cell.zero_state(batch_size, tf.float32)

# Create a recurrent neural network specified by "cell", i.e. unroll the
# network.
# Returns a list of all previous RNN hidden states and the final states.
# final_state contains n_layer LSTMStateTuple that contain both the 
# final hidden and the cell state of the respective layer.
outputs, final_state = tf.nn.dynamic_rnn(cell, 
                                         X_placeholder, 
                                         initial_state=init_state)

## Gather final activations

Because we are performing sequence *classification*, we are only interested in the output activations of the last timestep. Since tensorflow does not support negative indexing, we first transpose the tensor such that the "n_steps" axis is first. Then, we use tf.gather to select the correct slice. This process is illustrated below for the following parameter setting: batch_size=100, n_steps=28, n_hidden=20

![title](figures/tensor_transformation.jpg)

In [4]:
temp = tf.transpose(outputs, [1,0,2])
last_output = tf.gather(temp, int(temp.get_shape()[0]-1))

## Network output and loss function

In [None]:
# After gathering the final activations we can easily compute the logits
# using a single matrix multiplication
logits = tf.matmul(last_output, V)+c

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_placeholder,
                                                           logits=logits))

train_step = tf.train.AdamOptimizer(eta).minimize(loss)

## Accuracy

We compute the accuracy of our model as follows.
First, we use *tf.argmax* which gives us the highest entry in a tensor along some axis. For example, *tf.argmax(logits,1)* gives us the label our model considers to be most likely for each input. The true labels are computed using *tf.argmax(y_placeholder,1)*. In a next step, we compare these two tensors using *tf.equal* resulting in a tensor of boolean values.

To compute the accuracy, we first cast the boolean values to floats using *tf.cast*. In a last step, we take the mean of all values.

In [20]:
correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(y_placeholder,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

## Training

We train the network for the specified number of epochs. In each epoch, we run through all training examples, computing the loss and accuracy after 100 batches.

In [26]:
with tf.Session() as sess:
    
    # We first have to initialize all variables
    init = tf.global_variables_initializer()
    sess.run(init)
    
    # Train for the specified number of epochs
    for epoch in range(n_epochs):
        
        print()
        print("Epoch: ", epoch)
        
        for batch in range(n_batches):
            
            x_batch, y_batch = mnist.train.next_batch(batch_size)
            x_batch = x_batch.reshape((batch_size, n_steps, n_input))
            
            _train_step = sess.run(train_step, 
                                        feed_dict=
                                        {X_placeholder:x_batch,
                                         y_placeholder:y_batch
                                        })
            
            
            if batch%100 == 0:
                _loss, _accuracy = sess.run([loss, accuracy],
                                 feed_dict={
                                     X_placeholder:x_batch,
                                     y_placeholder:y_batch
                                 })
                print("Minibatch loss: %s  Accuracy: %s" % (_loss, _accuracy))
          
    print()
    print("Optimization done! Let's calculate the test error")
    
    # Evaluate the model on the first "batch_size" test examples
    x_test_batch, y_test_batch = mnist.test.next_batch(batch_size)
    x_test_batch = x_test_batch.reshape((batch_size, n_steps, n_input))
    
    test_loss, test_accuracy, _train_step = sess.run([loss, accuracy, train_step],
                                                    feed_dict={
                                                        X_placeholder:x_test_batch,
                                                        y_placeholder:y_test_batch
                                                    })
    print()
    print("Loss on test set: ", test_loss)
    print("Accuracy on test set: ", test_accuracy)


Epoch:  0
Minibatch loss: 2.35251  Accuracy: 0.18
Minibatch loss: 0.504062  Accuracy: 0.85
Minibatch loss: 0.324097  Accuracy: 0.92
Minibatch loss: 0.194671  Accuracy: 0.92
Minibatch loss: 0.0865606  Accuracy: 0.99
Minibatch loss: 0.140195  Accuracy: 0.96

Epoch:  1
Minibatch loss: 0.124076  Accuracy: 0.97
Minibatch loss: 0.0696919  Accuracy: 0.99
Minibatch loss: 0.098697  Accuracy: 0.98
Minibatch loss: 0.0694356  Accuracy: 0.99
Minibatch loss: 0.0907172  Accuracy: 0.99
Minibatch loss: 0.0110981  Accuracy: 1.0

Epoch:  2
Minibatch loss: 0.0895157  Accuracy: 0.98
Minibatch loss: 0.0611437  Accuracy: 0.99
Minibatch loss: 0.182449  Accuracy: 0.96
Minibatch loss: 0.0999113  Accuracy: 0.97
Minibatch loss: 0.0848051  Accuracy: 0.98
Minibatch loss: 0.0522574  Accuracy: 0.99

Epoch:  3
Minibatch loss: 0.0710386  Accuracy: 0.99
Minibatch loss: 0.0692343  Accuracy: 0.99
Minibatch loss: 0.070002  Accuracy: 0.98
Minibatch loss: 0.0168356  Accuracy: 1.0
Minibatch loss: 0.0741237  Accuracy: 0.98
Mi