#### Applying weight pruning to CNN for MNIST
* without pruning, protobuff can be converted by utensor-cli; with pruning, more nodes/ops like mask are in the graph_def, one error I got when using utensor-cli is `ValueError: unsupported op type in uTensor: QuantizedMul`
* with pruning, a mask node is added with 0 and 1. The effective weights should be weights*mask=mask_weights. For example, `fc2_weights = sess.run(sess.graph.get_tensor_by_name("fc2/weights:0"))`, `fc2_weights` won't be sparse. It will be multiplied by the corresponding mask. 
* Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the strip_pruning_vars utility. This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a masked_weight tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the target_sparsity or the weight_sparsity_map hyperparameters above. <br>
`$ bazel build -c opt contrib/model_pruning:strip_pruning_vars` <br>
`$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb`
* with pruning, more weights are zeroout but the size of the tensor doens't change. So the size of the pb file won't change. But if the file is compressed (zipped), the size will reduce. 
* But will it work on uTensor? I don't think uTensor has programs to unzip compressed protobuff and special program to handle sparsed weights. I think it just load the weights normally into a matrix in RAM.. If the size of the weights doesn't change, the RAM usage will be the same as non-purned weights. Need to be confirmed. 
* NPU can process the compressed weights?

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt

In [5]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [110]:
tf.reset_default_graph()

In [111]:
img_size, num_channels, num_classes = 28, 1, 10
pooling_ksize = [1, 2, 2, 1]
pooling_strides = [1, 2, 2, 1]
x = tf.placeholder(tf.float32, shape=[None, img_size, img_size, num_channels], name='x')
y_ = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_')

In [112]:
with tf.variable_scope('conv1') as scope:
    kernel_shape = [5, 5, 1, 16]
    strides = [1, 1, 1, 1]
    bias_shape = kernel_shape[-1]
    kernel = tf.get_variable("weights", shape=kernel_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d(), 
                             dtype=tf.float32, trainable=True)
    conv = tf.nn.conv2d(x, pruning.apply_mask(kernel, scope), strides=strides, padding="SAME", name="conv_map")
    bias = tf.get_variable("bias", shape=bias_shape, initializer=tf.contrib.layers.xavier_initializer(), 
                           dtype=tf.float32, trainable=True)
    pre_activation = tf.add(conv, bias)
    conv1 = tf.nn.relu(pre_activation, name="activation")


pool1 = tf.nn.max_pool(value=conv1, ksize=pooling_ksize, strides=pooling_strides, padding="SAME", name="max_pool1")

print(x)
print(kernel)
print(conv)
print(conv1)
print(pool1)

Tensor("x:0", shape=(?, 28, 28, 1), dtype=float32)
<tf.Variable 'conv1/weights:0' shape=(5, 5, 1, 16) dtype=float32_ref>
Tensor("conv1/conv_map:0", shape=(?, 28, 28, 16), dtype=float32)
Tensor("conv1/activation:0", shape=(?, 28, 28, 16), dtype=float32)
Tensor("max_pool1:0", shape=(?, 14, 14, 16), dtype=float32)


