In [1]:
#essential imports
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn

In [2]:
#get dataset
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#(batch, sequenceLenght, input size)
trX = trX.reshape(-1, 28, 28)
#(batch, sequenceLength, input_size)
teX = teX.reshape(-1, 28, 28)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
#essential functions
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

In [4]:
def prepareData(X, lstm_size):
    #INPUT: (batch, time, inputVectorSize/lstmCell)
    #OUTPUT:[(batch, inputVectorSize/lstmCell)] of length time
    #STEP-1:(time, batch, inputVectorSize/lstmCell)
    #STEP-2:(time*batch, inputVectorSize/lstmCell)
    #STEP-3:[(batch, inputVectorSize/lstmCell)] of length time
    
    #STEP-1:
    #SHAPE of X:(batch, time, inputVectorsize)
    #Shape of XT:(time, batch, inputvectorSize)
    XT = tf.transpose(X, [1, 0, 2])
    
    #STEP-2:
    #get each row of size lstm cell i.e. (time*batch, lstmsize)
    XR = tf.reshape(XT, [-1, lstm_size])
    
    #get total time_step_size arrays out of XR
    #i.e. (batch, inputVectorSize) arrays of time steps
    #implies total slices = time_steps and each slice is of order (minibatch, inputsize)
    #axis = 0
    X_split = tf.split(XR, time_step_size, 0)
    return X_split


#create the model
def model(X, W, B, lstm_size):
    X_split = prepareData(X, lstm_size)
    # Make lstm with lstm_size (each input vector size)
    lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)
    # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
    #total outputs and states are of size time steps
    #output:[minibatchXlstm], states=[miniatchXlstm] of time step
    outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)

    # Linear activation
    # Get the last output
    #lstm.state_size====LSTMStateTuple(c=28, h=28)
    #return multiplication with last output state cell
    return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the stat



In [10]:
#configuration variables
batch_size        = 128
test_size         = 256
input_vector_size = 28
lstm_size         = 28
time_step_size    = 28

In [6]:
OUTPUT_CLASSES = 10
X = tf.placeholder("float", shape=[None, 28, 28])
Y = tf.placeholder("float", shape=[None, OUTPUT_CLASSES])
W = init_weights([lstm_size, OUTPUT_CLASSES])
B = init_weights([OUTPUT_CLASSES])
py_x, state_size = model(X, W, B, lstm_size)
cost             = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op         = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op       = tf.argmax(py_x, 1)

In [12]:
# Launch the graph in a session
with tf.Session() as sess:
    # you need to initialize all variables
    tf.global_variables_initializer().run()

    for i in range(10):
        for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):
#             print("batch " + str(start))
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})
        test_indices = np.arange(len(teX))  # Get A Test Batch
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]
        predictions  = sess.run(predict_op, feed_dict={X: teX[test_indices]})
        actuals      = np.argmax(teY[test_indices], axis=1)
        accuracy     = np.mean(predictions==actuals)
        print("Epoch:- " + str(i) + " Accuracy:- " + str(accuracy))

Epoch:- 0 Accuracy:- 0.6484375
Epoch:- 1 Accuracy:- 0.7890625
Epoch:- 2 Accuracy:- 0.82421875
Epoch:- 3 Accuracy:- 0.8828125
Epoch:- 4 Accuracy:- 0.93359375
Epoch:- 5 Accuracy:- 0.94140625
Epoch:- 6 Accuracy:- 0.9375
Epoch:- 7 Accuracy:- 0.93359375
Epoch:- 8 Accuracy:- 0.9375
Epoch:- 9 Accuracy:- 0.93359375
