In [90]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [91]:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


## conv

In [92]:
def weight_var(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_var(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W, strides=[1, 1, 1, 1]):
    return tf.nn.conv2d(x, W, strides=strides, padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

In [108]:
tf.reset_default_graph()

# build computation graph
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
is_training = tf.placeholder(tf.bool)

x_img = tf.reshape(x, [-1, 28, 28, 1])

kernel_size = [19, 19]

with tf.name_scope('conv1'):
    conv1 = tf.layers.conv2d(
        inputs=x_img, filters=32, kernel_size=kernel_size,
        padding='same', activation=tf.nn.relu,
        name='conv1'
    )

pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=(2, 2), strides=2)

# save images
conv1_filters = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'conv1')
           if 'kernel' in var.name][0]
tf.summary.image('conv1_weights', tf.transpose(conv1_filters, (3, 0, 1, 2)), max_outputs=32)


conv2 = tf.layers.conv2d(
    inputs=pool1, filters=64, kernel_size=kernel_size,
    padding='same', activation=tf.nn.relu
)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=(2, 2), strides=2)

pool2flat = tf.reshape(pool2, [-1, 7*7*64])
dense = tf.layers.dense(inputs=pool2flat, units=1024, activation=tf.nn.relu)
dropout = tf.layers.dropout(inputs=dense, rate=0.5, training=is_training)

y = tf.layers.dense(inputs=dropout, units=10)

cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

tf.summary.scalar('cross_entropy', cross_entropy)
tf.summary.scalar('accuracy', accuracy)
merged = tf.summary.merge_all()

In [109]:
%%time
summaries_dir = 'tb/run6-19'
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sesh:
    train_writer = tf.summary.FileWriter(summaries_dir + '/train', sesh.graph)
    test_writer = tf.summary.FileWriter(summaries_dir + '/test')
    sesh.run(tf.global_variables_initializer())
    batch_size = 25
    
    test_xs, test_ys = mnist.test.images, mnist.test.labels
    
    for i in range(2000):
        batch = mnist.train.next_batch(batch_size)
        if i % 10 == 0: # record test set accuracy
            summary, acc = sesh.run([merged, accuracy], feed_dict={x: test_xs, y_: test_ys, is_training: False})
            test_writer.add_summary(summary, i)
        else:
            if i % 100 == 99: # record train set accuracy
                summary, _ = sesh.run([merged, train_step], 
                                      feed_dict={x: batch[0],
                                                 y_: batch[1],
                                                 is_training: True},
                                      )
                train_writer.add_run_metadata(run_metadata, 'step%d' % i)
                train_writer.add_summary(summary, i)
                print('Adding run metadata for', i)
            else:  # Record a summary
                summary, _ = sesh.run([merged, train_step], feed_dict={x: batch[0],
                                                 y_: batch[1],
                                                 is_training: True})
                train_writer.add_summary(summary, i)
                
test_writer.close()
train_writer.close()


Adding run metadata for 99
Adding run metadata for 199
Adding run metadata for 299
Adding run metadata for 399
Adding run metadata for 499
Adding run metadata for 599
Adding run metadata for 699
Adding run metadata for 799
Adding run metadata for 899
Adding run metadata for 999
Adding run metadata for 1099
Adding run metadata for 1199
Adding run metadata for 1299
Adding run metadata for 1399
Adding run metadata for 1499
Adding run metadata for 1599
Adding run metadata for 1699
Adding run metadata for 1799
Adding run metadata for 1899
Adding run metadata for 1999
CPU times: user 2min 32s, sys: 15.4 s, total: 2min 47s
Wall time: 2min 32s
