In [None]:
# This script is based on:
# https://www.tensorflow.org/get_started/mnist/pros

import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util as gu
from tensorflow.python.framework.graph_util import remove_training_nodes
from tensorflow.tools.graph_transforms import TransformGraph

# Import training data

In [None]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Define Tensorflow Model

### Fully connected 2 layer NN

In [None]:
def deepnn(x):
  with tf.name_scope("Layer1"):
    W_fc1 = weight_variable([784, 128], name='W_fc1')
    b_fc1 = bias_variable([128], name='b_fc1')
    a_fc1 = tf.add(tf.matmul(x, W_fc1), b_fc1, name="zscore")
    h_fc1 = tf.nn.relu(a_fc1)

  with tf.name_scope("Layer2"):
    W_fc2 = weight_variable([128, 64], name='W_fc2')
    b_fc2 = bias_variable([64], name='b_fc2')
    a_fc2 = tf.add(tf.matmul(h_fc1, W_fc2), b_fc2, name="zscore")
    h_fc2 = tf.nn.relu(a_fc2)
  
  with tf.name_scope("OuputLayer"):
    W_fc3 = weight_variable([64, 10], name='W_fc3')
    b_fc3 = bias_variable([10], name='b_fc3')
    y_pred = tf.add(tf.matmul(h_fc2, W_fc3), b_fc3, name="prediction")

  return y_pred


def weight_variable(shape, name):
  """weight_variable generates a weight variable of a given shape."""
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial, name)


def bias_variable(shape, name):
  """bias_variable generates a bias variable of a given shape."""
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial, name)

### Specify inputs, outputs, and a cost function

In [None]:
# Create the model
x = tf.placeholder(tf.float32, [None, 784], name="x")

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10], name="y")

# Build the graph for the deep net
y_pred = deepnn(x)

with tf.name_scope("Loss"):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_, 
                                                            logits=y_pred)
    loss = tf.reduce_mean(cross_entropy, name="cross_entropy_loss")
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, name="train_step")
  
with tf.name_scope("Prediction"): 
    correct_prediction = tf.equal(tf.argmax(y_pred, 1, name='y_pred'), 
                                  tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")

# Configure Tensorflow Session

In [None]:
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

saver = tf.train.Saver()

# Start training
sess = tf.Session()
sess.run(init)

# Train the model

In [None]:
for i in range(20000):
  batch = mnist.train.next_batch(50)
  if i % 100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={
        x: batch[0], y_: batch[1]})
    print('step %d, training accuracy %g' % (i, train_accuracy))
  sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})

## What is the final accuracy

In [None]:
print('test accuracy %g' % sess.run(accuracy, feed_dict={
    x: mnist.test.images, y_: mnist.test.labels}))

# Freeze the graph

In [None]:
saver.save(sess, "./my-model/model.ckpt")
out_nodes = [y_pred.op.name, y_.op.name, cross_entropy.op.name,
             correct_prediction.op.name, accuracy.op.name]

### Freeze Constants

In [None]:
sub_graph_def = gu.convert_variables_to_constants(sess, sess.graph_def, out_nodes)

### Remove unnecessary training nodes

In [None]:
sub_graph_def = remove_training_nodes(sub_graph_def)

### Quantize the graph

In [None]:
transformed_graph_def = TransformGraph(sub_graph_def, [],
                                           ["Prediction/y_pred"], ["quantize_weights", "quantize_nodes"])

### Save the graph to PB file

In [None]:
graph_path = tf.train.write_graph(transformed_graph_def,
                                  "./my-model", "deep_mlp.pb",
                                  as_text=False)

print('written graph to: %s' % graph_path)

In [None]:
sess.close()