In [2]:
from __future__ import division, print_function, absolute_import

import tensorflow as tf

from tensorflow.python.framework import graph_util


# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Training Parameters
learning_rate = 0.001
num_steps = 20000
batch_size = 128
display_step = 100

# Network Parameters
num_input = 784 # MNIST data input (img shape: 28*28)
num_classes = 10 # MNIST total classes (0-9 digits)
dropout = 0.6 # Dropout, probability to keep units

# tf Graph input
X = tf.placeholder(tf.float32, [None, num_input], name="placeholder")
Y = tf.placeholder(tf.float32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32, name="keep_prob") # dropout (keep probability)

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [3]:
def display_nodes(nodes):
    for i, node in enumerate(nodes):
        print('%d %s %s' % (i, node.name, node.op))
        [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]

In [4]:
# Create some wrappers for simplicity
def conv2d(x, W, b, strides=1):
    # Conv2D wrapper, with bias and relu activation
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)


def maxpool2d(x, k=2):
    # MaxPool2D wrapper
    return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
                          padding='SAME')


# Create model
def conv_net(x, weights, biases, dropout):
    # MNIST data input is a 1-D vector of 784 features (28*28 pixels)
    # Reshape to match picture format [Height x Width x Channel]
    # Tensor input become 4-D: [Batch Size, Height, Width, Channel]
    x = tf.reshape(x, shape=[-1, 28, 28, 1])
    
    # Convolution Layer
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    # Max Pooling (down-sampling)
    conv1 = maxpool2d(conv1, k=2)

    # Convolution Layer
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    # Max Pooling (down-sampling)
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer
    # Reshape conv2 output to fit fully connected layer input
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    # Apply Dropout
    fc1 = tf.nn.dropout(fc1, dropout)

    # Output, class prediction
    out = tf.add(tf.matmul(fc1, weights['wout']), biases['bout'], name = 'out')
    return out

# Store layers weight & bias
weights = {
    # 5x5 conv, 1 input, 32 outputs
#     'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])),
    'wc1': tf.get_variable('wc1', [5,5,1,32], initializer=tf.contrib.keras.initializers.glorot_uniform(seed=None)),
    # 5x5 conv, 32 inputs, 64 outputs
#     'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])),
    'wc2': tf.get_variable('wc2', [5,5,32,64], initializer=tf.contrib.keras.initializers.glorot_uniform(seed=None)),
    # fully connected, 7*7*64 inputs, 1024 outputs
#     'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])),
    'wd1': tf.get_variable('wd1', [7*7*64,1024], initializer=tf.contrib.keras.initializers.glorot_uniform(seed=None)),
    # 1024 inputs, 10 outputs (class prediction)
#     'out': tf.Variable(tf.random_normal([1024, num_classes]))
    'wout': tf.get_variable('wout', [1024, num_classes], initializer=tf.contrib.keras.initializers.glorot_uniform(seed=None))
}

biases = {
#     'bc1': tf.Variable(tf.random_normal([32])),
    'bc1': tf.get_variable('bc1', [32], initializer=tf.zeros_initializer),
#     'bc2': tf.Variable(tf.random_normal([64])),
    'bc2': tf.get_variable('bc2', [64], initializer=tf.zeros_initializer),
#     'bd1': tf.Variable(tf.random_normal([1024])),
    'bd1': tf.get_variable('bd1', [1024], initializer=tf.zeros_initializer),
#     'out': tf.Variable(tf.random_normal([num_classes]))
    'bout': tf.get_variable('bout', [num_classes], initializer=tf.zeros_initializer)
}

# Construct model
logits = conv_net(X, weights, biases, keep_prob)
prediction = tf.nn.softmax(logits, name='output')

# Define loss and optimizer
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)


# Evaluate model
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

In [5]:
X_train = mnist.train.images[:int(55000*0.9)] -0.5
y_train = mnist.train.labels[:int(55000*0.9)]
X_val = mnist.train.images[int(55000*0.9):]-0.5
y_val = mnist.train.labels[int(55000*0.9):]
X_test = mnist.test.images-0.5
y_test = mnist.test.labels

In [6]:
index = 0
epochs = 10
# Start training
with tf.Session() as sess: 
    # Run the initializer
    sess.run(init)
    
    val_loss = float("Inf")
    val_loss_new = float("Inf")
    for epoch in range(epochs):
        index = 0
        if val_loss_new > val_loss:
            break
        else:
            val_loss = val_loss_new
        #         batch_x, batch_y = mnist.train.next_batch(batch_size)
            while(index < len(X_train)):
                if index + batch_size > len(X_train):
                    batch_x = X_train[index:]
                    batch_y = y_train[index:]
                    index += batch_size
                else:
                    batch_x = X_train[index: index + batch_size]
                    batch_y = y_train[index: index + batch_size]
                    index += batch_size
                # Run optimization op (backprop)
                sess.run(train_op, feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.6})
#         if index + batch_size > len(X_train):
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                 Y: batch_y,
                                                                 keep_prob: 1.0})
            print("Epoch " + str(epoch) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.4f}".format(acc))

            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: X_val,
                                                                 Y: y_val,
                                                                 keep_prob: 1.0})
            val_loss_new = loss
            print("Epoch " + str(epoch) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Validation Accuracy= " + \
                  "{:.4f}".format(acc))

            if loss<val_loss:
                output_names = ['output']
#                 output_names += [v.op.name for v in tf.global_variables()]
                constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)
                loss, acc = sess.run([loss_op, accuracy], feed_dict={X: X_test,
                              Y: y_test,
                              keep_prob: 1.0})
                print("Test loss:", loss)
                print("Test acc", acc)
                
                
    print("Optimization Finished!")

    # Calculate accuracy for 256 MNIST test images

