In [1]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import random
import datetime

import tensorflow as tf
import numpy as np
import scipy.misc
import scipy
import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data
slim = tf.contrib.slim

In [2]:
mnist = input_data.read_data_sets('MNIST_data')
x_size, y_size = 28, 28
n_classes = 10
default_collection = 'nodes'

def timestamp():
    d = datetime.datetime.now()
    return d.strftime("%Y/%m/%d/%X")

timestamp()

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


'2016/12/10/15:03:25'

In [3]:
def random_batch_iterator(x, y, *, batch_size):
    n = x.shape[0]
    assert n == y.shape[0]
    
    while True:
        index = np.random.randint(n, size=batch_size)
        x_batch, y_batch = x[index], y[index]
        yield x_batch.copy(), y_batch.copy()
        
def batch_iterator(x, y, batch_size):
    n = x.shape[0]
    assert n == y.shape[0]
    
    for i in range(0, n, batch_size):
        x_batch, y_batch = x[i:i+batch_size], y[i:i+batch_size]        
        yield x_batch.copy(), y_batch.copy()

In [4]:
mnist.train.images.shape

(55000, 784)

In [5]:
def build_cnn(inputs, *, n_conv, conv_base, conv_mul, conv_size, pool_size, collection=default_collection):
    l = inputs
    for i in range(n_conv):
        n_filters = conv_base * conv_mul ** i
        l = slim.conv2d(l, n_filters, [conv_size, conv_size],
                        scope='Conv{}'.format(i+1), outputs_collections=collection)
        l = slim.max_pool2d(l, [pool_size, pool_size], scope='MaxPool{}'.format(i+1),
                            outputs_collections=collection)
    l = slim.flatten(l)
    
    l = slim.dropout(l, 0.5, scope='Dropout', outputs_collections=collection)
    l = slim.fully_connected(l, 10, activation_fn=None, scope='Output',
                             outputs_collections=collection)
    return l

def build_loss(logits, y_true):
    logloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y_true))
    return logloss

In [6]:
batch_size = 512

n_conv = 2
conv_base = 32
conv_mul = 2
conv_size = 5
pool_size = 2

graph = tf.Graph()
with graph.as_default():
    with tf.variable_scope('model') as vs:
        x_ph = tf.placeholder(tf.float32, shape=[batch_size, x_size * y_size])
        x_image = tf.reshape(x_ph, [-1, x_size, y_size, 1])
        y_ph = tf.placeholder(tf.int64, shape=[batch_size])

        logits = build_cnn(x_image, n_conv=n_conv, conv_base=conv_base, conv_mul=conv_mul,
                           conv_size=conv_size, pool_size=pool_size)

        prediction = tf.nn.softmax(logits)

        loss = build_loss(logits, y_ph)

        optimizer = tf.train.AdamOptimizer().minimize(loss)

        correct_prediction = tf.equal(tf.argmax(prediction, 1), y_ph)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
    # Code to use of tensorboard
    with tf.name_scope('summaries'):
        tf.scalar_summary('log_loss', loss)
        tf.scalar_summary('acc', accuracy)
        merged_summary = tf.merge_all_summaries()

nodes = graph.get_collection(default_collection)

In [7]:
mnist.train.images

array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)

In [17]:
n_epochs = 1000

train_iterator = random_batch_iterator(mnist.train.images, mnist.train.labels, batch_size=batch_size)
val_iterator = random_batch_iterator(mnist.validation.images, mnist.validation.labels, batch_size=batch_size)
test_iterator = batch_iterator(mnist.test.images, mnist.test.labels, batch_size=batch_size)

best_acc = 0.0
path = '/tmp/tf/' + timestamp()
with tf.Session(graph=graph) as session:
    saver = tf.train.Saver()
    tf.global_variables_initializer().run()
    print('Initialized')
    train_writer = tf.train.SummaryWriter(
        path+'/train', session.graph)
    val_writer = tf.train.SummaryWriter(
        path+'/val', session.graph)
    
    for epoch in range(n_epochs):
        x_batch, y_batch = next(train_iterator)
        _, summary, acc, l = session.run([optimizer, merged_summary, accuracy, loss],
                                         feed_dict={x_ph: x_batch, y_ph: y_batch})
        train_writer.add_summary(summary, epoch)
        
        x_batch, y_batch = next(val_iterator)
        summary, acc, l = session.run([merged_summary, accuracy, loss],
                                      feed_dict={x_ph: x_batch, y_ph: y_batch})
        val_writer.add_summary(summary, epoch)
        
        if acc > best_acc:
            best_acc = acc
            saver.save(session,"Mnist_NLA.ckpt")
        #print("Current Validation accuracy is: {:<4.2%}".format(acc))
        
    test_acc = 0
    n = 0
    for x_batch, y_batch in test_iterator:
        if len(x_batch) != batch_size:
            break
        test_acc += accuracy.eval(feed_dict={x_ph: x_batch, y_ph: y_batch})
        n += 1
    test_acc = test_acc / n
    print("The test accuracy is: {:<4.2%}".format(test_acc))

Initialized
Current Validation accuracy is: 12.70%
Current Validation accuracy is: 25.59%
Current Validation accuracy is: 36.72%
Current Validation accuracy is: 42.77%
Current Validation accuracy is: 51.56%
Current Validation accuracy is: 58.59%
Current Validation accuracy is: 61.72%
Current Validation accuracy is: 63.67%
Current Validation accuracy is: 67.38%
Current Validation accuracy is: 63.87%
Current Validation accuracy is: 71.29%
Current Validation accuracy is: 69.92%
Current Validation accuracy is: 67.19%
Current Validation accuracy is: 70.70%
Current Validation accuracy is: 70.70%
Current Validation accuracy is: 71.88%
Current Validation accuracy is: 76.17%
Current Validation accuracy is: 76.17%
Current Validation accuracy is: 73.05%
Current Validation accuracy is: 75.59%
Current Validation accuracy is: 80.66%
Current Validation accuracy is: 79.69%
Current Validation accuracy is: 78.32%
Current Validation accuracy is: 81.45%
Current Validation accuracy is: 82.03%
Current Valid

In [16]:
test_acc

0.0