##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Structural pruning M by N

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_sparsity_2_by_4"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

Welcome to the guide on the structural pruning M by N.

Before reading this tutorial it is recommended to get familiar with the concept of pruning and APIs for unstructured pruning:
*  General overview of the pruning technique for the model optimization, see the [overview](https://www.tensorflow.org/model_optimization/guide/pruning).
*  Usage of API's on a single end-to-end example, see the [pruning example](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras).

In this tutorial, you will:
* Define and train a model on the mnist dataset with structural sparsity 2 by 4
* Convert the pruned model to tflite format
* Visualize structure of the pruned weights


## Setup

For finding the APIs you need and understanding purposes, you can run but skip reading this section.

In [None]:
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization
! pip install -q matplotlib

In [None]:
import tensorflow as tf

import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

from tensorflow import keras

## Define model and train with structural pruning : 2 by 4

In [None]:
# Load MNIST dataset.
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

Define parameters for pruning and specify the type of structural pruning that will be used: (2, 4).
It means that in a block of four elements, two with the lowest magnitude will be set to zero.

We don't set `pruning_schedule` parameter. By default, the pruning mask is defined at the first step and it is not updated during the training.

In [None]:
pruning_params_2_by_4 = {
    'sparsity_m_by_n': (2, 4),
}

Define parameters for unstructured pruning with the same target sparsity: 50%.

In [None]:
pruning_params_unstructured = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5,
                                                              begin_step=0,
                                                              frequency=100)
}

Define the model architecture and specify which layers to prune. Structural pruning is applied selectively to the model.

In the example below, we prune only some of the layers. We prune `Conv2D` layer with the biggest number of parameters and an internal `Dense` layer.

It is important to notice that even if we marked the first `Conv2D` layer to be structural pruned, it is not structurally pruned, because the number of input channels is 1. Therefore, we prune the first `Conv2D` layer with the unstructured pruning.


In [None]:
model = keras.Sequential([
    prune_low_magnitude(
        keras.layers.Conv2D(
            32, 5, padding='same', activation='relu',
            input_shape=(28, 28, 1),
            name="unstructured_pruning"),
        **pruning_params_unstructured),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    prune_low_magnitude(
        keras.layers.Conv2D(
            64, 5, padding='same',
            name="structural_pruning"),
        **pruning_params_2_by_4),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),
    keras.layers.Flatten(),
    prune_low_magnitude(
        keras.layers.Dense(
            1024, activation='relu',
            name="structural_pruning_dense"),
        **pruning_params_2_by_4),
    keras.layers.Dropout(0.4),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.summary()

Train and evaluate the model.

In [None]:
batch_size = 128
epochs = 2

model.fit(
    train_images,
    train_labels,
    batch_size=batch_size,
    epochs=epochs,
    verbose=0,
    callbacks=tfmot.sparsity.keras.UpdatePruningStep(),
    validation_split=0.1)

_, model_for_pruning_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print('Pruned test accuracy:', model_for_pruning_accuracy)

Strip the pruning wrapper.

In [None]:
model = tfmot.sparsity.keras.strip_pruning(model)

## Convert model to tflite format

In [None]:
import tempfile

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

_, tflite_file = tempfile.mkstemp('.tflite')
print('Saved converted pruned model to:', tflite_file)
with open(tflite_file, 'wb') as f:
  f.write(tflite_model)

## Visualize and check weights.

Now let visualize the weights structure in the `Dense` layer pruned with 2/4 sparsity. At first, we need to extract these weights from the tflite file.

In [None]:
# Load tflite file with the created pruned model
interpreter = tf.lite.Interpreter(model_path=tflite_file)
interpreter.allocate_tensors()

details = interpreter.get_tensor_details()

# Weights of the dense layer that has been pruned.
tensor_name = 'structural_pruning_dense/MatMul'
detail = [x for x in details if tensor_name in x["name"]]

# We need the first layer.
tensor_data = interpreter.tensor(detail[0]["index"])()

To check that we selected the layer that has been pruned, let us check the shape of the weight tensor.

In [None]:
print(f"Shape of Dense layer is {tensor_data.shape}")

Now we visualize the structure for a small subset of the weight tensor. The structure of the weight tensor is sparse in the last dimension and has a pattern (2,4): two elements out of four are zeros. To make visualization more clear, we replace all non-zero values with ones.

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

width = height = 24

subset_values_to_display = tensor_data[0:height, 0:width]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(subset_values_to_display) > 0, val_ones, val_zeros)

Let us define the auxiliary function to draw separation lines to see the structure clearly.

In [None]:
def plot_separation_lines(height, width):

    block_size = [1, 4]

    # Add separation lines to the figure.
    num_hlines = int((height - 1) / block_size[0])
    num_vlines = int((width - 1) / block_size[1])
    line_y_pos = [y * block_size[0] for y in range(1, num_hlines + 1)]
    line_x_pos = [x * block_size[1] for x in range(1, num_vlines + 1)]

    for y_pos in line_y_pos:
        plt.plot([-0.5, width], [y_pos - 0.5 , y_pos - 0.5], color='w')

    for x_pos in line_x_pos:
        plt.plot([x_pos - 0.5, x_pos - 0.5], [-0.5, height], color='w')

Now let us visualize the subset of the weight tensor.

In [None]:
plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Structural pruning for Dense layer")
plt.show()

Let us visualize weights for `Conv2D` layer. The structural sparsity is applied in the last channel, the same way as for `Dense` layer. Only the second `Conv2D` layer is structurally pruned as it is pointed out above.

In [None]:
# Let us get weights of the convolutional layer that has been pruned with 2/4 sparsity.
tensor_name = 'structural_pruning/Conv2D'
detail = [x for x in details if tensor_name in x["name"]]
tensor_data = interpreter.tensor(detail[1]["index"])()
print(f"Shape of the weight tensor is {tensor_data.shape}")

Similar to the weights of  `Dense` layer, the last dimension of the kernel has (2, 4) structure.

In [None]:
weights_to_display = tf.reshape(tensor_data, [tf.reduce_prod(tensor_data.shape[:-1]), -1])
weights_to_display = weights_to_display[0:width, 0:height]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(weights_to_display) > 1e-9, val_ones, val_zeros)

plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Structurally pruned weights for Conv2D layer")
plt.show()

Let's see how unstructured weights look. We extract them and display a subset of the weight tensor.

In [None]:
# Let us get weights of the convolutional layer that has been pruned with unstructured pruning.
tensor_name = 'unstructured_pruning/Conv2D'
detail = [x for x in details if tensor_name in x["name"]]
tensor_data = interpreter.tensor(detail[0]["index"])()
print(f"Shape of the weight tensor is {tensor_data.shape}")

In [None]:
weights_to_display = tf.reshape(tensor_data, [tensor_data.shape[0],tf.reduce_prod(tensor_data.shape[1:])])
weights_to_display = weights_to_display[0:width, 0:height]

val_ones = np.ones([height, width])
val_zeros = np.zeros([height, width])
subset_values_to_display = np.where(abs(weights_to_display) > 0, val_ones, val_zeros)

plot_separation_lines(height, width)

plt.axis('off')
plt.imshow(subset_values_to_display)
plt.colorbar()
plt.title("Unstructed pruned weights for Conv2D layer")
plt.show()

There is a python script included in the TensorFlow Model Optimization Toolkit that could be used to check whether which layers in the model from the given flite file have the structurally pruned weights: [`check_sparsity_m_by_n.py`](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py).