# Back to the basics

Let's train a simple CNN using tensorflow. Note that quantization aware training should be used to ensure that the model will perform as intended after the quantization.

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow import InteractiveSession, ConfigProto
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
tf.logging.set_verbosity(tf.logging.ERROR)

First, we have to define the model:

In [2]:
def cnn_model(x, training=False):
    x = tf.reshape(x, shape=[-1, 28, 28, 1])
    
    x = tf.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
    x = tf.layers.max_pooling2d(x, 2, 2)

    x = tf.layers.conv2d(x, 64, 3, activation=tf.nn.relu)
    x = tf.layers.max_pooling2d(x, 2, 2)

    x = tf.layers.dropout(x, rate=0.5, training=training)
    x = tf.contrib.layers.flatten(x)
    x = tf.layers.dense(x, 1024, activation=tf.nn.relu)
    x = tf.layers.dropout(x, rate=0.5, training=training)
    x = tf.layers.dense(x, 10)
    return x

Let's setup the graph!

In [3]:
sess = tf.Session()

x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
y = tf.placeholder(tf.float32, shape=[None, 10], name='label')

logits = cnn_model(x, True)
y_pred = tf.nn.softmax(logits, name='prob')
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_pred))

# Add fake quantization nodes
tf.contrib.quantize.create_training_graph(quant_delay=2000)

# Setup the optimizer
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
train_step_fine = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)

In [4]:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

In [5]:
sess.run(tf.global_variables_initializer())

In [6]:
mnist = input_data.read_data_sets('MNIST-data', one_hot=True)

for i in range(10 * 60000//128):
    batch = mnist.train.next_batch(128)
    sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})
    
    if (i + 1) % 500 == 0:
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        print("Test acc = ", acc)

for i in range(10 * 60000//128):
    batch = mnist.train.next_batch(128)
    sess.run(train_step_fine, feed_dict={x: batch[0], y: batch[1]})
    
    if (i + 1) % 500 == 0:
        acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
        print("Test acc = ", acc)
        

Extracting MNIST-data/train-images-idx3-ubyte.gz
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
Test acc =  0.9555
Test acc =  0.9725
Test acc =  0.9718
Test acc =  0.9792
Test acc =  0.978
Test acc =  0.9795
Test acc =  0.9819
Test acc =  0.9835
Test acc =  0.9797
Test acc =  0.9853
Test acc =  0.9858
Test acc =  0.9875
Test acc =  0.9871
Test acc =  0.9855
Test acc =  0.988
Test acc =  0.988
Test acc =  0.9867
Test acc =  0.9881


Save the variables!

In [7]:
saver = tf.train.Saver()
saver.save(sess, 'data/trained_model.ckpt')

'data/trained_model.ckpt'