# Sparsifying a Pre-Trained Model for Inference

In this document we show how to prune a model using the `torch.ao.sparsity` toolkit.

Before going into details, let us define the model.

In [1]:
import os

import torch
from torch import nn
from torch.ao import sparsity

In [2]:
def make_model():
    model = nn.Sequential(
        nn.Sequential(
            nn.Linear(128, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
        ),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    return model

model = make_model()

print(model)

Sequential(
  (0): Sequential(
    (0): Linear(in_features=128, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=256, bias=True)
    (3): ReLU()
  )
  (1): Linear(in_features=256, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)


Given the model above, here are the requirements for its pruning:

- `model[0][0]`:
    - `sparsity_level = 0.7`
    - `sparse_block_shape = (4, 1)`
- `model[0][2]`:
    - `sparsity_level = 0.9`
    - `sparse_block_shape = (1, 8)`
- All other `nn.Linear` layers:
    - `sparsity_level = 0.8`
    - `sparse_block_shape = (1, 4)`

In [3]:
model = make_model()

sparse_config = [
    {'module': model[0][0], 'sparsity_level': 0.7, 'sparse_block_shape': (4, 1), 'zeros_per_block': 4},
    {'module': model[0][2], 'sparsity_level': 0.9, 'sparse_block_shape': (1, 8), 'zeros_per_block': 8},
    # The following layers will take default parameters
    model[1],
    model[3]
]

sparse_defaults = {
    'sparsity_level': 0.8,
    'sparse_block_shape': (1, 4),
    'zeros_per_block': 4
}

## Step 1. Create a sparsifier

Before we can attach a model to the sparsifier, we have to create a sparsifier with the defaults arguments. Notice that although the sparsifier is instantiated, it has no layers attached to it.

In [4]:
sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)

That creates a sparsifier and gives it the defaults that it could apply to the layers that don't have sparsity configured.

## Step 2. Prepare the model for sparsification

Now that the sparsifier is instantiated, we need to attach a model to it.

**Note:** Once you `prepare` the model, it is modified by attaching weight transformations to it. If you need the original model, you need to deepcopy it.

In [5]:
sparsifier.prepare(model, config=sparse_config)

After running the `prepare`, bother the sparsifier and the model is modified, such that the sparsifier would have information about the model, while the model would have mechanisms to apply the sparsity to the appropriate layers.

In [6]:
print("===> First sparsified group:")
print(sparsifier.module_groups[0])

print()

print("===> First linear layer:")
print(model[0][0])

print()

print("===> Sparsity mask for the first linear layer:")
print(sparsifier.module_groups[0]['module'].parametrizations.weight[0].mask)

===> First sparsified group:
{'sparsity_level': 0.7, 'sparse_block_shape': (4, 1), 'zeros_per_block': 4, 'module': ParametrizedLinear(
  in_features=128, out_features=1024, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): FakeSparsity()
    )
  )
), 'fqn': '0.0'}

===> First linear layer:
ParametrizedLinear(
  in_features=128, out_features=1024, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): FakeSparsity()
    )
  )
)

===> Sparsity mask for the first linear layer:
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])


## Step 3. Take a Sparsification Step

Now that the sparsifier is aware of the model and the model has the weight parametrizations attached to it, you can take a step that will compute the sparsity masks. Although you can call the `step` method as many times as you want, in this example, we will cal it once to show how it affects the weight tensors within the model.

In [7]:
# Show the level of sparsity BEFORE step:
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        weight_sparsity_level = (layer.weight == 0).float().mean()
        sparsity_target = [mg['sparsity_level'] for mg in sparsifier.module_groups if mg['fqn'] == name][0]
        has_mask = hasattr(layer, 'parametrizations')
        print(f'Sparsity in layer {name} = {weight_sparsity_level:.2%} (target = {sparsity_target:.2%}, has_mask = {has_mask})')

Sparsity in layer 0.0 = 0.00% (target = 70.00%, has_mask = True)
Sparsity in layer 0.2 = 0.00% (target = 90.00%, has_mask = True)
Sparsity in layer 1 = 0.00% (target = 80.00%, has_mask = True)
Sparsity in layer 3 = 0.00% (target = 80.00%, has_mask = True)


In [8]:
# Take a step
sparsifier.step()

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /home/zafar/Git/pytorch-dev/pytorch/aten/src/ATen/native/BinaryOps.cpp:506.)
  return torch.floor_divide(self, other)


In [9]:
# Show the level of sparsity AFTER step:
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        weight_sparsity_level = (layer.weight == 0).float().mean()
        sparsity_target = [mg['sparsity_level'] for mg in sparsifier.module_groups if mg['fqn'] == name][0]
        print(f'Sparsity in layer {name} = {weight_sparsity_level:.2%} (target = {sparsity_target:.2%})')

Sparsity in layer 0.0 = 70.00% (target = 70.00%)
Sparsity in layer 0.2 = 90.00% (target = 90.00%)
Sparsity in layer 1 = 80.00% (target = 80.00%)
Sparsity in layer 3 = 80.00% (target = 80.00%)


## Step 4. "Squash" the Mask into the Weight

Now that the mask is computed, and the model is ready to be deployed for inference, we can get rid of the mask tensor, and "squash" it into the weight tensor.
This is achieved using the sparsifier's `.squash_mask()` method.

To demonstrate the difference that the squashing makes, we can save the model, and check its size before and after squashing.

In [10]:
# Save and check the size
torch.save(model.state_dict(), "model_before_squash.pt")
model_size = os.stat("model_before_squash.pt").st_size
print(f'Model size BEFORE squashing: {model_size / 1_000_000:.2f}MB')


Model size BEFORE squashing: 3.43MB


In [11]:
sparsifier.squash_mask()

Notice that "squashing" multiplies the mask by the weight and deletes the mask after that. If you would like to keep the mask, you would need to make a copy before squashing it.

In [12]:
# Save and check the size
torch.save(model.state_dict(), "model_after_squash.pt")
model_size = os.stat("model_after_squash.pt").st_size
print(f'Model size AFTER squashing: {model_size / 1_000_000:.2f}MB')

Model size AFTER squashing: 1.72MB


In [13]:
# Show the sparsity per layer
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        weight_sparsity_level = (layer.weight == 0).float().mean()
        has_mask = hasattr(layer, 'parametrizations')
        print(f'Sparsity in layer {name} = {weight_sparsity_level:.2%} (has mask = {has_mask})')

Sparsity in layer 0.0 = 70.00% (has mask = False)
Sparsity in layer 0.2 = 90.00% (has mask = False)
Sparsity in layer 1 = 80.00% (has mask = False)
Sparsity in layer 3 = 80.00% (has mask = False)
