In [2]:
import os
import numpy as np

import tensorflow as tf
from tensorflow.python.keras.datasets import mnist
from tensorflow.contrib.eager.python import tfe

# Import the BasicLSTM written in TF Eager
from utils.basic_lstm import BasicLSTM

In [3]:
# enable eager mode
tf.enable_eager_execution()
tf.set_random_seed(0)
np.random.seed(0)

In [4]:
if not os.path.exists('weights/'):
    os.makedirs('weights/')

# constants
units = 128
batch_size = 100
epochs = 2
num_classes = 10

In [5]:
# dataset loading
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((-1, 28, 28))  # 28 timesteps, 28 inputs / timestep
x_test = x_test.reshape((-1, 28, 28))  # 28 timesteps, 28 inputs / timeste

# one hot encode the labels. convert back to numpy as we cannot use a combination of numpy
# and tensors as input to keras
y_train_ohe = tf.one_hot(y_train, depth=num_classes).numpy()
y_test_ohe = tf.one_hot(y_test, depth=num_classes).numpy()

print('x train', x_train.shape)
print('y train', y_train_ohe.shape)
print('x test', x_test.shape)
print('y test', y_test_ohe.shape)

x train (60000, 28, 28)
y train (60000, 10)
x test (10000, 28, 28)
y test (10000, 10)


# What is the `BasicLSTM` Model 

The earlier model uses the LSTM model from Keras and this model uses the `BasicLSTM` Model written in TF Eager style code, which more or less replicates the important parts of the Keras LSTMCell.

For some reason, there is a noticeable speed difference between the two models. Perhaps it is due to the usage of K.rnn() internally (inside Keras RNN, the base class of all RNNs) which is causing the slowdown.

In comparison, the `BasicLSTM` simply loops over the batch of data in a more pythonic way. It can be found in the utils folder, and is posted as a code snippet here. 

```python
import tensorflow as tf


class BasicLSTM(tf.keras.Model):
    def __init__(self, units, return_sequence=False, return_states=False, **kwargs):
        super(BasicLSTM, self).__init__(**kwargs)
        self.units = units
        self.return_sequence = return_sequence
        self.return_states = return_states

        def bias_initializer(_, *args, **kwargs):
            # Unit forget bias from the paper
            # - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
            return tf.keras.backend.concatenate([
                tf.keras.initializers.Zeros()((self.units,), *args, **kwargs),  # input gate
                tf.keras.initializers.Ones()((self.units,), *args, **kwargs),  # forget gate
                tf.keras.initializers.Zeros()((self.units * 2,), *args, **kwargs),  # context and output gates
            ])

        self.kernel = tf.keras.layers.Dense(4 * units, use_bias=False)
        self.recurrent_kernel = tf.keras.layers.Dense(4 * units, kernel_initializer='glorot_uniform', bias_initializer=bias_initializer)

    def call(self, inputs, training=None, mask=None, initial_states=None):
        # LSTM Cell in pure TF Eager code
        # reset the states initially if not provided, else use those
        if initial_states is None:
            h_state = tf.zeros((inputs.shape[0], self.units))
            c_state = tf.zeros((inputs.shape[0], self.units))
        else:
            assert len(initial_states) == 2, "Must pass a list of 2 states when passing 'initial_states'"
            h_state, c_state = initial_states

        h_list = []
        c_list = []

        for t in range(inputs.shape[1]):
            # LSTM gate steps
            ip = inputs[:, t, :]
            z = self.kernel(ip)
            z += self.recurrent_kernel(h_state)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            # gate updates
            i = tf.keras.activations.sigmoid(z0)
            f = tf.keras.activations.sigmoid(z1)
            c = f * c_state + i * tf.nn.tanh(z2)

            # state updates
            o = tf.keras.activations.sigmoid(z3)
            h = o * tf.nn.tanh(c)

            h_state = h
            c_state = c

            h_list.append(h_state)
            c_list.append(c_state)

        hidden_outputs = tf.stack(h_list, axis=1)
        hidden_states = tf.stack(c_list, axis=1)

        if self.return_states and self.return_sequence:
            return hidden_outputs, [hidden_outputs, hidden_states]
        elif self.return_states and not self.return_sequence:
            return hidden_outputs[:, -1, :], [h_state, c_state]
        elif self.return_sequence and not self.return_states:
            return hidden_outputs
        else:
            return hidden_outputs[:, -1, :]
```

In [6]:
class BasicLSTMModel(tf.keras.Model):
    def __init__(self, units, num_classes):
        super(BasicLSTMModel, self).__init__()
        self.units = units
        self.lstm = BasicLSTM(units)
        self.classifier = tf.keras.layers.Dense(num_classes)

    def call(self, inputs, training=None, mask=None):
        h = self.lstm(inputs)
        output = self.classifier(h)

        # softmax op does not exist on the gpu, so always use cpu
        with tf.device('/cpu:0'):
            output = tf.nn.softmax(output)

        return output

In [7]:
device = '/cpu:0' if tfe.num_gpus() == 0 else '/gpu:0'

with tf.device(device):
    # build model and optimizer
    model = BasicLSTMModel(units, num_classes)
    model.compile(optimizer=tf.train.AdamOptimizer(0.01), loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # TF Keras tries to use entire dataset to determine shape without this step when using .fit()
    # Fix = Use exactly one sample from the provided input dataset to determine input/output shape/s for the model
    dummy_x = tf.zeros((1, 28, 28))
    model._set_inputs(dummy_x)

    # train
    model.fit(x_train, y_train_ohe, batch_size=batch_size, epochs=epochs,
              validation_data=(x_test, y_test_ohe), verbose=1)

    # evaluate on test set
    scores = model.evaluate(x_test, y_test_ohe, batch_size, verbose=1)
    print("Final test loss and accuracy :", scores)

    saver = tfe.Saver(model.variables)
    saver.save('weights/06_02_rnn/weights.ckpt')

Train on 60000 samples, validate on 10000 samples
Epoch 1/2
Epoch 2/2
Final test loss and accuracy : [0.09331267344765365, 0.9722000068426132]
