In [59]:
import os
import os.path as path
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
import preprocessing as pre

MODEL_NAME = 'hs_mnist_convnet'
NUM_STEPS = 6000
BATCH_SIZE = 10

def model_input(input_node_name, keep_prob_node_name):
    x = tf.placeholder(tf.float32, shape=[None, 28*28], name=input_node_name)
    keep_prob = tf.placeholder(tf.float32, name=keep_prob_node_name)
    y_ = tf.placeholder(tf.float32, shape=[None, 24])
    return x, keep_prob, y_


def build_model(x, keep_prob, y_, output_node_name):
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    # 28*28*1

    conv1 = tf.layers.conv2d(x_image, 64, 3, 1, 'same', activation=tf.nn.relu)
    # 28*28*64
    pool1 = tf.layers.max_pooling2d(conv1, 2, 2, 'same')
    # 14*14*64

    conv2 = tf.layers.conv2d(pool1, 128, 3, 1, 'same', activation=tf.nn.relu)
    # 14*14*128
    pool2 = tf.layers.max_pooling2d(conv2, 2, 2, 'same')
    # 7*7*128

    conv3 = tf.layers.conv2d(pool2, 256, 3, 1, 'same', activation=tf.nn.relu)
    # 7*7*256
    pool3 = tf.layers.max_pooling2d(conv3, 2, 2, 'same')
    # 4*4*256

    flatten = tf.reshape(pool3, [-1, 4*4*256])
    fc = tf.layers.dense(flatten, 1024, activation=tf.nn.relu)
    dropout = tf.nn.dropout(fc, keep_prob)
    logits = tf.layers.dense(dropout, 24)
    outputs = tf.nn.softmax(logits, name=output_node_name)

    # loss
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))

    # train step
    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)

    # accuracy
    correct_prediction = tf.equal(tf.argmax(outputs, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    #Prediction node
    prediction = tf.argmax(outputs,1)
    
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("accuracy", accuracy)
    merged_summary_op = tf.summary.merge_all()

    return train_step, loss, accuracy, merged_summary_op,outputs, prediction

def train(x, keep_prob, y_, train_step, loss, accuracy,
        merged_summary_op, saver,prediction):
    print("Training start...")

    #mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    hs_mnist=pre.load_split_scale_data()

    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)

        tf.train.write_graph(sess.graph_def, 'out',
            MODEL_NAME + '.pbtxt', True)

        # op to write logs to Tensorboard
        summary_writer = tf.summary.FileWriter('logs/',
            graph=tf.get_default_graph())

        for step in range(NUM_STEPS):
            batch = hs_mnist.train.next_batch(BATCH_SIZE)#mnist.train.next_batch(15)
            if step % 100 == 0:
                train_accuracy = accuracy.eval(feed_dict={
                    x: batch[0], y_: batch[1], keep_prob: 1.0})
                print('step %d, training accuracy %f' % (step, train_accuracy))
            _, summary = sess.run([train_step, merged_summary_op],
                feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
            summary_writer.add_summary(summary, step)

        saver.save(sess, 'out/' + MODEL_NAME + '.chkp')

        test_accuracy = accuracy.eval(feed_dict={x: hs_mnist.test.images,
                                    y_: hs_mnist.test.labels,
                                    keep_prob: 1.0})
        print('test accuracy %g' % test_accuracy)

    print("training finished!")
    return hs_mnist.test,init_op
    
def export_model(input_node_names, output_node_name):
    freeze_graph.freeze_graph('out/' + MODEL_NAME + '.pbtxt', None, False,
        'out/' + MODEL_NAME + '.chkp', output_node_name, "save/restore_all",
        "save/Const:0", 'out/frozen_' + MODEL_NAME + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")


In [60]:
if not path.exists('out'):
    os.mkdir('out')

input_node_name = 'input'
keep_prob_node_name = 'keep_prob'
output_node_name = 'output'

x, keep_prob, y_ = model_input(input_node_name, keep_prob_node_name)

train_step, loss, accuracy, merged_summary_op,outputs,prediction = build_model(x, keep_prob,y_, output_node_name)
saver = tf.train.Saver()

batch_test,init_op=train(x, keep_prob, y_, train_step, loss, accuracy,merged_summary_op, saver,prediction)

export_model([input_node_name, keep_prob_node_name], output_node_name)

Training start...
step 0, training accuracy 0.000000
step 100, training accuracy 0.100000
step 200, training accuracy 0.100000
step 300, training accuracy 0.100000
step 400, training accuracy 0.400000
step 500, training accuracy 0.300000
step 600, training accuracy 0.800000
step 700, training accuracy 0.500000
step 800, training accuracy 0.700000
step 900, training accuracy 0.800000
step 1000, training accuracy 0.600000
step 1100, training accuracy 0.700000
step 1200, training accuracy 0.700000
step 1300, training accuracy 0.700000
step 1400, training accuracy 1.000000
step 1500, training accuracy 1.000000
step 1600, training accuracy 0.800000
step 1700, training accuracy 0.800000
step 1800, training accuracy 0.800000
step 1900, training accuracy 0.900000
step 2000, training accuracy 1.000000
step 2100, training accuracy 0.900000
step 2200, training accuracy 1.000000
step 2300, training accuracy 0.900000
step 2400, training accuracy 0.900000
step 2500, training accuracy 0.900000
step 2

In [61]:
#tf.reset_default_graph()

In [70]:
with tf.Session() as sess:
        saver.restore(sess, "out/hs_mnist_convnet.chkp")
        print ("Model restored.")
        sess.run(init_op)#Pass a 28* 28 scaled images 1 channel
        print(prediction.eval(feed_dict={x: batch_test.images[100:200],keep_prob:1.0}))
        print(accuracy.eval(feed_dict={x: batch_test.images[100:200],y_:batch_test.labels[100:200],keep_prob:1.0}))

print(batch_test.labels[100:200])

INFO:tensorflow:Restoring parameters from out/hs_mnist_convnet.chkp
Model restored.
[22 22 22 22 13 22 22 13 22 22 13 22 13 22 22 22 13 22 22 22 22 22 22 22 22
 22 22 22 22 22 22 22 22 22 22 22 22 13 13 13 22 22 22 22 22 22 13 22 22 22
 22 13 13 13 22 22 22 13 22 13 13 13 13 22 13 22 22 22 13 22 13 22 22 22 13
 22 13 13 22 22 22 13 22 22 22 13 13 22 22 13 22 22 22 13 22 22 13 22 22 22]
0.09
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  1. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  1.  0. ...,  0.  0.  0.]]
