# Lab 3 & 4 - DNN optimization: quantization & pruning in PyTorch

In our second lab, we will focus on two main techniques for optimizing deep neural networks: quantization and pruning. We will use the `torchao` library to implement these techniques and evaluate their impact on model performance and size. Both of these techniques are crucial for deploying models on resource-constrained devices, such as mobile phones and embedded systems (especially when using FPGAs).

In [None]:
import copy
import torch

class ToyLinearModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x
    
    def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
        return (
            torch.randn(
                batch_size, self.linear1.in_features, dtype=dtype, device=device
            ),
        )

model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.float16)
model_fp16 = copy.deepcopy(model)

In [None]:
# TODO profile initial model (hint: use torch.profiler we saw in the previous lab)
# You may also define it as a function to be able to reuse it later

## Static Quantization

We can directly quantize a pre-trained model using static quantization. We can do that both with weights and activations. While in static quantization we can rather easily quantize weights, activations are more tricky. To accurately quantize activations, we almost always need to perform calibration, which involves running the model on a representative dataset to determine the optimal scaling factors (zero point, scaling factor) for each layer.

In [None]:
from torchao.quantization import Int8WeightOnlyConfig, quantize_
# THIS IS INPLACE OPERATION!
quantize_(model, Int8WeightOnlyConfig())

In [None]:
import os
torch.save(model, "/tmp/int8_model.pt")
torch.save(model_fp16, "/tmp/float16_model.pt")
int8_model_size_mb = os.path.getsize("/tmp/int8_model.pt") / 1024 / 1024
float16_model_size_mb = os.path.getsize("/tmp/float16_model.pt") / 1024 / 1024

print("int8 model size: %.2f MB" % int8_model_size_mb)

print("float16 model size: %.2f MB" % float16_model_size_mb)


In [None]:
from torchao.utils import (
    benchmark_model,
)

num_runs = 100
torch._dynamo.reset()
example_inputs = (torch.randn(8, 1024, dtype=torch.float16, device="cpu"),)
model(*example_inputs)  # warmup
model_fp16(*example_inputs)  # warmup
fp16_time = benchmark_model(model_fp16, num_runs, example_inputs)
int8_time = benchmark_model(model, num_runs, example_inputs)

print("fp16 mean time: %0.3f ms" % fp16_time)
print("int8 mean time: %0.3f ms" % int8_time)
print("speedup: %0.1fx" % (fp16_time / int8_time))

In [None]:
# TODO profile quantized model

## Calibration 

To increase the accuracy of our quantization, we will perform calibration.

### Some utilities for post-training static quantization with calibration


In [None]:
import copy
import torch.nn.functional as F

from torch import Tensor
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

class ObservedLinear(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        bias: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs

    def forward(self, input: Tensor):
        observed_input = self.act_obs(input)
        observed_weight = self.weight_obs(self.weight)
        return F.linear(observed_input, observed_weight, self.bias)

    @classmethod
    def from_float(cls, float_linear, act_obs, weight_obs):
        observed_linear = cls(
            float_linear.in_features,
            float_linear.out_features,
            act_obs,
            weight_obs,
            False,
            device=float_linear.weight.device,
            dtype=float_linear.weight.dtype,
        )
        observed_linear.weight = float_linear.weight
        observed_linear.bias = float_linear.bias
        return observed_linear
    


def insert_observers_(model, act_obs, weight_obs):
    _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

    def replacement_fn(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)


is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

In [None]:
from torchao.quantization.quant_primitives import (
    MappingType,
)
from torchao.quantization.observer import (
    AffineQuantizedMinMaxObserver,
)
from torchao.quantization.granularity import (
    PerAxis,
    PerTensor,
)

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8

act_obs = AffineQuantizedMinMaxObserver(
    mapping_type,
    target_dtype,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)
weight_obs = AffineQuantizedMinMaxObserver(
    mapping_type,
    target_dtype,
    granularity=PerAxis(axis=0),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)
insert_observers_(model_fp16, act_obs, weight_obs)


In [None]:
from dataclasses import dataclass
from torchao.quantization.transform_module import register_quantize_module_handler
from torchao.core.config import AOBaseConfig
from torchao.dtypes import to_affine_quantized_intx_static
from torchao.quantization import to_linear_activation_quantized

@dataclass
class StaticQuantConfig(AOBaseConfig):
    target_dtype: torch.dtype


# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant_transform(
    module: torch.nn.Module,
    config: StaticQuantConfig,
):
    target_dtype = config.target_dtype
    observed_linear = module

    weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

    def weight_quant_func(weight):
        block_size = (1, weight.shape[1])
        if target_dtype == torch.int8:
            return to_affine_quantized_intx_static(
                weight, weight_scale, weight_zero_point, block_size, target_dtype
            )
        raise ValueError(f"Unsupported target dtype {target_dtype}")

    linear = torch.nn.Linear(
        observed_linear.in_features,
        observed_linear.out_features,
        False,
        device=observed_linear.weight.device,
        dtype=observed_linear.weight.dtype,
    )
    linear.weight = observed_linear.weight
    linear.bias = observed_linear.bias

    linear.weight = torch.nn.Parameter(
        weight_quant_func(linear.weight), requires_grad=False
    )

    # activation quantization
    act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
    if target_dtype == torch.int8:
        input_quant_func = lambda x: to_affine_quantized_intx_static(
            x, act_scale, act_zero_point, x.shape, target_dtype
        )
    else:
        raise ValueError(f"Unsupported target dtype {target_dtype}")
    linear.weight = torch.nn.Parameter(
        to_linear_activation_quantized(linear.weight, input_quant_func),
        requires_grad=False,
    )

    return linear

#### Important! Reinitialize the initial model before running the calibrated quantization, as we have already quantized it in the previous step!

In [None]:
from torchao.quantization.utils import compute_error
torch.manual_seed(0)
model_fp16 = ToyLinearModel(1024, 1024, 1024).eval().to(torch.float16)
model_int8 = copy.deepcopy(model_fp16)
example_inputs = model_fp16.example_inputs(batch_size=32, device="cpu", dtype=torch.float16)
outputs_fp16 = model_fp16(*example_inputs)
insert_observers_(model_int8, act_obs, weight_obs)

# Run some calibration data through the model to collect statistics
with torch.no_grad():
    for _ in range(10):
        model_int8(*example_inputs)

quantize_(model_int8, StaticQuantConfig(target_dtype=target_dtype), is_observed_linear)

outputs_int8 = model_int8(*example_inputs)
print("Output error post vs pre quantization (FP16 vs INT8): %.6f" % compute_error(outputs_fp16, outputs_int8))

## TODO
1. Based on previous lab, modify the our Toy model to perform classificaton on MNIST dataset (10 classes).
2. Train the model for one epoch on MNIST dataset
3. Evaluate the model on the test set and compute accuracy
4. Perform static quantization with calibration
5. Evaluate the model after quantization
6. Compare the accuracies, latencies and model sizes before and after quantization
7. You may try to work with compilation too to see if it helps with execution speed

In [12]:
# TODO your code here

## Extra TODO
Look at the example of how to export a model into such a format that can be directly ran on embedded devices (e.g. microcontrollers).
Take a look at [this tutorial](https://docs.pytorch.org/executorch/stable/using-executorch-export.html) to see how to export a model to the `.pte` format using the `executorch` library. Then, try to save our model that way.