In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import cv2

In [2]:
train_dir = 'train'

In [3]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
def variable_summaries(var):
    """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
    with tf.name_scope('summaries'):
        mean = tf.reduce_mean(var)
        tf.summary.scalar('mean', mean)
        
        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        
        tf.summary.scalar('stddev', stddev)
        tf.summary.scalar('max', tf.reduce_max(var))
        tf.summary.scalar('min', tf.reduce_min(var))
        tf.summary.histogram('histogram', var)

In [5]:
with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None, 784], name='x')
    yt = tf.placeholder(tf.float32, [None, 10], name='yt')

with tf.name_scope('hidden'):
    W = tf.Variable(tf.zeros([784, 10]), name='W')
    b = tf.Variable(tf.zeros([10]), name='b')
    variable_summaries(W)
    variable_summaries(b)

with tf.name_scope('activations'):
    y = tf.nn.softmax(tf.matmul(x, W) + b, name='y')
    variable_summaries(y)

with tf.name_scope('test_images'):
    image_shaped_input = tf.placeholder(tf.float32, [None, 256, 256, 1], name='image_shaped_input')
    tf.summary.image('input', image_shaped_input, 10)

In [6]:
with tf.name_scope('cross_entropy'):
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(yt * tf.log(y), reduction_indices=[1]), name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
    
with tf.name_scope('train'):
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

with tf.name_scope('accuracy'):
    with tf.name_scope('correct_prediction'):
        correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(yt,1), name='correct_prediction')
    with tf.name_scope('accuracy'):
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

tf.summary.scalar('accuracy', accuracy)
merged = tf.summary.merge_all()

In [7]:
saver = tf.train.Saver(tf.trainable_variables())
sess = tf.InteractiveSession()

train_writer = tf.summary.FileWriter(train_dir + '/logs/train', sess.graph)
test_writer = tf.summary.FileWriter(train_dir + '/logs/test')

tf.add_to_collection('x', x)
tf.add_to_collection('yt', yt)
tf.add_to_collection('accuracy', accuracy)

In [8]:
def draw_labels(images, labels):
    labels = labels.argmax(1)

    font = cv2.FONT_HERSHEY_SIMPLEX
    bottomLeftCornerOfText = (0,28)
    fontScale = 1
    fontColor = (255,255,255)
    lineWidth = 2
    lineType = cv2.LINE_AA

    labeled_images = []
    for idx, img in enumerate(images):
        lbl = labels[idx]
        img = img.reshape((28, 28)) * 255
        img = cv2.resize(img, (256, 256))

        cv2.putText(
            img,
            str(lbl),
            bottomLeftCornerOfText, 
            font, 
            fontScale,
            fontColor,
            lineWidth,
            lineType
        )

        labeled_images.append(img.reshape((256, 256, 1)))
    
    return np.asarray(labeled_images)

In [None]:
tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, yt: batch_ys})
    train_labels = sess.run(y, feed_dict={x: batch_xs})
    train_summary = sess.run(
        merged, 
        feed_dict={
            x: batch_xs,
            yt: batch_ys,
            image_shaped_input: draw_labels(batch_xs, train_labels)
        }
    )

    if i % 50 == 49:
        test_batch_xs, test_batch_ys = mnist.test.next_batch(100)
        
        os.makedirs(train_dir, exist_ok=True)
        checkpoint_path = '%s/models/mnist.ckpt' % train_dir
        saver.save(sess, checkpoint_path, global_step=i)
        
        train_writer.add_summary(train_summary, i)
        test_labels, acc = sess.run(
            [y, accuracy], 
            feed_dict={
                x: test_batch_xs, 
                yt: test_batch_ys
            }
        )
        test_summary = sess.run(
            merged, 
            feed_dict={
                x: test_batch_xs,
                yt: test_batch_ys,
                image_shaped_input: draw_labels(test_batch_xs, test_labels)
            }
        )
        test_writer.add_summary(test_summary, i)

In [None]:
with tf.Session() as sess:
    train_dir = 'train'
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    meta_path = '%s.meta' % latest_checkpoint

    saver = tf.train.import_meta_graph(meta_path)
    saver.restore(sess, latest_checkpoint)
    
    x = tf.get_collection('x')[0]
    yt = tf.get_collection('yt')[0]
    accuracy = tf.get_collection('accuracy')[0]

    feed_dict={x: mnist.test.images, yt: mnist.test.labels}

    print(sess.run(accuracy, feed_dict))

In [None]:
num_rows, num_cols = 10, 10
width, height = num_cols * 2, num_rows * 2
fig = plt.figure(figsize=(width, height))

for i in range(0, 100):
    ax = fig.add_subplot(num_rows, num_cols, i+1)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    image = mnist.test.images[i].reshape((28, 28))
    label = mnist.test.labels[i].argmax()
    prediction = sess.run(y, feed_dict={x: mnist.test.images[i:i+1]}).argmax()
    ax.imshow(image, cmap='Greys', interpolation='none')
    text = u'%i=%i'%(label, prediction) if label==prediction else u'%i≠%i'%(label, prediction)
    color = 'black' if label==prediction else 'red'
    ax.text(
        0, 
        0, 
        text, 
        bbox={'facecolor':'white', 'pad':5}, 
        fontdict={'size':14, 'weight': 'bold', 'color': color}
    )