# Quantizing a Sparse Model (QAT)

Quantization can be combined with the sparsity toolkit (`torch.ao.sparsity`) to achieve even better computational performance. There are two main ways of quantizing a sparse model:

1. Post-Training Quantization: Quantizes a model that was already trained and sparsified
2. Quantization-Aware Training: Trains the model with quantization and sparsity in mind

In this notebook we will focus on the quantization-aware training

## Quantization-Aware Training

Quantization-Aware Training (QAT)can be used in combination with the sparsity. The key is to run the `torch.quantization.prepare_qat` **BEFORE** the `sparsifier.prepare`. The reason is that the `prepare_qat` utility replaces the layers that need to be quantized, while the sparsifier needs to keep track of the model to compute the sparsity. By keeping the order as described, we are making sure that the sparsifier has the same layers that the quatization toolflow is modifying.

### Step 1: Create a model

In [1]:
import torch
from torch import nn
import torch.quantization as tq

in_features = 7
num_classes = 10

def make_model():
    model = nn.Sequential(
        tq.QuantStub(),
        nn.Linear(in_features, 32),
        nn.ReLU(),
        nn.Linear(32, 256),
        nn.ReLU(),
        nn.Linear(256, 32),
        nn.ReLU(),
        nn.Linear(32, num_classes),
        tq.DeQuantStub()
    )
    return model

model = make_model()
print(model)

Sequential(
  (0): QuantStub()
  (1): Linear(in_features=7, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=256, bias=True)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=32, bias=True)
  (6): ReLU()
  (7): Linear(in_features=32, out_features=10, bias=True)
  (8): DeQuantStub()
)


### Step 2: Prepare the model for QAT

In [2]:
tq.prepare_qat(model, inplace=True)



Sequential(
  (0): QuantStub()
  (1): Linear(in_features=7, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=256, bias=True)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=32, bias=True)
  (6): ReLU()
  (7): Linear(in_features=32, out_features=10, bias=True)
  (8): DeQuantStub()
)

### Step 3: Create a sparsifier (and a scheduler if needed)

In [3]:
from torch.ao import sparsity

sparse_config = [
    {'module': model[1], 'sparsity_level': 0.7, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4},
    {'module': model[3], 'sparsity_level': 0.9, 'sparse_block_shape': (1, 4), 'zeros_per_block': 4},
    # The following layers will take default parameters
    model[5],
]

sparse_defaults = {
    'sparsity_level': 0.8,
    'sparse_block_shape': (1, 4),
    'zeros_per_block': 4
}

# Create a sparsifier and attach a model to it
sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)
sparsifier.prepare(model, config=sparse_config)

# Create a scheduler
def stepping_lambda(epoch):
    steps = [0.0, 0.5, 0.75, 1.0]
    if epoch >= len(steps):
        return 1.0
    return steps[epoch]
scheduler = sparsity.LambdaSL(sparsifier, stepping_lambda)

In [4]:
# Notice the model has both the parametrizations and FakeSparsity
model

Sequential(
  (0): QuantStub()
  (1): ParametrizedLinear(
    in_features=7, out_features=32, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): FakeSparsity()
      )
    )
  )
  (2): ReLU()
  (3): ParametrizedLinear(
    in_features=32, out_features=256, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): FakeSparsity()
      )
    )
  )
  (4): ReLU()
  (5): ParametrizedLinear(
    in_features=256, out_features=32, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): FakeSparsity()
      )
    )
  )
  (6): ReLU()
  (7): Linear(in_features=32, out_features=10, bias=True)
  (8): DeQuantStub()
)

### Step 4: Train the model

In [5]:
model.train()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    optimizer.zero_grad()
    
    x = torch.randn(128, in_features)
    y = torch.randint(0, num_classes, size=(128,))
    
    y_hat = model(x)
    loss = criterion(y_hat, y)
    loss.backward()
    
    optimizer.step()
    sparsifier.step()
    scheduler.step()
    
    print(f'Running epoch {epoch + 1:>2} / 20... Loss: {loss.item():.2f}')

Running epoch  1 / 20... Loss: 2.31
Running epoch  2 / 20... Loss: 2.31
Running epoch  3 / 20... Loss: 2.31
Running epoch  4 / 20... Loss: 2.30
Running epoch  5 / 20... Loss: 2.30
Running epoch  6 / 20... Loss: 2.31


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /home/zafar/Git/pytorch-dev/pytorch/aten/src/ATen/native/BinaryOps.cpp:506.)
  return torch.floor_divide(self, other)


Running epoch  7 / 20... Loss: 2.31
Running epoch  8 / 20... Loss: 2.30
Running epoch  9 / 20... Loss: 2.30
Running epoch 10 / 20... Loss: 2.30
Running epoch 11 / 20... Loss: 2.30
Running epoch 12 / 20... Loss: 2.30
Running epoch 13 / 20... Loss: 2.30
Running epoch 14 / 20... Loss: 2.31
Running epoch 15 / 20... Loss: 2.31
Running epoch 16 / 20... Loss: 2.30
Running epoch 17 / 20... Loss: 2.30
Running epoch 18 / 20... Loss: 2.30
Running epoch 19 / 20... Loss: 2.30
Running epoch 20 / 20... Loss: 2.30


### Step 5: Convert the model

The model now finished training, and we can squash the masks before converting it.

In [6]:
sparsifier.squash_mask()

In [7]:
import torch.quantization as tq
import torch.ao.nn.sparse.quantized as ao_qnn
from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern

# Step 2: Create custom mapping
#         You can also use dynamic mapping here that maps to `ao.nn.sparse.quantized.dynamic.Linear`
sparse_mapping = tq.get_default_static_quant_module_mappings()
sparse_mapping[nn.Linear] = ao_qnn.Linear

# Step 3: Convert the model
with LinearBlockSparsePattern(1, 4):
    tq.convert(model, inplace=True, mapping=sparse_mapping)

The model is now quantized and uses sparse quantized kernels

In [8]:
print(model)

Sequential(
  (0): QuantStub()
  (1): Linear(in_features=7, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=256, bias=True)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=32, bias=True)
  (6): ReLU()
  (7): Linear(in_features=32, out_features=10, bias=True)
  (8): DeQuantStub()
)


TODO: There is bug in conversion!!!