In [1]:
import tensorflow as tf
try:
    tf.logging.set_verbosity(tf.logging.ERROR)
except:
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import numpy as np

In [2]:
import sys
sys.path.append('../../../code_repos/TF_model_pruning/')

In [3]:
from model_zoo.net_frameworks.pruning import segnet_VHA_light
from model_zoo.net_blocks.pruning import conv3D_block, start_77conv_block, SE_block_3d
from pruning_utils import P_node, func_map
from pruning_utils.solver import solver_cfg, solver_lib, prune_solver

In [4]:
shape = [1,128,128,128,1]
np_in = np.random.randn(*shape)*10

## build a model

In [5]:
tf.reset_default_graph()
tf_in = tf.placeholder(tf.float32,shape)
pn_in = P_node(tf_in, y=None, is_head=True) # set up a pruning node
pn_1 = conv3D_block(pn_in,8,norm='BN')
pn_2 = conv3D_block(pn_1,8)
pn_3 = pn_1 + pn_2
pn_4 = start_77conv_block(pn_3,8)
pn_out = SE_block_3d(pn_4)
pn_out.as_output('pn_out')
graph = pn_out.graph

## read the weights in tf.Variables

In [6]:
# get tf variables
weight_dict = {}
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    # save .ckpt
    saver = tf.train.Saver()
    saver.save(sess, './ckpt/model2.ckpt')
    # read variables
    for tf_v in tf.all_variables():
        weight_dict[tf_v.name] = sess.run(tf_v)

        
# show
for k,v in weight_dict.items():
    print(k.ljust(50, ' '),v.shape)

conv3D_block/conv3d/kernel:0                       (3, 3, 3, 1, 8)
conv3D_block/conv3d/bias:0                         (8,)
conv3D_block/BatchNorm/gamma:0                     (8,)
conv3D_block/BatchNorm/beta:0                      (8,)
conv3D_block/BatchNorm/moving_mean:0               (8,)
conv3D_block/BatchNorm/moving_variance:0           (8,)
conv3D_block_1/conv3d/kernel:0                     (3, 3, 3, 8, 8)
conv3D_block_1/conv3d/bias:0                       (8,)
conv3D_block_1/GroupNorm/beta:0                    (8,)
conv3D_block_1/GroupNorm/gamma:0                   (8,)
start_77conv_block/conv3d/kernel:0                 (7, 7, 7, 8, 8)
start_77conv_block/conv3d/bias:0                   (8,)
start_77conv_block/GroupNorm/beta:0                (8,)
start_77conv_block/GroupNorm/gamma:0               (8,)
SE_block_3d/dense/kernel:0                         (8, 1)
SE_block_3d/dense/bias:0                           (1,)
SE_block_3d/dense_1/kernel:0                       (1, 8)
SE_block_3d

## setup pruning config

In [7]:
# setup configure
pruning_cfg = {
    'conv3D_block':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
    'conv3D_block_1':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
    'pn_add':{},
    'start_77conv_block':{
        'method': 'weight_mean',
        'scale': 0.5,
    },
}

## prune

In [8]:
# from pruning_utils.solver import prune_solver
# prune_solver(pn_1, weight_dict, pruning_cfg)

In [9]:
from pruning_utils.solver import prune_though_gragh
prune_though_gragh(graph, weight_dict, pruning_cfg)

Start pruning though entire graph ...
   - pruning node <conv3D_block> ...               - DONE
   - pruning node <conv3D_block_1> ...             - DONE
   - pruning node <pn_add> ...                     - SKIPED prune unable
   - pruning node <start_77conv_block> ...         - DONE
   - pruning node <SE_block_3d> ...                - SKIPED prune unable
Prune Done!


## rebuild pruned model

In [10]:
from pruning_utils.rebuild_ops import rebuild_tf_graph
tf.reset_default_graph() # remember to reset graph
shape = [1,128,128,128,1]
tf_in = tf.placeholder(tf.float32,shape)
graph_pruned = rebuild_tf_graph(tf_in, graph)
pn_out = graph_pruned.output_nodes['pn_out']

In [11]:
## load pruned weights

In [13]:
from pruning_utils.weights import load_weights_2_tfsess
load_weights_2_tfsess(weight_dict, mode='reuse')

In [None]:
# tf.all_variables()

## TODO

- [ ] auto gen configs
- [ ]  finish all solvers