Train the model in keras first to note the accuracy values, compare these with those obtained by training the same model in tensorflow. This is to ensure that there are no implementation errors.

In [91]:
import tensorflow as tf

In [95]:
# Load the mnist dataset for keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the pixel values
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1))
test_images = test_images.astype('float32') / 255

# Prepare the labels
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

In [96]:
# Design the network architecture
# conv + maxpool + conv + maxpool + dense + softmax
from tensorflow.python.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
from tensorflow.python.keras.models import Model

inputs = Input(shape=(28, 28, 1))
x = Conv2D(8, (3, 3), activation='relu')(inputs)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
outputs = Dense(10, activation='softmax')(x)

model = Model(inputs, outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=1, batch_size=128)

test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest set accuracy: ', test_acc)

Epoch 1/1

Test set accuracy:  0.926


In [88]:
# Load the mnist dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [18]:
# defines
def weight_variable(shape):
    # truncated_normal so that weights are not too far away from 0.0.
    initial = tf.truncated_normal( shape=shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    # small positive bias value so that we dont end with a lot of dead neurons using ReLU
    return tf.Variable(tf.constant(0.1, shape=shape))

In [61]:
# Design the network architecture
# conv + maxpool + conv + maxpool + Dense + Softmax
from tensorflow.python.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
from tensorflow.python.keras.models import Model

inputs = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float64, [None, 10])
labels = tf.cast(labels, tf.float32)

# Use the keras funcional API to make the syntax simpler
train_images = tf.reshape(inputs, [-1,28,28,1])
x = Conv2D(8, (3, 3), activation='relu')(train_images)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
# outputs = Dense(10, activation='softmax')(x)
Wout = weight_variable([16, 10])
biasOut = bias_variable([10])
logits = tf.matmul(x, Wout) + biasOut
outputs = tf.nn.softmax(logits)

In [76]:
# Define loss
from tensorflow.python.keras.losses import categorical_crossentropy
cross_entropy = tf.reduce_mean(categorical_crossentropy(labels, outputs))

In [79]:
# Training 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
sess = tf.Session()
with sess.as_default():
    init_var = tf.global_variables_initializer()
    init_var.run()
    # 500 steps, little more than 1 epoch of training
    for i in range(500):
        batch = mnist.train.next_batch(128)
        train_step.run({inputs:batch[0], labels:batch[1]})

In [80]:
from tensorflow.python.keras.metrics import categorical_accuracy as accuracy

acc_value = tf.reduce_mean(accuracy(labels, outputs))
print(acc_value)
with sess.as_default():
    print(acc_value.eval(feed_dict={inputs: mnist.test.images[:500],
                                    labels: mnist.test.labels[:500]}))

Tensor("Mean_28:0", shape=(), dtype=float32)
0.52
