In [1]:
import tensorflow as tf
try:
    tf.logging.set_verbosity(tf.logging.ERROR)
except:
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import numpy as np
import pickle
import os

In [2]:
import sys
sys.path.append('../../../code_repos/TF_model_pruning/')

In [3]:
from model_zoo.net_frameworks.pruning import segnet_VHA_light
from model_zoo.net_blocks.pruning import conv3D_block, start_77conv_block, SE_block_3d, upsample3D_block
from pruning_utils import P_node, func_map, save_graph, load_graph
from pruning_utils.solver import prune_though_gragh, get_complete_cfg
from pruning_utils.rebuild_ops import rebuild_tf_graph
from pruning_utils.weights import load_weights_2_tfsess

In [4]:
shape = [1,128,128,128,1]
np_in = np.random.randn(*shape)*10

## build a model

In [5]:
tf.reset_default_graph()
with tf.variable_scope('scopetest1/scopetest2'):
    tf_in = tf.placeholder(tf.float32,shape)
    pn_in = P_node(tf_in, y=None, is_head=True) # set up a pruning node
    OP, V_OP, H_OP, [U1, U2, U3,] = segnet_VHA_light(pn_in,16,is_training=False)
    graph = OP.graph

## save graph & ckpt

In [6]:
os.makedirs('./ckpt/',exist_ok=True)
os.makedirs('./graph/',exist_ok=True)

# save .ckpt
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver()
    saver.save(sess, './ckpt/model.ckpt')

# save graph
save_graph(graph,'./graph/graph.pkl')

## read weights in ckpt & run pruning

In [7]:
# setup configure
cfg = {
    'conv3D_block':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
    'start_77conv_block':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
    'SE_block_3d':{
        'method': 'weight_mean',
    },
    'upsample3D_block':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
    'va_attention':{
        'method': 'norm',
    },
}

In [8]:
tf.reset_default_graph()

reader = tf.train.NewCheckpointReader("./ckpt/model.ckpt")  
  
variables = reader.get_variable_to_shape_map()  

weight_dict = {}
for v in variables: 
    weight_dict[v+':0'] = reader.get_tensor(v)

graph = load_graph('./graph/graph.pkl')

# re-build graph
with tf.variable_scope(graph.base_scope[:-1]):
    tf_in = tf.placeholder(tf.float32,shape)
    graph = rebuild_tf_graph(tf_in, graph)

# get cfg
pruning_cfg = get_complete_cfg(graph,cfg)

prune_though_gragh(graph, weight_dict, pruning_cfg)

tf.reset_default_graph()

# save .ckpt & graph
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    load_weights_2_tfsess(weight_dict, sess, mode='new')
    saver = tf.train.Saver()
    saver.save(sess, './ckpt/model_p.ckpt')
save_graph(graph,'./graph/graph_p.pkl')

Start pruning though entire graph ...
   - pruning node <network/tf_identity> ...                            - SKIPED <prune unable>
   - pruning node <network/start_77conv_block> ...                     - DONE
   - pruning node <network/Down1/conv3D_block> ...                     - DONE
   - pruning node <network/Down1/SE_block_3d> ...                      - DONE
   - pruning node <network/Down2/tf_layers_max_pooling3d> ...          - SKIPED <prune unable>
   - pruning node <network/Down2/conv3D_block> ...                     - DONE
   - pruning node <network/Down2/SE_block_3d> ...                      - DONE
   - pruning node <network/Down3/tf_layers_max_pooling3d> ...          - SKIPED <prune unable>
   - pruning node <network/Down3/conv3D_block> ...                     - DONE
   - pruning node <network/Down3/SE_block_3d> ...                      - DONE
   - pruning node <network/Middle/tf_layers_max_pooling3d> ...         - SKIPED <prune unable>
   - pruning node <network/Middle/co

## load prunined model

In [9]:
tf.reset_default_graph()

In [10]:
graph = load_graph('./graph/graph_p.pkl')
with tf.variable_scope(graph.base_scope[:-1]):
    tf_in = tf.placeholder(tf.float32,shape)
    graph = rebuild_tf_graph(tf_in, graph)
    
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver()
    saver.restore(sess,'./ckpt/model_p.ckpt')