In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(".", one_hot=True, reshape=False)

import tensorflow as tf

# Parameters
learning_rate = 0.00001
epochs = 10
batch_size = 128

# Number of samples to calculate validation and accuracy
# Decrease this if you're running out of memory to calculate accuracy
test_valid_size = 256

# Network Parameters
n_classes = 10  # MNIST total classes (0-9 digits)
dropout = 0.75  # Dropout, probability to keep units

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting .\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting .\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting .\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting .\t10k-labels-idx1-ubyte.gz


In [2]:
# Store layers weight & bias
weights = {
    'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
    'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),
    'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])),
    'out': tf.Variable(tf.random_normal([1024, n_classes]))}

biases = {
    'bc1': tf.Variable(tf.random_normal([32])),
    'bc2': tf.Variable(tf.random_normal([64])),
    'bd1': tf.Variable(tf.random_normal([1024])),
    'out': tf.Variable(tf.random_normal([n_classes]))}

In [3]:
def conv2d(x, W, b, strides=1):
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

In [4]:
def maxpool2d(x, k=2):
    return tf.nn.max_pool(
        x,
        ksize=[1, k, k, 1],
        strides=[1, k, k, 1],
        padding='SAME')

In [5]:
def conv_net(x, weights, biases, dropout):
    # Layer 1 - 28*28*1 to 14*14*32
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    conv1 = maxpool2d(conv1, k=2)

    # Layer 2 - 14*14*32 to 7*7*64
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer - 7*7*64 to 1024
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    fc1 = tf.nn.dropout(fc1, dropout)

    # Output Layer - class prediction - 1024 to 10
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

In [6]:
# tf Graph input
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)

# Model
logits = conv_net(x, weights, biases, keep_prob)

# Define loss and optimizer
cost = tf.reduce_mean(\
    tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\
    .minimize(cost)

# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf. global_variables_initializer()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(epochs):
        for batch in range(mnist.train.num_examples//batch_size):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: dropout})

            # Calculate batch loss and accuracy
            loss = sess.run(cost, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: 1.})
            valid_acc = sess.run(accuracy, feed_dict={
                x: mnist.validation.images[:test_valid_size],
                y: mnist.validation.labels[:test_valid_size],
                keep_prob: 1.})

            print('Epoch {:>2}, Batch {:>3} -'
                  'Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(
                epoch + 1,
                batch + 1,
                loss,
                valid_acc))

    # Calculate Test Accuracy
    test_acc = sess.run(accuracy, feed_dict={
        x: mnist.test.images[:test_valid_size],
        y: mnist.test.labels[:test_valid_size],
        keep_prob: 1.})
    print('Testing Accuracy: {}'.format(test_acc))

Epoch  1, Batch   1 -Loss: 31337.0879 Validation Accuracy: 0.199219
Epoch  1, Batch   2 -Loss: 29685.4238 Validation Accuracy: 0.203125
Epoch  1, Batch   3 -Loss: 29255.9512 Validation Accuracy: 0.203125
Epoch  1, Batch   4 -Loss: 22676.3281 Validation Accuracy: 0.207031
Epoch  1, Batch   5 -Loss: 23378.0254 Validation Accuracy: 0.207031
Epoch  1, Batch   6 -Loss: 19348.8730 Validation Accuracy: 0.238281
Epoch  1, Batch   7 -Loss: 21145.7812 Validation Accuracy: 0.246094
Epoch  1, Batch   8 -Loss: 20410.1934 Validation Accuracy: 0.265625
Epoch  1, Batch   9 -Loss: 19539.2969 Validation Accuracy: 0.261719
Epoch  1, Batch  10 -Loss: 14472.7295 Validation Accuracy: 0.281250
Epoch  1, Batch  11 -Loss: 14790.0957 Validation Accuracy: 0.285156
Epoch  1, Batch  12 -Loss: 17214.8223 Validation Accuracy: 0.285156
Epoch  1, Batch  13 -Loss: 15030.7910 Validation Accuracy: 0.285156
Epoch  1, Batch  14 -Loss: 17579.8203 Validation Accuracy: 0.300781
Epoch  1, Batch  15 -Loss: 14707.7656 Validation

