In [1]:
%load_ext autoreload
%autoreload 2

## LSTM Cell
The *Long Short-Term Memory* (LSTM) cell was proposed in 1997. It performs much better than a standard RNN cell, training will converge faster and it will detect long term dependencies in the data. The architecture of a basic LSTM cell is shown below:
![lstm](figures/lstm.png)

If you don't look at what's inside the box, the LSTM cell looks exactly like a regular cell, except that its state is split into two vectors: $\mathbf{h}_{(t)}$ and $\mathbf{c}_{(t)}$ ("c" stands for "cell"). You can think of $\mathbf{h}_{(t)}$ as the short-term state and $\mathbf{c}_{(t)}$ as the long-term state.

The key idea of LSTM cells is that the network can learn what to store in the long term state, what to throw away, and what to read from it. As the long term state $\mathbf{c}_{(t-1)}$ traverse the network from left to right, you can see it first goes through a **forget gate**, dropping some memories, and then it adds some new memories via the addition operation (which adds the memories that were selected by an **input gate**). The result $\mathbf{c}_{(t)}$ is sent straight out, without any further transformation. So at each time step, some memories are dropped and some are added. 

Moreover, after the addition operation, the long-term state is copied and passed through the `tanh` function, and the result is filtered by the **output gate**. This produces the short term state $\mathbf{h}_{(t)}$ (which is equal to the cell's output for this time step $\mathbf{y}_{(t)}$.

Let's look at where the new memories come from and how the gates work:

First, the current output $\mathbf{x}_{(t)}$ and the previous short-term state $\mathbf{h}_{(t-1)}$ are fed to four different fully connected layers. They all serve a different purpose:
* The main layer is the one that outputs $\mathbf{g}_{(t)}$. It has the usual role of analysing the current inputs $\mathbf{x}_{(t)}$ and the previous (short-term) state $\mathbf{h}_{(t-1)}$. In a basic cell, there is nothing else than this layer, and its output goes straight out to $\mathbf{y}_{(t)}$ and $\mathbf{h}_{(t)}$. In contrast, in an LSTM cell this layer's output does not go straight out, but instead is partially stored in the long-term state.

* The other three layers are **gate controllers**. Since they are logistic activation function, their outputs range from 0 to 1. As you can see, their outputs are fed to element-wise multiplication operations, so if they output 0s, they close the gate, and if they output 1s they open it. Specifically:

    - The **forget gate** (controlled by $\mathbf{f}_{(t)}$) controls which part of the long term state should be erased
    
    - The **input gate** (controlled by $\mathbf{i}_{(t)}$) controls which parts of $\mathbf{g}_{(t)}$ should be added to the long-term state (this is why we said it was only "partially stored")
    
    - The **output gate** (controlled by $\mathbf{o}_{(t)}$) controls which parts of the long-term state should be read and output at this time step (both to $\mathbf{h}_{(t)}$ and $\mathbf{y}_{(t)}$. 

LSTM computations
\begin{eqnarray}
\mathbf{i}_{(t)}  &= & \sigma(\mathbf{W}_{xi}^T\cdot\mathbf{x}_{(t)} + \mathbf{W}_{hi}^T\cdot \mathbf{h}_{(t-1)} + \mathbf{b}_i)\\
\mathbf{f}_{(t)}  &= & \sigma(\mathbf{W}_{xf}^T\cdot\mathbf{x}_{(t)} + \mathbf{W}_{hf}^T\cdot \mathbf{h}_{(t-1)} + \mathbf{b}_f)\\
\mathbf{o}_{(t)}  &= & \sigma(\mathbf{W}_{xo}^T\cdot\mathbf{x}_{(t)} + \mathbf{W}_{ho}^T\cdot \mathbf{h}_{(t-1)} + \mathbf{b}_o)\\
\mathbf{g}_{(t)}  &= & \mathrm{tanh}(\mathbf{W}_{xg}^T\cdot\mathbf{x}_{(t)} + \mathbf{W}_{hg}^T\cdot \mathbf{h}_{(t-1)} + \mathbf{b}_g)\\
\mathbf{c}_{(t)} & = & \mathbf{f}_{(t)}\otimes \mathbf{c}_{(t-1)} + \mathbf{i}_{(t)}\otimes \mathbf{g}_{(t)}\\
\mathbf{y}_{(t)} & = & \mathbf{h}_{(t)} = \mathbf{o}_{(t)}\otimes \mathrm{tanh}(\mathbf{c}_{(t)})
\end{eqnarray}

In [4]:
import numpy as np
import os
import tensorflow as tf

def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)
    
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

def save_fig(fig_id, tight_layout=True):
    path = os.path.join('figures', fig_id + '.png')
    print('Saving figure', fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)

## Basic RNNs

In [5]:
reset_graph()
n_inputs = 3
n_neurons = 5

X0 = tf.placeholder(tf.float32, [None, n_inputs])
X1 = tf.placeholder(tf.float32, [None, n_inputs])

Wx = tf.Variable(tf.random_normal(shape=[n_inputs, n_neurons], dtype=tf.float32))
Wy = tf.Variable(tf.random_normal(shape=[n_neurons, n_neurons], dtype=tf.float32))
b = tf.Variable(tf.zeros([1, n_neurons], dtype=tf.float32))

Y0 = tf.tanh(tf.matmul(X0, Wx) + b)
Y1 = tf.tanh(tf.matmul(Y0, Wy) + tf.matmul(X1, Wx) + b)

init = tf.global_variables_initializer()

In [6]:
X0_batch = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 0, 1]]) # t = 0
X1_batch = np.array([[9, 8, 7], [0, 0, 0], [6, 5, 4], [3, 2, 1]]) # t = 1