#     loss, acc = sess.run([loss_op, accuracy], feed_dict={X: X_test,
#                                   Y: y_test,
#                                   keep_prob: 1.0})
#     print("Test loss:", loss)
#     print("Test acc", acc)
#     constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    with tf.gfile.FastGFile('model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())
    

Epoch 0, Minibatch Loss= 0.0151, Training Accuracy= 1.0000
Epoch 0, Minibatch Loss= 0.0513, Validation Accuracy= 0.9856
INFO:tensorflow:Froze 8 variables.
Converted 8 variables to const ops.
Test loss: 0.042307965
Test acc 0.9867
Epoch 1, Minibatch Loss= 0.0045, Training Accuracy= 1.0000
Epoch 1, Minibatch Loss= 0.0447, Validation Accuracy= 0.9873
INFO:tensorflow:Froze 8 variables.
Converted 8 variables to const ops.
Test loss: 0.032558862
Test acc 0.9894
Epoch 2, Minibatch Loss= 0.0056, Training Accuracy= 1.0000
Epoch 2, Minibatch Loss= 0.0369, Validation Accuracy= 0.9904
INFO:tensorflow:Froze 8 variables.
Converted 8 variables to const ops.
Test loss: 0.033148535
Test acc 0.989
Epoch 3, Minibatch Loss= 0.0014, Training Accuracy= 1.0000
Epoch 3, Minibatch Loss= 0.0424, Validation Accuracy= 0.9873
Optimization Finished!


In [7]:
def test_graph(graph_path, use_dropout):
    tf.reset_default_graph()
    graph_def = tf.GraphDef()
    
    with tf.gfile.FastGFile(graph_path, 'rb') as f:
        graph_def.ParseFromString(f.read())
        
    _ = tf.import_graph_def(graph_def, name='')
    sess = tf.Session()    
    prediction_tensor = sess.graph.get_tensor_by_name('final_result:0') 
    
    feed_dict = {'input:0': mnist.test.images[:256]}
    if use_dropout:
        feed_dict['keep_prob:0'] = 1.0
        
    predictions = sess.run(prediction_tensor, feed_dict)
    result = accuracy(predictions, mnist.test.labels[:256])
    return result

In [9]:
from tensorflow.core.framework import graph_pb2

In [13]:
def display_nodes(nodes):
    for i, node in enumerate(nodes):
        print('%d %s %s' % (i, node.name, node.op))
        [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]
        
graph = tf.GraphDef()
with tf.gfile.Open('./model.pb', 'rb') as f:
    data = f.read()
    graph.ParseFromString(data)
    
display_nodes(graph.node)

0 placeholder_1 Placeholder
1 keep_prob Placeholder
2 wc1 Const
3 wc1/read Identity
└─── 0 ─ wc1
4 wc2 Const
5 wc2/read Identity
└─── 0 ─ wc2
6 wd1 Const
7 wd1/read Identity
└─── 0 ─ wd1
8 wout Const
9 wout/read Identity
└─── 0 ─ wout
10 bc1 Const
11 bc1/read Identity
└─── 0 ─ bc1
12 bc2 Const
13 bc2/read Identity
└─── 0 ─ bc2
14 bd1 Const
15 bd1/read Identity
└─── 0 ─ bd1
16 bout Const
17 bout/read Identity
└─── 0 ─ bout
18 Reshape/shape Const
19 Reshape Reshape
└─── 0 ─ placeholder_1
└─── 1 ─ Reshape/shape
20 Conv2D Conv2D
└─── 0 ─ Reshape
└─── 1 ─ wc1/read
21 BiasAdd BiasAdd
└─── 0 ─ Conv2D
└─── 1 ─ bc1/read
22 Relu Relu
└─── 0 ─ BiasAdd
23 MaxPool MaxPool
└─── 0 ─ Relu
24 Conv2D_1 Conv2D
└─── 0 ─ MaxPool
└─── 1 ─ wc2/read
25 BiasAdd_1 BiasAdd
└─── 0 ─ Conv2D_1
└─── 1 ─ bc2/read
26 Relu_1 Relu
└─── 0 ─ BiasAdd_1
27 MaxPool_1 MaxPool
└─── 0 ─ Relu_1
28 Reshape_1/shape Const
29 Reshape_1 Reshape
└─── 0 ─ MaxPool_1
└─── 1 ─ Reshape_1/shape
30 MatMul MatMul
└─── 0 ─ Reshape_1
└─── 1 ─ w

In [14]:
# Connect 'MatMul_1' with 'Relu_2'
graph.node[44].input[0] = 'Relu_2' # 44 -> MatMul_1
# Remove dropout nodes
nodes = graph.node[:33] + graph.node[44:] # 33 -> MatMul_1 
del nodes[1] # 1 -> keep_prob

# Save graph
output_graph = graph_pb2.GraphDef()
output_graph.node.extend(nodes)
with tf.gfile.GFile('./model_tf_without_dropout.pb', 'w') as f:
    f.write(output_graph.SerializeToString())

In [8]:
x_new, y_new = mnist.train.next_batch(batch_size)