# MNIST handwritten digits classification with RNNs

In this notebook, we'll train a recurrent neural network (RNN) to classify MNIST digits using Keras (with either Theano or Tensorflow as the compute backend).  Keras version $\ge$ 2 is required. 

This notebook builds on the MNIST-MLP notebook, so the recommended order is to go through the MNIST-MLP notebook before starting with this one. 

First, the needed imports. Note that there are a few new recurrent layers compared to the MNIST-MLP notebook: SimpleRNN, LSTM, GRU.

In [None]:
%matplotlib inline

from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.layers.recurrent import SimpleRNN, LSTM, GRU 
from keras.utils import np_utils
from keras import backend as K

from distutils.version import LooseVersion as LV
from keras import __version__

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

print('Using Keras version:', __version__, 'backend:', K.backend())
assert(LV(__version__) >= LV("2.0.0"))

Next, let's load and process the MNIST dataset. First time we may have to download the data, which can take a while.

In [None]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
nb_classes = 10
img_rows, img_cols = 28, 28

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# one-hot encoding:
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

print()
print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('Y_train:', Y_train.shape)

### Images as sequences

Note that in this notebook we are using *a sequence model* for image classification.  Therefore, we consider here an image to be a sequence of (pixel) input vectors.

More exactly, we consider each MNIST digit image (of size 28x28 pixels) to be a sequence of length 28 (number of image rows) with a 28-dimensional input vector (each image row, having 28 columns) associated with each time step. 

### Initialization

Now we are ready to create a recurrent model.  Keras contains three types of recurrent layers:

 * `SimpleRNN`, a fully-connected RNN where the output is fed back to input.
 * `LSTM`, the Long-Short Term Memory unit layer.
 * `GRU`, the Gated Recurrent Unit layer.

See https://keras.io/layers/recurrent/ for more information.

In [None]:
# Number of hidden units to use:
nb_units = 50

model = Sequential()

# Recurrent layers supported: SimpleRNN, LSTM, GRU:
model.add(SimpleRNN(nb_units,
                    input_shape=(img_rows, img_cols)))

# To stack multiple RNN layers, all RNN layers except the last one need
# to have "return_sequences=True".  An example of using two RNN layers:
#model.add(SimpleRNN(16,
#                    input_shape=(img_rows, img_cols),
#                    return_sequences=True))
#model.add(SimpleRNN(32))

model.add(Dense(units=nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

print(model.summary())

In [None]:
SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))

### Learning

Now let's train the RNN model. Note that we do not need the `reshape()` function as in the MLP case. 

This is a relatively complex model, so training (especially with LSTM and GRU layers) can be considerably slower than with MLPs. 

In [None]:
%%time

epochs = 3

history = model.fit(X_train, 
                    Y_train, 
                    epochs=epochs, 
                    batch_size=128,
                    verbose=2)

In [None]:
plt.figure(figsize=(5,3))
plt.plot(history.epoch,history.history['loss'])
plt.title('loss')

plt.figure(figsize=(5,3))
plt.plot(history.epoch,history.history['acc'])
plt.title('accuracy');

### Inference

With enough training epochs and a large enough model, the test accuracy should exceed 98%.  

You can compare your result with the state-of-the art [here](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html).  Even more results can be found [here](http://yann.lecun.com/exdb/mnist/). 

In [None]:
%%time
scores = model.evaluate(X_test, Y_test, verbose=2)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

We can again take a closer look on the results. Let's begin by defining
a helper function to show the failure cases of our classifier. 

In [None]:
def show_failures(predictions, trueclass=None, predictedclass=None, maxtoshow=10):
    rounded = np.argmax(predictions, axis=1)
    errors = rounded!=y_test
    print('Showing max', maxtoshow, 'first failures. '
          'The predicted class is shown first and the correct class in parenthesis.')
    ii = 0
    plt.figure(figsize=(maxtoshow, 1))
    for i in range(X_test.shape[0]):
        if ii>=maxtoshow:
            break
        if errors[i]:
            if trueclass is not None and y_test[i] != trueclass:
                continue
            if predictedclass is not None and predictions[i] != predictedclass:
                continue
            plt.subplot(1, maxtoshow, ii+1)
            plt.axis('off')
            plt.imshow(X_test[i,:,:], cmap="gray")
            plt.title("%d (%d)" % (rounded[i], y_test[i]))
            ii = ii + 1

Here are the first 10 test digits the RNN classified to a wrong class:

In [None]:
predictions = model.predict(X_test)

show_failures(predictions)

We can use `show_failures()` to inspect failures in more detail. For example, here are failures in which the true class was "6":

In [None]:
show_failures(predictions, trueclass=6)