# **Embedded AI workshop**
## **Basics of Pruning**
### *Mohammmad Ali Zamani*

*Senior Machine Learning Scientist at HITeC e.V.*

homepage: [zamani.ai](https://zamani.ai/)

This notebook is the simplified version of [Pytorch tutorial for pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html)


Pruning Tutorial
================

**Author**: [Michela Paganini](https://github.com/mickypaganini)

State-of-the-art deep learning techniques rely on over-parametrized
models that are hard to deploy. On the contrary, biological neural
networks are known to use efficient sparse connectivity. Identifying
optimal techniques to compress models by reducing the number of
parameters in them is important in order to reduce memory, battery, and
hardware consumption without sacrificing accuracy. This in turn allows
you to deploy lightweight models on device, and guarantee privacy with
private on-device computation. On the research front, pruning is used to
investigate the differences in learning dynamics between
over-parametrized and under-parametrized networks, to study the role of
lucky sparse subnetworks and initializations (\"[lottery
tickets](https://arxiv.org/abs/1803.03635)\") as a destructive neural
architecture search technique, and more.

In this tutorial, you will learn how to use `torch.nn.utils.prune` to
sparsify your neural networks, and how to extend it to implement your
own custom pruning technique.

Requirements
------------

`"torch>=1.4.0a0+8e8a5e0"`


In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

Create a model
==============

In this tutorial, we use the
[LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) architecture
from LeCun et al., 1998.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Inspect a Module
================

Let\'s inspect the (unpruned) `conv1` layer in our LeNet model. It will
contain two parameters `weight` and `bias`, and no buffers, for now.


In [None]:
model = LeNet().to(device=device)
module = model.conv1
print("weight:")
print(module.weight)
print("bias:")
print(module.bias)

Pruning a Module
================

To prune a module (in this example, the `conv1` layer of our LeNet
architecture), first select a pruning technique among those available in
`torch.nn.utils.prune` (or
[implement](#extending-torch-nn-utils-pruning-with-custom-pruning-functions)
your own by subclassing `BasePruningMethod`). Then, specify the module
and the name of the parameter to prune within that module. Finally,
using the adequate keyword arguments required by the selected pruning
technique, specify the pruning parameters.

In this example, we will prune at random 30% of the connections in the
parameter named `weight` in the `conv1` layer. The module is passed as
the first argument to the function; `name` identifies the parameter
within that module using its string identifier; and `amount` indicates
either the percentage of connections to prune (if it is a float between
0. and 1.), or the absolute number of connections to prune (if it is a
non-negative integer).


In [None]:
prune.random_unstructured(module, name="weight", amount=0.3)

Pruning only affected `weight` by removing about 30% of the parameters. The `bias` was not pruned, so it will remain intact.


In [None]:
print("weight:")
print(module.weight)
print("bias:")
print(module.bias)

For completeness, we can now prune the `bias` too, to see how the
parameters, buffers, hooks, and attributes of the `module` change. Just
for the sake of trying out another pruning technique, here we prune the
3 smallest entries in the bias by L1 norm, as implemented in the
`l1_unstructured` pruning function.


In [None]:
prune.l1_unstructured(module, name="bias", amount=3)

In [None]:
print(module.bias)

Structured Pruning
=================

We now want to prune `module.weight`,
this time using structured pruning along the 0th axis of the tensor (the
0th axis corresponds to the output channels of the convolutional layer
and has dimensionality 6 for `conv1`), based on the channels\' L2 norm.
This can be achieved using the `ln_structured` function, with `n=2` and
`dim=0`.


In [None]:
model = LeNet()
module = model.conv1
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print("after pruning:")
print(module.weight)

# TODO, prune the weights with different amount and dimension

Pruning multiple parameters in a model
======================================

By specifying the desired pruning technique and parameters, we can
easily prune multiple tensors in a network, perhaps according to their
type, as we will see in this example.


In [None]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name='weight', amount=0.2, n=2, dim=0)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(new_model.conv1.weight)
print(new_model.fc3.weight)

Global pruning
==============

So far, we only looked at what is usually referred to as \"local\"
pruning, i.e. the practice of pruning tensors in a model one by one, by
comparing the statistics (weight magnitude, activation, gradient, etc.)
of each entry exclusively to the other entries in that tensor. However,
a common and perhaps more powerful technique is to prune the model all
at once, by removing (for example) the lowest 20% of connections across
the whole model, instead of removing the lowest 20% of connections in
each layer. This is likely to result in different pruning percentages
per layer. Let\'s see how to do that using `global_unstructured` from
`torch.nn.utils.prune`.


In [None]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

Now we can check the sparsity induced in every pruned parameter, which
will not be equal to 20% in each layer. However, the global sparsity
will be (approximately) 20%.


In [None]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)