In [1]:
import tensorflow as tf
try:
    tf.logging.set_verbosity(tf.logging.ERROR)
except:
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import numpy as np

In [2]:
import sys
sys.path.append('../../../code_repos/TF_model_pruning/')

In [3]:
from model_zoo.net_frameworks.pruning import segnet_VHA_light
from pruning_utils import P_node, func_map
from pruning_utils.pruning_ops import rebuild_tf_graph, rebuild_tf_graph_rc

In [4]:
shape = [1,128,128,128,1]
np_in = np.random.randn(*shape)*10

## build a network & save ckpt
use class `P_node` to creat nodes and pass though into the network.

Also, a graph was built as well. You can get it in the attr `graph` of each node in the graph.

you can use `graph.print_info()` to check the infos of graph

In [5]:
tf.reset_default_graph()

tf_in = tf.placeholder(tf.float32,shape)
pn_in = P_node(tf_in, y=None, is_head=True, ch_op_type=None)
OP, V_OP, H_OP, [U1, U2, U3,] = segnet_VHA_light(pn_in,16,is_training=False)

# get graph
graph_org = OP.graph
graph_org.print_info()

# output test
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    op_ = sess.run(
        graph_org.output_nodes['OP'].tensor_out[...,0],
        feed_dict = {tf_in:np_in}
    )
    print('output sum is ', op_.sum())
    saver = tf.train.Saver()

Total number of nodes  -> 82
Number of output nodes -> 6
    0th output key ->   OP                   network/OP/tf_identity
    1th output key ->   V_OP                 network/VslAtt/tf_nn_softmax
    2th output key ->   H_OP                 network/HrtAtt/tf_nn_softmax
    3th output key ->   U1                   network/Up1/conv3D_block_1
    4th output key ->   U2                   network/Up2/conv3D_block_1
    5th output key ->   U3                   network/Up3/conv3D_block_1
output sum is  1044418.5


## re-build the original network & load ckpt

you can use func `rebuild_tf_graph` to rebuild the network with the `graph` we built just now.

function `rebuild_tf_graph` and `rebuild_tf_graph_rc` gives the same result. You can use the one you like. `rebuild_tf_graph` is recommonded.

In [6]:
tf.reset_default_graph()
tf_in = tf.placeholder(tf.float32,shape)
graph1 = rebuild_tf_graph(tf_in, graph_org)
graph1.print_info()

# test scope_id
for x in graph1.all_nodes:
    if x.scope_id not in [x.scope_id for x in graph_org.all_nodes]:
        assert 0
        
# output test
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver()
    saver.restore(sess,'./ckpt/model.ckpt')
    op_ = sess.run(
        graph1.output_nodes['OP'].tensor_out[...,0],
        feed_dict = {tf_in:np_in}
    )
    print('output sum is ', op_.sum(), ', which is same as above')
    saver.save(sess, './ckpt/model.ckpt')

Total number of nodes  -> 82
Number of output nodes -> 6
    0th output key ->   V_OP                 network/VslAtt/tf_nn_softmax
    1th output key ->   H_OP                 network/HrtAtt/tf_nn_softmax
    2th output key ->   U3                   network/Up3/conv3D_block_1
    3th output key ->   U2                   network/Up2/conv3D_block_1
    4th output key ->   U1                   network/Up1/conv3D_block_1
    5th output key ->   OP                   network/OP/tf_identity
output sum is  1047786.3 , which is same as above


In [7]:
tf.reset_default_graph()
tf_in = tf.placeholder(tf.float32,shape)
graph2 = rebuild_tf_graph_rc(tf_in, graph_org)
graph2.print_info()

# test scope_id
for x in graph2.all_nodes:
    if x.scope_id not in [x.scope_id for x in graph_org.all_nodes]:
        assert 0
        
# output test
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver()
    saver.restore(sess,'./ckpt/model.ckpt')
    op_ = sess.run(
        graph2.output_nodes['OP'].tensor_out[...,0],
        feed_dict = {tf_in:np_in}
    )
    print('output sum is ', op_.sum(), ', which is same as above')
    saver.save(sess, './ckpt/model.ckpt')

Total number of nodes  -> 82
Number of output nodes -> 6
    0th output key ->   V_OP                 network/VslAtt/tf_nn_softmax
    1th output key ->   H_OP                 network/HrtAtt/tf_nn_softmax
    2th output key ->   U3                   network/Up3/conv3D_block_1
    3th output key ->   U2                   network/Up2/conv3D_block_1
    4th output key ->   U1                   network/Up1/conv3D_block_1
    5th output key ->   OP                   network/OP/tf_identity
output sum is  1047786.3 , which is same as above


## pruning with trained-weights(.ckpt) and graph