In [4]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Contains all of the images and labels (train and test) in the MNIST_data data set
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)


# y = Wx + b
# Input to the graph, takes in any number of images (784 element pixel arrays)
x_input = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='x_input') # shape =  [None, 784] any number of image(None) with 784 value
# Weights to be multiplied by input
W = tf.Variable(initial_value=tf.zeros(shape=[784, 10]), name='W')
# Biases to be added to weights * inputs
b = tf.Variable(initial_value=tf.zeros(shape=[10]), name='b') # output from the above W would be 10 elements (Lables)
# Actual model prediction based on input and current values of W and b
y_actual = tf.add(x=tf.matmul(a=x_input, b=W, name='matmul'), y=b, name='y_actual')
# Input to enter correct answer for comparison during training
y_expected = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='y_expected')

# Cross entropy loss function because output is a list of possibilities (% certainty of the correct answer)
cross_entropy_loss = tf.reduce_mean(
    input_tensor=tf.nn.softmax_cross_entropy_with_logits(labels=y_expected, logits=y_actual),
    name='cross_entropy_loss')
# Classic gradient descent optimizer aims to minimize the difference between expected and actual values (loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01, name='optimizer')
train_step = optimizer.minimize(loss=cross_entropy_loss, name='train_step')

saver = tf.train.Saver()

# Create the session to run the nodes
session = tf.InteractiveSession()
session.run(tf.global_variables_initializer())

tf.train.write_graph(graph_or_graph_def=session.graph_def,
                     logdir='.',
                     name='mnist_model.pbtxt',
                     as_text=False)

# Train the model by fetching batches of 100 images and labels at a time and running train_step
# Run through the batches 1000 times (epochs)
for _ in range(1000):
    batch = mnist_data.train.next_batch(100)
    train_step.run(feed_dict={x_input: batch[0], y_expected: batch[1]})

saver.save(sess=session,
           save_path='mnist_model.ckpt')

# Measure accuracy by comparing the predicted values to the correct values and calculating how many of them match
correct_prediction = tf.equal(x=tf.argmax(y_actual, 1), y=tf.argmax(y_expected, 1))
accuracy = tf.reduce_mean(tf.cast(x=correct_prediction, dtype=tf.float32))
print(accuracy.eval(feed_dict={x_input: mnist_data.test.images, y_expected: mnist_data.test.labels}))

# Test a prediction on a single image
print(session.run(fetches=y_actual, feed_dict={x_input: [mnist_data.test.images[0]]}))


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.87
[[-0.2917421  -2.3547919  -0.61536527  0.112896   -0.2517305  -0.6600764
  -1.9688445   4.894759   -0.41793722  1.5528336 ]]


### Freeze Graph

In [2]:
from tensorflow.python.tools import freeze_graph, optimize_for_inference_lib

freeze_graph.freeze_graph(input_graph='mnist_model.pbtxt',
                          input_saver='',
                          input_binary=True,
                          input_checkpoint='mnist_model.ckpt',
                          output_node_names='y_actual',
                          restore_op_name='save/restore_all',
                          filename_tensor_name='save/Const:0',
                          output_graph='frozen_mnist_model.pb',
                          clear_devices=True,
                          initializer_nodes='')

input_graph_def = tf.GraphDef()
with tf.gfile.Open('frozen_mnist_model.pb', 'rb') as f:
    data = f.read()
    input_graph_def.ParseFromString(data)

output_graph_def = optimize_for_inference_lib.optimize_for_inference(input_graph_def=input_graph_def,
                                                                     input_node_names=['x_input'],
                                                                     output_node_names=['y_actual'],
                                                                     placeholder_type_enum=tf.float32.as_datatype_enum)

f = tf.gfile.FastGFile(name='optimized_frozen_mnist_model.pb',
                       mode='w')
f.write(file_content=output_graph_def.SerializeToString())


Instructions for updating:
Use tf.gfile.GFile.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from mnist_model.ckpt
Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
Instructions for updating:
Use tf.compat.v1.graph_util.remove_training_nodes
