In [None]:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph

from utensor_cgen.utils import prepare_meta_graph

In [None]:
print(tf.__version__)

In [None]:
def get_conv_filter(width, height, in_channels, out_channels,
                    dtype=tf.float32, initializer=None, seed=None, name=None):
    """
    arguments
    =========
    - width: int, filter width
    - height: int, filter height
    - in_channels: int, input channel
    - out_channels: int, output channel
    - dtype: data type
    - initializer: filter initializer
    - seed: random seed of the initializer
    """
    if initializer is None:
        initializer = tf.glorot_normal_initializer(seed=seed, dtype=dtype)
    filter_shape = [width, height, in_channels, out_channels]
    return tf.Variable(initializer(shape=filter_shape), name=name, dtype=dtype)

In [None]:
def get_bias(shape, dtype=tf.float32, name=None, initializer=None, seed=None):
    if initializer is None:
        initializer = tf.glorot_normal_initializer(seed=seed, dtype=dtype)
    return tf.Variable(initializer(shape=shape), name=name, dtype=dtype)

In [None]:
def conv_layer(in_fmap, w_shape, padding='SAME', stride=1, relu=True, name=None):
    width, height, in_channel, out_channel = w_shape
    strides = [1, stride, stride, 1]
    with tf.name_scope(name, 'conv'):
        w_filter = get_conv_filter(width, height, in_channel, out_channel)
        out_fmap = tf.nn.conv2d(in_fmap, w_filter, 
                                padding=padding, 
                                strides=strides,
                                name='feature_map')
        bias = get_bias(w_filter.shape.as_list()[-1:],
                        dtype=in_fmap.dtype,
                        name='bias')
        act = tf.add(out_fmap, bias, name='activation')
        if relu:
            act = tf.nn.relu(act, name='relu')
    return act

In [None]:
def fc_layer(in_tensor, out_dim, act_func=None, initializer=None, name=None):
    """Fully conneted layer
    """
    if initializer is None:
        initializer = tf.glorot_normal_initializer(dtype=in_tensor.dtype)
    if act_func is None:
        act_func = tf.nn.relu
    w_shape = [in_tensor.shape.as_list()[-1], out_dim]
    with tf.name_scope(name, 'fully_connect'):
        w_fc = tf.Variable(initializer(shape=w_shape, dtype=in_tensor.dtype), name='weight')
        logits = tf.matmul(in_tensor, w_fc, name='logit')
        act = act_func(logits, name='activation')
    return act

In [None]:
from functools import reduce

In [None]:
graph = tf.Graph()
with graph.as_default():
    tf_img_batch = tf.placeholder(tf.float32, 
                                  shape=[None, 32, 32, 3], 
                                  name='img_batch')
    tf_label_batch = tf.placeholder(tf.float32,
                                   shape=[None, 10],
                                   name='label_batch')
    relu_1_1 = conv_layer(tf_img_batch, [3, 3, 3, 64], name='conv_1_1')
    relu_1_2 = conv_layer(relu_1_1, [3, 3, 64, 64], name='conv_1_2')
    pool_1 = tf.nn.max_pool(relu_1_2, 
                            ksize=[1, 2, 2, 1],
                            strides=[1, 2, 2, 1], 
                            padding='SAME',
                            name='pool_1')
    relu_2_1 = conv_layer(pool_1, [3, 3, 64, 32], name='conv_2_1')
    relu_2_2 = conv_layer(relu_2_1, [3, 3, 32, 32], name='conv_2_2')
    pool_2 = tf.nn.max_pool(relu_2_2,
                            ksize=[1, 2, 2, 1],
                            strides=[1, 2, 2, 1],
                            padding='SAME',
                            name='pool_2')
    relu_3_1 = conv_layer(pool_2, [5, 5, 32, 32], name='conv_3_1')
    relu_3_2 = conv_layer(relu_3_1, [5, 5, 32, 32], name='conv_3_2')
    pool_3 = tf.nn.max_pool(relu_3_2,
                            ksize=[1, 2, 2, 1],
                            strides=[1, 2, 2, 1],
                            padding='SAME',
                            name='pool_3')
    N_dim = reduce(lambda x, acc: acc*x, pool_3.shape.as_list()[1:])
    flat_vec = tf.reshape(pool_3, [-1, N_dim], name='input_vec')
    fc_1 = fc_layer(flat_vec, 256, name='fc_1')
    keep_prob_1 = tf.placeholder(tf.float32, name='keep_prob_1')
    dropout_1 = tf.nn.dropout(fc_1, keep_prob=keep_prob_1, name='dropout_1')
    keep_prob_2 = tf.placeholder(tf.float32, name='keep_prob_2')
    fc_2 = fc_layer(dropout_1, 128, name='fc_2')
    dropout_2 = tf.nn.dropout(fc_2, keep_prob=keep_prob_2, name='dropout_2')
    fc_3 = fc_layer(dropout_2, 10, name='fc_3')
    
    pred_label = tf.argmax(fc_3, -1, name='pred_label')
    
    with tf.name_scope('Loss'):
        loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf_label_batch,
                                                          logits=fc_3,
                                                          name='cross_entropy')
        total_loss = tf.reduce_sum(loss, name='total_cross_entropy')
    train_op = tf.train.AdamOptimizer(1e-4).minimize(total_loss, name='train_op')        

In [None]:
!rm -rf ckpt && mkdir -p ckpt/cnn

In [None]:
with tf.Session(graph=graph) as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver()
    ckpt = saver.save(sess, 'ckpt/cnn/model')

In [None]:
graph_def = prepare_meta_graph(ckpt+'.meta', output_nodes=[pred_label.op.name])

In [None]:
with open('test_cnn_float.pb', 'wb') as fid:
    fid.write(graph_def.SerializeToString())

In [None]:
!rm -rf logs && mkdir logs

In [None]:
tf.summary.FileWriter(logdir='logs/ori_graph', graph=graph).close()

In [None]:
ckpt

In [None]:
meta_path = ckpt + '.meta'

In [None]:
for node in graph_def.node:
    print(node.name)

In [None]:
trans_graph_def = TransformGraph(input_graph_def=graph_def,
                                 inputs=[],
                                 outputs=[pred_label.op.name],
                                 transforms=["quantize_weights", "quantize_nodes"])

In [None]:
new_graph = tf.Graph()
with new_graph.as_default():
    tf.import_graph_def(trans_graph_def, name='')
tf.summary.FileWriter(logdir='logs/quant_graph', graph=new_graph).close()

In [None]:
with open('test_cnn.pb', 'wb') as fid:
    fid.write(trans_graph_def.SerializeToString())

In [None]:
from utensor_cgen.operators import OperatorFactory

In [None]:
for n in new_graph_def.node:
    if n.op not in ['Const', 'Placeholder'] and \
       n.op not in OperatorFactory._operators:
        print(n.name, n.op)
    if n.op == 'QuantizedReshape':
        node = n

In [None]:
type(node)

In [None]:
node.input

In [None]:
list(node.attr.keys())

In [None]:
node.attr['T']

In [None]:
node.op