In [1]:
from __future__ import absolute_import, division, print_function

import tensorflow as tf
import numpy as np

In [2]:
num_classes = 10

learning_rate = 0.001
training_steps = 200
batch_size = 128
display_step = 10

conv1_filters = 32
conv2_filters = 64
fc1_units = 1024

In [3]:
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
x_train, x_test = x_train / 255., x_test / 255.

In [4]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

In [5]:
def conv2d(x, W, b, strides=1):    
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

def maxpool2d(x, k=2):    
    return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')

In [6]:
random_normal = tf.initializers.RandomNormal()

weights = {    
    'wc1': tf.Variable(random_normal([5, 5, 1, conv1_filters])),    
    'wc2': tf.Variable(random_normal([5, 5, conv1_filters, conv2_filters])),    
    'wd1': tf.Variable(random_normal([7*7*64, fc1_units])),    
    'out': tf.Variable(random_normal([fc1_units, num_classes]))
}

biases = {
    'bc1': tf.Variable(tf.zeros([conv1_filters])),
    'bc2': tf.Variable(tf.zeros([conv2_filters])),
    'bd1': tf.Variable(tf.zeros([fc1_units])),
    'out': tf.Variable(tf.zeros([num_classes]))
}

In [7]:
def conv_net(x):
        
    x = tf.reshape(x, [-1, 28, 28, 1])
    
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
        
    conv1 = maxpool2d(conv1, k=2)
    
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
        
    conv2 = maxpool2d(conv2, k=2)
    
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
        
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])    
    fc1 = tf.nn.relu(fc1)
    
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])    
    return tf.nn.softmax(out)

In [8]:
def cross_entropy(y_pred, y_true):    
    y_true = tf.one_hot(y_true, depth=num_classes)    
    y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)    
    return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))

def accuracy(y_pred, y_true):    
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)

optimizer = tf.optimizers.Adam(learning_rate)

In [9]:
def run_optimization(x, y):    
    with tf.GradientTape() as g:
        pred = conv_net(x)
        loss = cross_entropy(pred, y)
            
    trainable_variables = list(weights.values()) + list(biases.values())
    
    gradients = g.gradient(loss, trainable_variables)
        
    optimizer.apply_gradients(zip(gradients, trainable_variables))

In [10]:
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):    
    run_optimization(batch_x, batch_y)
    
    if step % display_step == 0:
        pred = conv_net(batch_x)
        loss = cross_entropy(pred, batch_y)
        acc = accuracy(pred, batch_y)
        print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

step: 10, loss: 79.027191, accuracy: 0.796875
step: 20, loss: 38.818451, accuracy: 0.898438
step: 30, loss: 36.494530, accuracy: 0.906250
step: 40, loss: 22.048792, accuracy: 0.945312
step: 50, loss: 26.919022, accuracy: 0.921875
step: 60, loss: 13.365584, accuracy: 0.960938
step: 70, loss: 18.899092, accuracy: 0.968750
step: 80, loss: 10.188875, accuracy: 0.968750
step: 90, loss: 10.616709, accuracy: 0.968750
step: 100, loss: 12.370735, accuracy: 0.968750
step: 110, loss: 16.033121, accuracy: 0.976562
step: 120, loss: 11.403316, accuracy: 0.976562
step: 130, loss: 8.424355, accuracy: 0.976562
step: 140, loss: 10.148264, accuracy: 0.976562
step: 150, loss: 11.201101, accuracy: 0.960938
step: 160, loss: 6.522997, accuracy: 0.992188
step: 170, loss: 8.217388, accuracy: 0.976562
step: 180, loss: 6.852281, accuracy: 0.984375
step: 190, loss: 9.211034, accuracy: 0.976562
step: 200, loss: 5.986761, accuracy: 0.992188