Epoch  1, Batch 122 -Loss:  4515.8306 Validation Accuracy: 0.617188
Epoch  1, Batch 123 -Loss:  3643.0850 Validation Accuracy: 0.617188
Epoch  1, Batch 124 -Loss:  4120.0225 Validation Accuracy: 0.617188
Epoch  1, Batch 125 -Loss:  3528.3799 Validation Accuracy: 0.632812
Epoch  1, Batch 126 -Loss:  3493.2769 Validation Accuracy: 0.621094
Epoch  1, Batch 127 -Loss:  4215.2764 Validation Accuracy: 0.636719
Epoch  1, Batch 128 -Loss:  4277.0835 Validation Accuracy: 0.640625
Epoch  1, Batch 129 -Loss:  4330.4805 Validation Accuracy: 0.636719
Epoch  1, Batch 130 -Loss:  5903.3906 Validation Accuracy: 0.621094
Epoch  1, Batch 131 -Loss:  3332.1428 Validation Accuracy: 0.621094
Epoch  1, Batch 132 -Loss:  3835.9099 Validation Accuracy: 0.625000
Epoch  1, Batch 133 -Loss:  3655.6282 Validation Accuracy: 0.628906
Epoch  1, Batch 134 -Loss:  4326.1523 Validation Accuracy: 0.640625
Epoch  1, Batch 135 -Loss:  4258.9463 Validation Accuracy: 0.621094
Epoch  1, Batch 136 -Loss:  3775.3381 Validation

Epoch  1, Batch 243 -Loss:  2721.5142 Validation Accuracy: 0.703125
Epoch  1, Batch 244 -Loss:  2022.3473 Validation Accuracy: 0.699219
Epoch  1, Batch 245 -Loss:  2584.5840 Validation Accuracy: 0.703125
Epoch  1, Batch 246 -Loss:  2226.0771 Validation Accuracy: 0.699219
Epoch  1, Batch 247 -Loss:  2025.4241 Validation Accuracy: 0.695312
Epoch  1, Batch 248 -Loss:  2881.6658 Validation Accuracy: 0.695312
Epoch  1, Batch 249 -Loss:  1957.9277 Validation Accuracy: 0.695312
Epoch  1, Batch 250 -Loss:  1486.7352 Validation Accuracy: 0.695312
Epoch  1, Batch 251 -Loss:  2101.7688 Validation Accuracy: 0.695312
Epoch  1, Batch 252 -Loss:  2068.1060 Validation Accuracy: 0.691406
Epoch  1, Batch 253 -Loss:  1938.5415 Validation Accuracy: 0.695312
Epoch  1, Batch 254 -Loss:  2251.4626 Validation Accuracy: 0.695312
Epoch  1, Batch 255 -Loss:  1694.7935 Validation Accuracy: 0.699219
Epoch  1, Batch 256 -Loss:  1843.2319 Validation Accuracy: 0.695312
Epoch  1, Batch 257 -Loss:  1831.4111 Validation

Epoch  1, Batch 364 -Loss:  1475.3032 Validation Accuracy: 0.738281
Epoch  1, Batch 365 -Loss:  1792.0393 Validation Accuracy: 0.730469
Epoch  1, Batch 366 -Loss:  1722.6764 Validation Accuracy: 0.730469
Epoch  1, Batch 367 -Loss:  1693.8138 Validation Accuracy: 0.730469
Epoch  1, Batch 368 -Loss:  1681.5858 Validation Accuracy: 0.730469
Epoch  1, Batch 369 -Loss:  1770.2434 Validation Accuracy: 0.734375
Epoch  1, Batch 370 -Loss:  1695.0015 Validation Accuracy: 0.734375
Epoch  1, Batch 371 -Loss:  1732.3776 Validation Accuracy: 0.726562
Epoch  1, Batch 372 -Loss:  1156.2373 Validation Accuracy: 0.718750
Epoch  1, Batch 373 -Loss:  1042.3030 Validation Accuracy: 0.722656
Epoch  1, Batch 374 -Loss:  1880.3650 Validation Accuracy: 0.718750
Epoch  1, Batch 375 -Loss:  2386.8540 Validation Accuracy: 0.718750
Epoch  1, Batch 376 -Loss:  1354.9421 Validation Accuracy: 0.722656
Epoch  1, Batch 377 -Loss:  1748.0425 Validation Accuracy: 0.730469
Epoch  1, Batch 378 -Loss:   948.6681 Validation

Epoch  2, Batch  56 -Loss:  1189.1097 Validation Accuracy: 0.765625
Epoch  2, Batch  57 -Loss:  1042.1167 Validation Accuracy: 0.769531
Epoch  2, Batch  58 -Loss:  1480.3318 Validation Accuracy: 0.777344
Epoch  2, Batch  59 -Loss:  1557.0820 Validation Accuracy: 0.777344
Epoch  2, Batch  60 -Loss:   982.5082 Validation Accuracy: 0.777344
Epoch  2, Batch  61 -Loss:  1017.6465 Validation Accuracy: 0.777344
Epoch  2, Batch  62 -Loss:   924.3066 Validation Accuracy: 0.781250
Epoch  2, Batch  63 -Loss:  1542.3916 Validation Accuracy: 0.773438
Epoch  2, Batch  64 -Loss:   870.5031 Validation Accuracy: 0.773438
Epoch  2, Batch  65 -Loss:  1064.3992 Validation Accuracy: 0.769531
Epoch  2, Batch  66 -Loss:  1132.1431 Validation Accuracy: 0.781250
Epoch  2, Batch  67 -Loss:  1132.7651 Validation Accuracy: 0.781250
Epoch  2, Batch  68 -Loss:   831.9274 Validation Accuracy: 0.785156
Epoch  2, Batch  69 -Loss:   971.4529 Validation Accuracy: 0.777344
Epoch  2, Batch  70 -Loss:  1301.4866 Validation

