In [None]:
import tensforflow as tf
import tensorflow.examples.tutorial.mnist import input_data
import tensorflow.contrib.layers import flatten

EPOCHS = 10
BATCH_SIZE = 128
mu=0
sigma = 0.1

learning_rate = 0.001

def LeNet(input):

    #normalize if input images are not X' = a +(X-Xmin)*(b-a)/(Xmax-Xmin)
    
    # input is of 2D dimension need to convert 4D (tensor)
    input = tf.reshape(input,(-1,28,28,1))
    # add 2 rows on either side of width,height to conver to 32x32x1 input
    input = tf.pad(input,[[0,0],[2,2],[2,2],[0,0]],mode="CONSTANT")
    
    wc1 = tf.Variable(tf.truncated_normal(shape=[5,5,1,6],mu=mu, stddev=sigma))
    bc1 = tf.Variable(tf.zeros(6))
    
    # conv layer 1 input=32,32,1 output = 28,28,6
    conv1 = tf.nn.conv2d(input,wc1,strides=[1,2,2,1],padding="VALID")
    conv1 = tf.add(conv1,bc1)
    
    #activation input = 28,28,6 output = 14,14,6
    conv1 = tf.nn.relu(conv1)
    
    #pooling
    conv1 = tf.nn.maxpool(conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID")
    
    #conv layer 2 input = 14,14,6 output = 10,10,16
    wc2 = tf.Variable(tf.truncated_normal([5,5,6,16],mu=mu,stddev=sigma))
    bc2 = tf.Variable(tf.zeros(16))
    
    conv2 = tf.nn.conv2d(conv1,wc2,strides=[1,2,2,1],padding="VALID")
    conv2 = tf.add(conv2,bc2)
    
    #activation
    conv2 = tf.nn.relu(conv2)
    
    # max pooling input = 10,10,16 output = 5,5,16   
    conv2 = tf.nn.maxpool(conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID')
    
    wd1 = tf.Variable(tf.truncated_normal([400,120],mu=mu,stddev=sigma))
    bd1 = tf.Variable(tf.zeros(120))
    
    # fully connected layer
    fc1 =  tf.reshape(conv2,[-1,wd1.get_shape().as_list()[0]])
    fc1 =  tf.add(tf.matmul(fc1,wd1),bd1)
    fc1 =  tf.nn.relu(fc1)
    
    wd2 = tf.Variable(tf.truncated_normal([120,84],mu=mu,stddev=sigma))
    bd2 = tf.Variable(tf.zeros(84))

    fc2 = tf.add(tf.matmul(fc1,wd2),bd2)
    fc2 = tf.nn.relu(fc2)
    
    wd3 = tf.Variable(tf.truncated_normal([84,10],mu=mu,stddev=sigma))
    bd3 = tf.Variable(tf.zeros(10))
    
    logits = tf.add(tf.matmul(fc2,wd3),bd3)
    
    return logits
    

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
x = tf.placeholder(tf.float32,[None,784])    
y = tf.placeholder(tf.float32,[None,10])

fc2 = LeNet(x)

cost = tf.reduce_mean(tf.nn.softmax.cross_entropy_with_logits(fc2))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

correct_prediction = tf.equal(tf.argmax(fc2,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

init = tf.global_variables_initializer()

with tf.Session as sess:
    sess.run(init)
    steps_per_epoch = mnist.train.num_examples//BATCH_SIZE
    num_examples = steps_per_epoch*BATCH_SIZE
    
    for epoch_iter in range(EPOCHS)    
        for step in range(steps_per_epoch):
            batch_x, batch_y = mnist.train.next_batch(BATCH_SIZE)
            sess.run(optimizer, feed_dict={x: batch_x, y:batch_y})
         
        steps_per_epoch = mnist.validation.num_examples // BATCH_SIZE
        num_examples =  steps_per_epoch*BATCH_SIZE
        for step in range(steps_per_epoch):
            batch_x, batch_y = mnist.validation.next_batch(BATCH_SIZE)
            loss = sess.run(cost, feed_dict= {x:batch_x, y:batch_y})
            acc  = sess.run(accuracy,feed_dict={x: batch_x, y: batch_y})
            total_acc += (acc*batch_x.shape[0])
            total_loss += (loss*batch_x.shape[0])
            
        print("EPOCH {} ...".format(i+1))
        print("Validation loss = {:.3f}".format(total_loss/num_examples))
        print("Validation accuracy = {:.3f}".format(total_acc/num_examples))
        print()

    # Evaluate on the test data
    steps_per_epoch = mnist.test.num_examples//BATCH_SIZE
    num_examples = steps_per_epoch*BATCH_SIZE
    for step in range(steps_per_epoch):
        batch_x, batch_y = mnist.test_next_batch(BATCH_SIZE)
        test_loss = sess.run(cost,feed_dict={x: batch_x, y: batch_y})
        test_acc  = sess.run(accuracy,feed_dict={x: batch_x, y: batch_y})
        total_acc += test_acc*batch_x.shape[0]
        total_loss += test_loss*batch_x.shape[0]
        
    print("Test loss = {:.3f}".format(total_loss/num_examples))
    print("Test accuracy = {:.3f}".format(total_acc/num_examples))