In [113]:
with tf.variable_scope("conv2") as scope:
    kernel_shape = [5, 5, 16, 32]
    strides = [1, 1, 1, 1]
    bias_shape = kernel_shape[-1]
    kernel = tf.get_variable(name="weights", shape=kernel_shape, dtype=tf.float32,
                             initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    conv = tf.nn.conv2d(input=pool1, filters=pruning.apply_mask(kernel, scope), strides=strides, padding="SAME", name='conv_map')
    bias = tf.get_variable(name="bias", shape=bias_shape, dtype=tf.float32,
                           initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    pre_activation = tf.add(conv, bias)
    conv2 = tf.nn.relu(pre_activation, name="relu")

pool2 = tf.nn.max_pool(value=conv2, ksize=pooling_ksize, strides=pooling_strides, padding="SAME", name="max_pooling2")

print(kernel)
print(conv)
print(conv2)
print(pool2)

<tf.Variable 'conv2/weights:0' shape=(5, 5, 16, 32) dtype=float32_ref>
Tensor("conv2/conv_map:0", shape=(?, 14, 14, 32), dtype=float32)
Tensor("conv2/relu:0", shape=(?, 14, 14, 32), dtype=float32)
Tensor("max_pooling2:0", shape=(?, 7, 7, 32), dtype=float32)


In [114]:
with tf.variable_scope("conv3") as scope:
    kernel_shape = [5, 5, 32, 64]
    strides = [1, 1, 1, 1]
    bias_shape = kernel_shape[-1]
    kernel = tf.get_variable(name="weights", shape=kernel_shape, dtype=tf.float32,
                             initializer=tf.contrib.layers.xavier_initializer_conv2d(), trainable=True)
    conv = tf.nn.conv2d(input=pool2, filters=pruning.apply_mask(kernel), strides=strides, padding="SAME", name="conv_map")
    bias = tf.get_variable(name="bias", shape=bias_shape, dtype=tf.float32,
                           initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    pre_activation = tf.add(conv, bias)
    conv3 = tf.nn.relu(pre_activation, name="relu")

pool3 = tf.nn.max_pool(value=conv3, ksize=pooling_ksize, strides=pooling_strides, padding="SAME", name="max_pooling3")

print(kernel)
print(conv)
print(conv3)
print(pool3)

<tf.Variable 'conv3/weights:0' shape=(5, 5, 32, 64) dtype=float32_ref>
Tensor("conv3/conv_map:0", shape=(?, 7, 7, 64), dtype=float32)
Tensor("conv3/relu:0", shape=(?, 7, 7, 64), dtype=float32)
Tensor("max_pooling3:0", shape=(?, 4, 4, 64), dtype=float32)


In [115]:
pool3_shape = pool3.shape.as_list()
pool3_flat = tf.reshape(pool3, [-1, pool3_shape[1]*pool3_shape[2]*pool3_shape[3]], name="flatten")
print(pool3_flat)

Tensor("flatten:0", shape=(?, 1024), dtype=float32)


In [116]:
pool3_flat.shape.as_list()

[None, 1024]

In [117]:
pool3_flat.get_shape()[1].value

1024

In [118]:
with tf.variable_scope("fc1") as scope:
    weights_shape = [pool3_flat.shape.as_list()[1], 128]
    bias_shape = weights_shape[-1]
    weights = tf.get_variable(name="weights", shape=weights_shape, dtype=tf.float32,
                              initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    bias = tf.get_variable(name="bias", shape=bias_shape, dtype=tf.float32,
                           initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    fc = tf.matmul(pool3_flat, pruning.apply_mask(weights, scope), name="matmul")
    pre_activation = tf.add(fc, bias)
    fc1 = tf.nn.relu(fc, name="relu")

print(weights)
print(fc1)

<tf.Variable 'fc1/weights:0' shape=(1024, 128) dtype=float32_ref>
Tensor("fc1/relu:0", shape=(?, 128), dtype=float32)


In [119]:
with tf.variable_scope("fc2") as scope:
    weights_shape = [fc1.get_shape()[1].value, num_classes]
    bias_shape = weights_shape[-1]
    weights = tf.get_variable(name="weights", shape=weights_shape, dtype=tf.float32,
                              initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    bias = tf.get_variable(name="bias", shape=bias_shape, dtype=tf.float32,
                           initializer=tf.contrib.layers.xavier_initializer(), trainable=True)
    fc = tf.matmul(fc1, pruning.apply_mask(weights), name="matmul")
    logits = tf.add(fc, bias, name="logits")
    y_pred = tf.argmax(logits, axis=1, name="y_pred")

print(weights)
print(logits)
print(y_pred)

<tf.Variable 'fc2/weights:0' shape=(128, 10) dtype=float32_ref>
Tensor("fc2/logits:0", shape=(?, 10), dtype=float32)
Tensor("fc2/y_pred:0", shape=(?,), dtype=int64)


In [122]:
global_step = tf.contrib.framework.get_or_create_global_step()

In [124]:
with tf.name_scope("Loss"):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, 
                                                               logits=logits)
    loss = tf.reduce_mean(cross_entropy, name="cross_entropy_loss")
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss, global_step=global_step, name="train_step")
  
# Here we specify the output as "Prediction/y_pred", this will be important later
with tf.name_scope("Prediction"): 
    correct_prediction = tf.equal(y_pred, 
                                  tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")

In [125]:
print(cross_entropy)
print(loss)
print(accuracy)

Tensor("Loss/softmax_cross_entropy_with_logits/Reshape_2:0", shape=(?,), dtype=float32)
Tensor("Loss/cross_entropy_loss:0", shape=(), dtype=float32)
Tensor("Prediction/accuracy:0", shape=(), dtype=float32)


In [130]:
batch_size = 100
n_epochs = 2
n_batches = int(mnist.train.num_examples / batch_size)
end_step = n_epochs * n_batches
print(n_batches)
print(end_step)

550
1100


In [129]:
pruning_hparams = pruning.get_pruning_hparams()
print(pruning_hparams)

name=model_pruning,begin_pruning_step=0,end_pruning_step=-1,weight_sparsity_map=[''],threshold_decay=0.0,pruning_frequency=10,nbins=256,block_height=1,block_width=1,block_pooling_function=AVG,initial_sparsity=0.0,target_sparsity=0.5,sparsity_function_begin_step=0,sparsity_function_end_step=100,sparsity_function_exponent=3.0,use_tpu=False


In [131]:
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = end_step
pruning_hparams.frequency = 100
pruning_hparams.target_sparsity = 0.9

In [132]:
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
make_update_op = pruning_obj.conditional_mask_update_op()

INFO:tensorflow:Updating masks.


In [133]:
sess = tf.Session()
# Initialize the variables (i.e. assign their default value)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

In [134]:
train_loss, train_accuracy = sess.run([loss, accuracy], 
                                      feed_dict={x: mnist.train.images.reshape(mnist.train.num_examples, img_size, img_size, num_channels), 
                                                 y_: mnist.train.labels})
print('Epoch %d, training loss: %g, training accuracy: %g' % (0, train_loss, train_accuracy))
val_loss, val_accuracy = sess.run([loss, accuracy], 
                                  feed_dict={x: mnist.validation.images.reshape(mnist.validation.num_examples, img_size, img_size, num_channels),
                                             y_: mnist.validation.labels})
print('Epoch %d, validation loss: %g, validation accuracy %g' % (0, val_loss, val_accuracy))

Epoch 0, training loss: 2.34491, training accuracy: 0.103909
Epoch 0, validation loss: 2.34243, validation accuracy 0.11


In [135]:
print("Weight sparsities: ", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

Weight sparsities:  [0.0, 0.0, 0.0, 0.0, 0.0]


In [138]:
for i in range(n_epochs):
    for j in range(n_batches):
        batch_images, batch_labels = mnist.train.next_batch(batch_size)
        sess.run(train_step, feed_dict={x: batch_images.reshape(batch_size, img_size, img_size, num_channels), 
                                        y_: batch_labels})
        sess.run(make_update_op)
    if i % 2 == 0:
        train_loss, train_accuracy = sess.run([loss, accuracy], 
                                              feed_dict={x: mnist.train.images.reshape(mnist.train.num_examples, img_size, img_size, num_channels), 
                                                         y_: mnist.train.labels})
        print('Epoch %d, training loss: %g, training accuracy: %g' % (i, train_loss, train_accuracy))
        val_loss, val_accuracy = sess.run([loss, accuracy], 
                                          feed_dict={x: mnist.validation.images.reshape(mnist.validation.num_examples, img_size, img_size, num_channels),
                                                     y_: mnist.validation.labels})
        print('Epoch %d, validation loss: %g, validation accuracy %g' % (i, val_loss, val_accuracy))
        print("Weight sparsity: ", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

Epoch 0, training loss: 1.46961, training accuracy: 0.619109
Epoch 0, validation loss: 1.45551, validation accuracy 0.6352
Weight sparsity:  [0.9, 0.9, 0.9, 0.9000015, 0.90000004]


In [192]:
print('test accuracy %g' % sess.run(accuracy, 
                                    feed_dict={x: mnist.validation.images.reshape(mnist.validation.num_examples, img_size, img_size, num_channels), 
                                               y_: mnist.validation.labels}))

test accuracy 0.7616


In [142]:
saver.save(sess, './chkps_mnist_weight_prune/mnist_prune')

'./chkps_mnist_weight_prune/mnist_prune'

In [193]:
from tensorflow.python.framework.graph_util import remove_training_nodes
from tensorflow.tools.graph_transforms import TransformGraph
sub_graph_def = remove_training_nodes(sess.graph_def)

In [195]:
output_nodes_list = [sess.graph.get_operation_by_name("fc2/y_pred").name]
print(output_nodes_list)

['fc2/y_pred']


In [196]:
from tensorflow.python.framework import graph_util as gu
sub_graph_def = gu.convert_variables_to_constants(sess=sess, 
                                                  input_graph_def=sub_graph_def,
                                                  output_node_names=output_nodes_list)

INFO:tensorflow:Froze 14 variables.
INFO:tensorflow:Converted 14 variables to const ops.


In [197]:
[node.name for node in sub_graph_def.node]

['x',
 'conv1/weights',
 'conv1/mask',
 'conv1/weights/masked_weight',
 'conv1/conv_map',
 'conv1/bias',
 'conv1/Add',
 'conv1/activation',
 'max_pool1',
 'conv2/weights',
 'conv2/mask',
 'conv2/weights/masked_weight',
 'conv2/conv_map',
 'conv2/bias',
 'conv2/Add',
 'conv2/relu',
 'max_pooling2',
 'conv3/weights',
 'conv3//mask',
 'conv3/weights/masked_weight',
 'conv3/conv_map',
 'conv3/bias',
 'conv3/Add',
 'conv3/relu',
 'max_pooling3',
 'flatten/shape',
 'flatten',
 'fc1/weights',
 'fc1/mask',
 'fc1/weights/masked_weight',
 'fc1/matmul',
 'fc1/relu',
 'fc2/weights',
 'fc2/bias',
 'fc2//mask',
 'fc2/weights/masked_weight',
 'fc2/matmul',
 'fc2/logits',
 'fc2/y_pred/dimension',
 'fc2/y_pred']

In [198]:
tf.train.write_graph(sub_graph_def, "./mnist_cnn_0to9", "mnist_cnn_weight_prune.pb", as_text=False)

'./mnist_cnn_0to9/mnist_cnn_weight_prune.pb'

In [187]:
all_tensors = [op.values() for op in sess.graph.get_operations()]
print(len(all_tensors))

1113


In [189]:
all_tensors[::100]

[(<tf.Tensor 'x:0' shape=(?, 28, 28, 1) dtype=float32>,),
 (<tf.Tensor 'conv3/bias/Assign:0' shape=(64,) dtype=float32_ref>,),
 (<tf.Tensor 'Loss/softmax_cross_entropy_with_logits/concat_1/values_0:0' shape=(1,) dtype=int32>,),
 (<tf.Tensor 'gradients/conv3/weights/masked_weight_grad/Mul_1:0' shape=(5, 5, 32, 64) dtype=float32>,),
 (<tf.Tensor 'conv3/weights/Adam_1/Initializer/zeros/Const:0' shape=() dtype=float32>,),
 (<tf.Tensor 'model_pruning_2/LogicalOr:0' shape=() dtype=bool>,),
 (<tf.Tensor 'cond/model_pruning/conv3//mask_assign/Switch:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
  <tf.Tensor 'cond/model_pruning/conv3//mask_assign/Switch:1' shape=(5, 5, 32, 64) dtype=float32_ref>),
 (<tf.Tensor 'save/Assign_20:0' shape=(64,) dtype=float32_ref>,),
 (<tf.Tensor 'zero_fraction_2/fraction:0' shape=() dtype=float32>,),
 (<tf.Tensor 'zero_fraction_6/counts_to_fraction/sub:0' shape=() dtype=int64>,),
 (<tf.Tensor 'zero_fraction_10/cond/count_nonzero_1/Cast:0' shape=(5, 5, 1, 16) dtype=in

In [204]:
fc2_weights = sess.run(sess.graph.get_tensor_by_name("fc2/weights:0"))

In [205]:
fc2_weights.shape

(128, 10)

In [207]:
(fc2_weights == 0).sum()

0

In [213]:
sess.run(tf.contrib.model_pruning.get_weight_sparsity())

[0.9, 0.9, 0.9, 0.9000015, 0.90000004]

In [None]:
sess.close()