Epoch  2, Batch 177 -Loss:  1171.1964 Validation Accuracy: 0.796875
Epoch  2, Batch 178 -Loss:   641.1357 Validation Accuracy: 0.792969
Epoch  2, Batch 179 -Loss:  1251.2698 Validation Accuracy: 0.796875
Epoch  2, Batch 180 -Loss:   664.3895 Validation Accuracy: 0.796875
Epoch  2, Batch 181 -Loss:  1024.8413 Validation Accuracy: 0.800781
Epoch  2, Batch 182 -Loss:  1322.5944 Validation Accuracy: 0.796875
Epoch  2, Batch 183 -Loss:  1375.9254 Validation Accuracy: 0.800781
Epoch  2, Batch 184 -Loss:  1078.8390 Validation Accuracy: 0.796875
Epoch  2, Batch 185 -Loss:   926.2111 Validation Accuracy: 0.792969
Epoch  2, Batch 186 -Loss:  1489.4587 Validation Accuracy: 0.796875
Epoch  2, Batch 187 -Loss:   779.2691 Validation Accuracy: 0.792969
Epoch  2, Batch 188 -Loss:  1235.0836 Validation Accuracy: 0.792969
Epoch  2, Batch 189 -Loss:   938.0582 Validation Accuracy: 0.792969
Epoch  2, Batch 190 -Loss:  1364.9949 Validation Accuracy: 0.804688
Epoch  2, Batch 191 -Loss:  1192.6144 Validation

Epoch  2, Batch 298 -Loss:  1071.8414 Validation Accuracy: 0.816406
Epoch  2, Batch 299 -Loss:  1361.9932 Validation Accuracy: 0.812500
Epoch  2, Batch 300 -Loss:  1189.6099 Validation Accuracy: 0.812500
Epoch  2, Batch 301 -Loss:   760.7672 Validation Accuracy: 0.812500
Epoch  2, Batch 302 -Loss:  1167.2301 Validation Accuracy: 0.816406
Epoch  2, Batch 303 -Loss:   901.1097 Validation Accuracy: 0.816406
Epoch  2, Batch 304 -Loss:   782.5875 Validation Accuracy: 0.816406
Epoch  2, Batch 305 -Loss:   619.2935 Validation Accuracy: 0.816406
Epoch  2, Batch 306 -Loss:  1336.6375 Validation Accuracy: 0.808594
Epoch  2, Batch 307 -Loss:   739.9867 Validation Accuracy: 0.808594
Epoch  2, Batch 308 -Loss:   863.9836 Validation Accuracy: 0.808594
Epoch  2, Batch 309 -Loss:   947.4402 Validation Accuracy: 0.812500
Epoch  2, Batch 310 -Loss:   646.4612 Validation Accuracy: 0.812500
Epoch  2, Batch 311 -Loss:   843.8563 Validation Accuracy: 0.808594
Epoch  2, Batch 312 -Loss:   768.6362 Validation

Epoch  2, Batch 419 -Loss:   879.0887 Validation Accuracy: 0.832031
Epoch  2, Batch 420 -Loss:   989.7128 Validation Accuracy: 0.828125
Epoch  2, Batch 421 -Loss:   826.2894 Validation Accuracy: 0.832031
Epoch  2, Batch 422 -Loss:  1314.4731 Validation Accuracy: 0.828125
Epoch  2, Batch 423 -Loss:   596.0858 Validation Accuracy: 0.832031
Epoch  2, Batch 424 -Loss:   964.3960 Validation Accuracy: 0.835938
Epoch  2, Batch 425 -Loss:  1079.3396 Validation Accuracy: 0.832031
Epoch  2, Batch 426 -Loss:   751.2534 Validation Accuracy: 0.832031
Epoch  2, Batch 427 -Loss:   806.1843 Validation Accuracy: 0.832031
Epoch  2, Batch 428 -Loss:   888.8420 Validation Accuracy: 0.828125
Epoch  2, Batch 429 -Loss:   693.4893 Validation Accuracy: 0.832031
Epoch  3, Batch   1 -Loss:  1122.1191 Validation Accuracy: 0.835938
Epoch  3, Batch   2 -Loss:   434.5948 Validation Accuracy: 0.832031
Epoch  3, Batch   3 -Loss:   636.6942 Validation Accuracy: 0.835938
Epoch  3, Batch   4 -Loss:  1198.0491 Validation

KeyboardInterrupt: 