In [1]:
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

epochs = 1000
batch_size = 55000 # Entire training set

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 100)
logits = layers.masked_fully_connected(layer2, 10)

# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op, the global step is critical here, make sure it matches the one used in pruning later
# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Get, Print, and Edit Pruning Hyperparameters
pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)

# Change hyperparameters to meet our needs
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = 250
pruning_hparams.pruning_frequency = 1
pruning_hparams.sparsity_function_end_step = 250
pruning_hparams.target_sparsity = .9

# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
prune_op = p.conditional_mask_update_op()

saver = tf.train.Saver()


with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    """
    # Train the model before pruning (optional)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))

    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Pre-Pruning accuracy:", acc_print)
    print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    save_path = saver.save(sess, "./model_before_pruning.ckpt")
    print("Model saved in path: %s" % save_path)
    """
    
    # Reset the global step counter and begin pruning
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Prune and retrain
            sess.run(prune_op)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
            print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    # Print final accuracy
    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Final accuracy:", acc_print)
    print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
    
    save_path = saver.save(sess, "./model_after_pruning.ckpt")
    print("Model saved in path: %s" % save_path)
    

  from ._conv import register_converters as _register_converters


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Colocations handled automatically by placer.
Pruning Hyperparameters: [('begin_pruning_step', 0), ('block_height', 1), ('block_pooling_function', 'AVG'), ('block_width', 1), ('end_pruning_step', -1), ('initial_sparsity', 0.0), ('name', 'model_pruning'), ('nbins', 256), ('pr

Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 500 test accuracy 0.8281
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 510 test accuracy 0.8286
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 520 test accuracy 0.83
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 530 test accuracy 0.832
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 540 test accuracy 0.8332
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 550 test accuracy 0.8354
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 560 test accuracy 0.8361
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 570 test accuracy 0.8373
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 580 test accuracy 0.8385
Weight sparsities: [0.8998725, 0.8983334, 0.89900005]
Pruned model step 590 test accuracy 0.8393
Weight sparsities: [0.8998725, 0.