## Accelerate Inference: Neural Network Pruning

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models, regularizers
from tensorflow.keras.layers import *

print(tf.version.VERSION)

In [None]:
# untar
!ls
!tar -xvzf dataset.tar.gz
# load train
train_images = pickle.load(open('train_images.pkl', 'rb'))
train_labels = pickle.load(open('train_labels.pkl', 'rb'))
# load val
val_images = pickle.load(open('val_images.pkl', 'rb'))
val_labels = pickle.load(open('val_labels.pkl', 'rb'))

In [None]:
# Define the neural network architecture (don't change this)

model = models.Sequential()
model.add(Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5), input_shape=(25,25,3)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(1e-5)))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(5))
model.add(Activation('softmax'))

In [None]:
print(model.summary())

In [None]:
# you can use the default hyper-parameters for training,
# val accuracy ~72% after 50 epochs

model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001, weight_decay=1e-6),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, batch_size=32, epochs=50,
                    validation_data=(val_images, val_labels)) # train for 50 epochs, with batch size 32

In [None]:
original_weights = model.get_weights()
results = model.evaluate(val_images, val_labels, batch_size=128)

In [None]:
class Pruner:
  def __init__(self, model):
    self.model = model
    weights = model.get_weights()
    self.kernels = [weights[0], weights[2], weights[4], weights[6]]


  def prune(self, prune_rate):
    # for each layer, prune everything below a certain threshold
    norms = []
    for i in range(len(self.kernels)):
      weights = self.kernels[i]
      norms.append([])
      # flatten kernel, get number of filters
      num_filters, num_channels, dim_a, dim_b = np.shape(weights)
      filters = tf.cast(tf.reshape(weights, [num_filters, num_channels * dim_a * dim_b]), dtype=tf.float32)
      for f in filters:
        # take the L1 norm of each flattened filter and add to norms
        l1_norm = np.sum(tf.math.abs(f).numpy())
        norms[i].append(l1_norm)

    # get all filters below prune_rate percentile
    norms = np.array(norms)
    threshold = np.percentile(norms.flatten(), prune_rate)
    norms = tf.convert_to_tensor(norms)
    greater_filters = tf.cast(tf.math.greater(norms, threshold), dtype=tf.float32).numpy()

    # set entire filters that are not greater than threshold to 0
    conv_masks = []
    for i in range(len(self.kernels)):
      layer = self.kernels[i]
      curr_mask = []
      for j in range(layer.shape[0]):
        filter = layer[j]
        # stack everything
        curr_mask.append(greater_filters[i][j] * np.ones(filter.shape))
      curr_mask = np.array(curr_mask)
      conv_masks.append(curr_mask)

    all_weights = model.get_weights()
    all_masks = [np.ones(x.shape) for x in all_weights]
    all_masks[0] = conv_masks[0]
    all_masks[2] = conv_masks[1]
    all_masks[4] = conv_masks[2]
    all_masks[6] = conv_masks[3]
    new_weights = [tf.math.multiply(all_masks[i], all_weights[i]) for i in range(len(all_masks))]

    self.weights = new_weights
    self.masks = all_masks

    # get sparsity
    # for each layer
    self.total_parameters = 0
    num_one_weights = 0
    for mask in self.masks:
      num_one_weights += np.sum(mask)
      self.total_parameters += np.prod(np.array(mask.shape))
    self.num_zero_weights = self.total_parameters - num_one_weights


  def fine_tune(self):
    """
    training loop adapted from keras documentation:
    https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch
    """

    # Instantiate an optimizer.
    optimizer = keras.optimizers.Adam(learning_rate=1e-6, weight_decay=1e-8)
    # Instantiate a loss function.
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    # Prepare the training dataset.
    batch_size = 32
    train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

    # Prepare the validation dataset.
    val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
    val_dataset = val_dataset.batch(batch_size)

    # Prepare the metrics.
    train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
    val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

    epochs = 20
    for epoch in range(epochs):
      # Iterate over the batches of the dataset.
      for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:
            # Run the forward pass of the layer. The operations that the layer applies
            # to its inputs are going to be recorded on the GradientTape.
            logits = self.model(x_batch_train, training=True)  # Logits for this minibatch
            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, self.model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, self.model.trainable_weights))


      # add masks to the trainable_weights
      pre_masked_weights = self.model.get_weights()
      masked_weights = []
      for i in range(len(pre_masked_weights)):
        weight = pre_masked_weights[i]
        mask = self.masks[i]
        masked_weights.append(tf.math.multiply(weight, mask))

      self.weights = masked_weights
      self.model.set_weights(self.weights)

      # Run a validation loop at the end of each epoch.
      for x_batch_val, y_batch_val in val_dataset:
          val_logits = model(x_batch_val, training=False)
          # Update val metrics
          val_acc_metric.update_state(y_batch_val, val_logits)
      val_acc = val_acc_metric.result()
      val_acc_metric.reset_states()
      print("Epoch %d" % (epoch,), "Validation acc: %.4f" % (float(val_acc),))


    # count number of zero weights for each layer at the end of all training
    num_zero_weights = 0
    for weights in self.weights:
      zero_weights = tf.cast(tf.math.equal(weights, 0.0), dtype=tf.float32)
      num_zero_weights += np.sum(zero_weights.numpy())
    self.num_zero_weights = num_zero_weights
    print("sparsity:", self.num_zero_weights/self.total_parameters)


In [None]:
def test_pruning(prune_rate):
  # reset the model weights to the original values
  model.set_weights(original_weights)

  # create the pruner, prune, set weights
  pruner = Pruner(model)
  pruner.prune(prune_rate)
  model.set_weights(pruner.weights)
  pruner.fine_tune()
  model.set_weights(pruner.weights)

  # evaluate final model and get the accuracy
  results = model.evaluate(val_images, val_labels, batch_size=128)
  accuracy = results[1]

  # get the pruning score
  print("sparsity", pruner.num_zero_weights / pruner.total_parameters)
  if accuracy > 0.6 and prune_rate > 0:
    return (accuracy + pruner.num_zero_weights / pruner.total_parameters) / 2, pruner.weights
  else:
    return 0, pruner.weights


In [None]:
new_finetune_tests = [5]
for rate in new_finetune_tests:
  score_metric, finetune_weights = test_pruning(rate)
  model.set_weights(finetune_weights)
  print("\n---post fine-tuning---\n")
  print("best score:", score_metric, "\nbest prune rate:", rate)

In [None]:
# you need to save the model's weights, naming it 'my_model_weights.h5'
model.save_weights("my_model_weights_fil.h5")

# running this cell will immediately download a file called 'my_model_weights.h5'
from google.colab import files
files.download("my_model_weights_fil.h5")

In [None]:
# pareto frontier plot
import matplotlib.pyplot as plt
import numpy as np

# define data values
x = np.array([0.00048572098365245314, 0.0009714419673049063, 0.0014571629509573594,
              0.004315833323495235, 0.0066381867765835266]) * 100 # X-axis points
y = np.array([0.4317, 0.3873, 0.1921, 0.1921, 0.1921]) * 100
plt.xlabel("Sparsity (%)")
plt.ylabel("Accuracy (%)")
plt.title("Filter Pruning Pareto Frontier: Accuracy vs. Sparsity")

plt.plot(x, y, 'o')  # Plot the chart
plt.grid()
plt.show()  # display