# RNN tutorial
- Use tf.contrib.rnn.static_rnn with MNIST dataset

In [1]:
import time
import numpy as np
import tensorflow as tf

from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
from pprint import pprint

# 0. Load MNIST data

In [2]:
mnist = input_data.read_data_sets('MNIST_idx3/', one_hot=False)

Extracting MNIST_idx3/train-images-idx3-ubyte.gz
Extracting MNIST_idx3/train-labels-idx1-ubyte.gz
Extracting MNIST_idx3/t10k-images-idx3-ubyte.gz
Extracting MNIST_idx3/t10k-labels-idx1-ubyte.gz


In [3]:
X_train, Y_train = mnist.train.images, mnist.train.labels
X_val, Y_val = mnist.validation.images, mnist.validation.labels
X_test, Y_test = mnist.test.images, mnist.test.labels

In [4]:
dim_X = X_train.shape[1]
pixel_X = int(np.sqrt(dim_X)) # np.sqrt의 출력이 float32이므로, 이를 int 자료형으로 변경

print("Dimension of X: %d (%d x %d)" % (dim_X, pixel_X, pixel_X))

Dimension of X: 784 (28 x 28)


# 1. Parameters

In [5]:
LEARNING_RATE  = 0.001
N_EPOCHS       = 10
BATCH_SIZE     = 128
DISPLAY_STEP   = 50
VAL_EPOCH      = 5

n_input   = pixel_X
n_steps   = pixel_X
n_hidden  = 32
n_classes = 10

## 2. Placeholders

In [6]:
X = tf.placeholder(tf.float32, [None, dim_X])
Y = tf.placeholder(tf.int32, [None,])

In [7]:
print(X.get_shape())

(?, 784)


## reshape

In [8]:
# for RNN I need tensor which is axis=1
reshaped_X = tf.reshape(tensor=X, shape=[-1, n_steps, n_input])
unstacked_X = tf.unstack(value=reshaped_X, num=n_steps, axis=1)

print("X shpae: ", X.get_shape())
print("X reshape: ", reshaped_X.get_shape())
print("X reshape unstacked list length: ", len(unstacked_X))
print(type(unstacked_X))
pprint(unstacked_X)

X shpae:  (?, 784)
X reshape:  (?, 28, 28)
X reshape unstacked list length:  28
<class 'list'>
[<tf.Tensor 'unstack:0' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:1' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:2' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:3' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:4' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:5' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:6' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:7' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:8' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:9' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:10' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:11' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:12' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:13' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:14' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:15' shape=(?, 28) dtype=float32>,
 <tf.Tensor 'unstack:16' shape=(?, 

## 3. Model

In [9]:
def RNN_static(X, n_steps, n_hidden, n_classes):
    
    # Reshape the placeholder
    reshaped_X = tf.reshape(tensor=X, shape=[-1, n_steps, n_input], name="reshaped")
    
    # Unstack to get a list of 'n_steps' tensors of shape [batch_size, n_input]
    unstacked_X = tf.unstack(value=reshaped_X, num=n_steps, axis=1, name="unstack")
    
    # Define RNN cell
    rnn_basic_cell = rnn.BasicRNNCell(num_units=n_hidden)
    
    # Get RNN cell outputs
    outputs, _ = rnn.static_rnn(cell=rnn_basic_cell,
                             inputs=unstacked_X,
                             dtype=tf.float32)
    # Define weights and biases
    W = tf.get_variable(name="weight",
                       shape=[n_hidden, n_classes],
                       initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable(name="bias",
                        shape=[n_classes],
                        initializer=tf.constant_initializer(0.0))
    
    # Calculate logits
    # last rnn output: outputs[-1] (because the type of outputs is "list")
    logits = tf.nn.xw_plus_b(x=outputs[-1], weights=W, biases=b, name="logits")
    
    return logits

In [10]:
# logits: output vector of RNN_static
logits = RNN_static(X, n_steps, n_hidden, n_classes)
print(logits)

Tensor("logits:0", shape=(?, 10), dtype=float32)


In [11]:
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=Y))
train_op = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(cost)

In [12]:
correct_prediction = tf.nn.in_top_k(predictions=logits, targets=Y, k=1)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [13]:
def make_batch_step(Y, batch_size, allow_small_batch=True):
    num_points = len(Y)
    start_idx = list(range(0, num_points, batch_size))
    end_idx = list(range(batch_size, num_points+1, batch_size))
    
    if allow_small_batch:
        start_idx.append(end_idx[-1] + 1)
        end_idx.append(num_points)
    
    return zip(start_idx, end_idx)

# 4. Run

In [14]:
start_time = time.time()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(N_EPOCHS):
        batch_step = make_batch_step(Y=Y_train,
                                        batch_size=BATCH_SIZE,
                                        allow_small_batch=True)
        step = 0

        # Train
        for start, end in batch_step:
            batch_xs = X_train[start:end]
            batch_ys = Y_train[start:end]
            
            _, cost_batch, accuracy_batch = sess.run([train_op, cost, accuracy],
                                                    feed_dict={X:batch_xs, Y:batch_ys})
            if step % DISPLAY_STEP is 0:
                print("[%3d step / %3d epoch] cost = %.6f, accuracy = %.6f" % (step, epoch, cost_batch, accuracy_batch))
                
            step += 1
        
        # Validation
        if epoch % VAL_EPOCH is 0:
            cost_val, accuracy_val = sess.run([cost, accuracy], feed_dict={X:X_val, Y:Y_val})
            print()
            print("\t[%3d epoch] validation cost = %.6f, validation accuracy = %.6f" % (epoch, cost_val, accuracy_val))
            print()
        
        print()
    
    # Test
    cost_test, accuracy_test = sess.run([cost, accuracy], feed_dict={X: X_test, Y: Y_test})
    print("\n\t[Test] test cost = %.6f, test accuracy = %.6f" % (cost_test, accuracy_test))

[  0 step /   0 epoch] cost = 2.355705, accuracy = 0.078125
[ 50 step /   0 epoch] cost = 1.965081, accuracy = 0.273438
[100 step /   0 epoch] cost = 1.772215, accuracy = 0.335938
[150 step /   0 epoch] cost = 1.438617, accuracy = 0.531250
[200 step /   0 epoch] cost = 1.257221, accuracy = 0.609375
[250 step /   0 epoch] cost = 1.252578, accuracy = 0.562500
[300 step /   0 epoch] cost = 0.973298, accuracy = 0.648438
[350 step /   0 epoch] cost = 1.020648, accuracy = 0.632812
[400 step /   0 epoch] cost = 1.040885, accuracy = 0.664062

	[  0 epoch] validation cost = 0.904405, validation accuracy = 0.727200


[  0 step /   1 epoch] cost = 0.837739, accuracy = 0.734375
[ 50 step /   1 epoch] cost = 0.868781, accuracy = 0.687500
[100 step /   1 epoch] cost = 0.996840, accuracy = 0.664062
[150 step /   1 epoch] cost = 0.860663, accuracy = 0.726562
[200 step /   1 epoch] cost = 0.824177, accuracy = 0.750000
[250 step /   1 epoch] cost = 0.947818, accuracy = 0.726562
[300 step /   1 epoch] co

In [15]:
print(time.time() - start_time, "sec")

24.483845949172974 sec