with tf.Session() as sess:
    init.run()
    Y0_val, Y1_val = sess.run([Y0, Y1], feed_dict={X0: X0_batch, X1:X1_batch})

In [7]:
print(Y0_val)
print(Y1_val)

[[-0.0664006   0.9625767   0.68105793  0.7091854  -0.898216  ]
 [ 0.9977755  -0.71978897 -0.9965761   0.9673924  -0.9998972 ]
 [ 0.99999774 -0.99898803 -0.9999989   0.9967762  -0.9999999 ]
 [ 1.         -1.         -1.         -0.99818915  0.9995087 ]]
[[ 1.         -1.         -1.          0.40200275 -0.9999998 ]
 [-0.12210423  0.6280527   0.9671843  -0.9937122  -0.25839362]
 [ 0.9999983  -0.9999994  -0.9999975  -0.8594331  -0.9999881 ]
 [ 0.99928284 -0.99999803 -0.9999058   0.9857963  -0.92205757]]


## Using `static_rnn()`

In [8]:
n_inputs = 3
n_neurons = 5

In [9]:
reset_graph()

X0 = tf.placeholder(tf.float32, [None, n_inputs])
X1 = tf.placeholder(tf.float32, [None, n_inputs])

basic_cell = tf.keras.layers.SimpleRNNCell(units=n_neurons)
output_seqs, states = tf.nn.static_rnn(basic_cell, [X0, X1],
                                      dtype=tf.float32)
Y0, Y1 = output_seqs

In [10]:
init = tf.global_variables_initializer()

X0_batch = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 0, 1]])
X1_batch = np.array([[9, 8, 7], [0, 0, 0], [6, 5, 4], [3, 2, 1]])

with tf.Session() as sess:
    init.run()
    Y0_val, Y1_val = sess.run([Y0, Y1], feed_dict={X0: X0_batch, X1: X1_batch})

In [11]:
print('Y0_val', Y0_val)
print('Y1_val', Y1_val)

Y0_val [[ 0.90012544 -0.8181632  -0.7137506   0.96558857 -0.00169486]
 [ 0.86250913 -0.9449419  -0.9252926   0.9998034   0.63831985]
 [ 0.812126   -0.984102   -0.9821324   0.9999989   0.90735716]
 [-0.9999769   0.99998367  0.9983094  -0.99973863  0.99999195]]
Y1_val [[-0.93277675  0.67243385 -0.65539604  0.99997735  0.9559362 ]
 [ 0.53530896  0.889247    0.7819666  -0.08488655 -0.5707694 ]
 [-0.80390906  0.83610725  0.6096734   0.99554706  0.6496031 ]
 [-0.9343803  -0.9317238   0.16454792  0.13083158  0.634052  ]]


In [3]:
from tensorflow_graph_in_jupyter import show_graph

In [12]:
show_graph(tf.get_default_graph())

## Training a sequence classifier

In [13]:
reset_graph()

n_steps = 28
n_inputs = 28
n_neurons = 150
n_outputs = 10

learning_rate = 0.001

X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])

basic_cell = tf.keras.layers.SimpleRNNCell(units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

logits = tf.layers.dense(states, n_outputs)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)

loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

init = tf.global_variables_initializer()

In [14]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]

In [15]:
def shuffle_batch(X, y, batch_size):
    rnd_idx = np.random.permutation(len(X))
    n_batches = len(X) // batch_size
    for batch_idx in np.array_split(rnd_idx, n_batches):
        X_batch, y_batch = X[batch_idx], y[batch_idx]
        yield X_batch, y_batch

In [16]:
X_test = X_test.reshape((-1, n_steps, n_inputs))

In [18]:
n_epochs = 100
batch_size = 150

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
            X_batch = X_batch.reshape(-1, n_steps, n_inputs)
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
        acc_batch = accuracy.eval(feed_dict={X:X_batch, y:y_batch})
        acc_test = accuracy.eval(feed_dict={X:X_test, y:y_test})
        print('{} Last batch accuracy: {}, Test accuracy: {}'.format(epoch, acc_batch, acc_test))

0 Last batch accuracy: 0.9133333563804626, Test accuracy: 0.9182000160217285
1 Last batch accuracy: 0.9466666579246521, Test accuracy: 0.9484999775886536
2 Last batch accuracy: 0.9599999785423279, Test accuracy: 0.9550999999046326
3 Last batch accuracy: 0.9800000190734863, Test accuracy: 0.9666000008583069
4 Last batch accuracy: 0.9266666769981384, Test accuracy: 0.9526000022888184
5 Last batch accuracy: 0.9933333396911621, Test accuracy: 0.9714000225067139
6 Last batch accuracy: 0.9733333587646484, Test accuracy: 0.9729999899864197
7 Last batch accuracy: 0.9800000190734863, Test accuracy: 0.9726999998092651
8 Last batch accuracy: 0.9733333587646484, Test accuracy: 0.9736999869346619
9 Last batch accuracy: 0.9800000190734863, Test accuracy: 0.972100019454956
10 Last batch accuracy: 0.9800000190734863, Test accuracy: 0.972599983215332
11 Last batch accuracy: 0.9666666388511658, Test accuracy: 0.9725000262260437
12 Last batch accuracy: 0.9800000190734863, Test accuracy: 0.965200006961822

KeyboardInterrupt: 