# Customizing the Sparsity Workflow

In this section we show how to write your own sparsifier and sparsity scheduler. As this is independent of the model training/inference, we will not create a model in this section

## Custom Sparsifier

Suppose you have a need for a custom sparsifier, that sets individual elements in a sparse tensor to zero if they are too far from the mean of that tensor (either too big or too small).

$$
w^\star_{ij} =
\begin{cases}
w_{ij} & \text{if}~ w_{ij} \stackrel{\rightarrow}{\approx} \mathbb{E}(w) \\
0 & \text{otherwise}
\end{cases}
$$

In practice, we can rank all elements by their distance from the mean, and ramove the furthest ones.

To implement a custom sparsifier, you need to implement two methods:

- `__init__(self, **kwargs)` constructor that would spacify the arguments required for your logic. This has to call `super().__init__(defaults=...)`. That way we make sure that the default arguments are propagated in case there are some parts of the model that don't have sparsity configurations.
- `update_mask(self, layer, **kwargs)` -- this method is where the main logic of changing a single layer is. Generally, you would want to have the same arguments to the kwargs as in the constructor. However, this is optional, as all the default configurations are passed through.

While writing the `update_mask`, you can get access to the mask that needs to be modified using `layer.parametrizations.weight[0].mask`. In addition to that if you require the original waight (non-sparsified) when you compute the next mask, you can access the original weight using `layer.parametrizations.weight.original`.

In [1]:
import torch
from torch.ao.sparsity import BaseSparsifier

class ClosestToMeanSparsifier(BaseSparsifier):
    def __init__(self, sparsity_level):
        defaults = {
            'sparsity_level': sparsity_level
        }
        super().__init__(defaults=defaults)
        
    def update_mask(self, layer, sparsity_level, **kwargs):
        # Step 1: get the weight and the mask from the parametrizations
        mask = layer.parametrizations.weight[0].mask
        weight = layer.parametrizations.weight.original
        # Step 2: implement the mask update logic
        ## Step 2a: Compute the mean and the distance
        weight_flat = weight.flatten()
        weight_mean = weight_flat.mean()
        weight_distance_to_mean = (weight_flat - weight_mean).abs()
        ## Step 2b: Rank the elements in the tensor
        _, sorted_idx = torch.sort(weight_distance_to_mean)
        threshold_idx = int(round(sparsity_level * len(sorted_idx)))
        sorted_idx = sorted_idx[:threshold_idx]
        ## Step 2c: Create a mask with the known zero elements
        new_mask = torch.ones_like(mask)
        new_mask = new_mask.flatten()
        new_mask[sorted_idx] = 0
        new_mask = new_mask.reshape(mask.shape)
        # Step 3: Reassign back to the mask
        mask.data = new_mask

In [2]:
from torch.ao import sparsity
from torch import nn


model = nn.Sequential(nn.Linear(128, 128))

sparsifier = ClosestToMeanSparsifier(sparsity_level=0.8)
sparsifier.prepare(model, config=None)
sparsifier.step()

In [3]:
sparsity_level = (model[0].weight == 0).float().mean()
print(f"Level of sparsity: {sparsity_level.item():.2%}")

Level of sparsity: 80.00%


In [4]:
# Don't forget to squash the mask
sparsifier.squash_mask()

In [5]:
model

Sequential(
  (0): Linear(in_features=128, out_features=128, bias=True)
)

## Custom Scheduler

Suppose you have a need for a custom scheduler, that sets sparsity level after some numbre of epochs, resets it to 0 after some more, and finally sets it again. Such a scheduler would have 3 epochs at which the sparsity level will either be set to 0.0 or some target sparsity.

$$
\text{current sparsity} =
\begin{cases}
0 & \text{if epoch} <e_0 \\
\text{target sparsity} & \text{if} ~e_0 < \text{epoch} < e_1\\
0 & \text{if} ~e_1 < \text{epoch} < e_2 \\
\text{target sparsity} & \text{otherwise}
\end{cases}
$$

To implement a custom scheuler, you only need to implement the method `get_sl`, which would return the list of levels of sparsity. As part of the scheduler, you have access to the following members:

- `self.last_epoch` shows what was the last epoch that the scheduler was called
- `self.base_sl` -- list of all the target sparsity levels
- `self.get_last_sl()` -- method that shows the last sparsity level that was updated

In [6]:
from torch.ao.sparsity import BaseScheduler

class OnOffSchedulerSL(BaseScheduler):
    def __init__(self, sparsifier, epoch_points, **kwargs):
        self.epoch_points = epoch_points
        super().__init__(sparsifier, **kwargs)
        
    def get_sl(self):
        if self.last_epoch < self.epoch_points[0]:
            return [0.0] * len(self.base_sl)
        elif self.epoch_points[0] <= self.last_epoch < self.epoch_points[1]:
            return self.base_sl
        elif self.epoch_points[1] <= self.last_epoch < self.epoch_points[2]:
            return [0.0] * len(self.base_sl)
        else:
            return self.base_sl

In [7]:
model = nn.Sequential(nn.Linear(128, 128))

sparsifier = ClosestToMeanSparsifier(sparsity_level=0.8)
sparsifier.prepare(model, config=None)
scheduler = OnOffSchedulerSL(sparsifier, epoch_points=[3, 6, 9])

In [8]:
for idx in range(15):
    sparsifier.step()
    scheduler.step()
    sparsity_level = (model[0].weight == 0).float().mean()
    print(f"Level of sparsity @ epoch {idx:>2}: {sparsity_level.item():0.2%}")

Level of sparsity @ epoch  0: 0.00%
Level of sparsity @ epoch  1: 0.00%
Level of sparsity @ epoch  2: 0.00%
Level of sparsity @ epoch  3: 80.00%
Level of sparsity @ epoch  4: 80.00%
Level of sparsity @ epoch  5: 80.00%
Level of sparsity @ epoch  6: 0.00%
Level of sparsity @ epoch  7: 0.00%
Level of sparsity @ epoch  8: 0.00%
Level of sparsity @ epoch  9: 80.00%
Level of sparsity @ epoch 10: 80.00%
Level of sparsity @ epoch 11: 80.00%
Level of sparsity @ epoch 12: 80.00%
Level of sparsity @ epoch 13: 80.00%
Level of sparsity @ epoch 14: 80.00%
