# Quantization using Brevitas
Brevitas supports different user needs and goals when it comes to defining quantized models. It offers two main workflows:

1. Manual approach: Users can build quantized models directly by using brevitas.nn quantized layers, either from scratch or by adapting an existing PyTorch floating-point model.

2. Automated approach: Users can start with a floating-point model and automatically generate a quantized version based on custom-defined rules.

Once the quantized model is created using either method, it can be used in the following ways:

- Post-Training Quantization (PTQ): Apply quantization to a pretrained floating-point model without additional training.

- Quantization Aware Training (QAT): Train the quantized model from the beginning or fine-tune it from a pretrained floating-point model.

- PTQ followed by QAT: Start with post-training quantization and then fine-tune using QAT to achieve optimal results by combining both techniques.

## 1. Defining a quantized model with brevitas.nn layers
### Weights-only quantization, float activations and biases
Suppose we want to evaluate how well a model performs on CIFAR-10 classification using 4-bit weights. In this tutorial, we’ll skip the details of training itself, since training a model with Brevitas follows the same process as training any standard PyTorch model.

To define a quantized model, brevitas.nn offers quantized layers that can replace or be combined with regular torch.nn layers. Specifically, we use brevitas.nn.QuantConv2d and brevitas.nn.QuantLinear instead of their PyTorch counterparts, setting weight_bit_width=4 to specify 4-bit quantization. For activation functions like ReLU and operations like max-pooling, we continue using torch.nn.ReLU and torch.nn.functional.max_pool2d as usual.

In [1]:
!pip install brevitas



In [9]:
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

import brevitas.nn as qnn


class QuantWeightLeNet(Module):
    ''' LeNet-5 with weight quantization. '''
    def __init__(self):
        super(QuantWeightLeNet, self).__init__()
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = nn.ReLU()
        self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4)
        self.relu2 = nn.ReLU()
        self.fc1   = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4)
        self.relu3 = nn.ReLU()
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4)
        self.relu4 = nn.ReLU()
        self.fc3   = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4)

    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

quant_weight_lenet = QuantWeightLeNet()
print(quant_weight_lenet)
## Training........

QuantWeightLeNet(
  (conv1): QuantConv2d(
    3, 6, kernel_size=(5, 5), stride=(1, 1)
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
          (input_view_impl): Identity()
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
  

A neural network that uses 4-bit weights while keeping activations in floating-point format can help reduce model storage size. However, it doesn't offer any computational benefits, since the low-bit weights still need to be converted to floating-point during inference. To make the model more efficient and suitable for deployment on resource-constrained hardware, it's important to quantize the activations as well—this way, both storage and computation become more efficient.

## 2. Weights and activations quantization, float biases
Now, we quantize both the weights and activations to 4 bits, while keeping the biases in floating-point format. To achieve this:

- We replace torch.nn.ReLU with brevitas.nn.QuantReLU, setting bit_width=4 to specify 4-bit activation quantization.

- To quantize the input itself, we add a brevitas.nn.QuantIdentity layer at the start of the network. This ensures the input is also quantized before it flows through the rest of the model.

In [10]:
from torch.nn import Module
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant


class QuantWeightActLeNet(Module):
    ''' LeNet-5 with weight and activation quantization. '''
    def __init__(self):
        super(QuantWeightActLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
        self.relu1 = qnn.QuantReLU(bit_width=4)
        self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4)
        self.relu2 = qnn.QuantReLU(bit_width=3)
        self.fc1   = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4)
        self.relu3 = qnn.QuantReLU(bit_width=4)
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4)
        self.relu4 = qnn.QuantReLU(bit_width=4)
        self.fc3   = qnn.QuantLinear(84, 10, bias=True)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

quant_weight_act_lenet = QuantWeightActLeNet()
print(quant_weight_lenet)
# ... training ...

QuantWeightLeNet(
  (conv1): QuantConv2d(
    3, 6, kernel_size=(5, 5), stride=(1, 1)
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (output_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (weight_quant): WeightQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClampSte()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
          (input_view_impl): Identity()
        )
        (scaling_impl): StatsFromParameterScaling(
          (parameter_list_stats): _ParameterListStats(
            (first_tracked_param): _ViewParameterWrapper(
              (view_shape_impl): OverTensorView()
            )
            (stats): _Stats(
              (stats_impl): AbsMax()
            )
          )
  

Here are a few important points to keep in mind when using QuantReLU in Brevitas:

- Statefulness: By default, QuantReLU is stateful, which means that reusing a single instance across different parts of the model will behave differently compared to creating a new instance each time. Each QuantReLU instance maintains its own state (e.g., statistics for quantization), so it's usually best to instantiate separate QuantReLU layers for each activation.

- Unsigned Quantization: QuantReLU applies the ReLU operation first (which ensures non-negative outputs), and then quantizes the result. Since ReLU outputs are always ≥ 0, QuantReLU uses unsigned quantization by default. For 4-bit quantization, this means the outputs are mapped to 16 discrete levels in the integer range [0, 15].

- Dequantized Representation: In Brevitas, quantized data is internally represented in floating-point format for compatibility with PyTorch operations. Although the data has been quantized (e.g., to 4-bit resolution), it’s stored in a float tensor. So, the output of QuantReLU looks like a standard float torch.Tensor, but only takes on a limited set of values. To get a more detailed representation—including quantization metadata—you can set return_quant_tensor=True, which will return a QuantTensor object instead of a plain tensor.

### 3. Weights, activations, biases quantization


In [11]:
from torch.nn import Module
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas.quant import Int32Bias


class QuantWeightActBiasLeNet(Module):
    ''' LeNet-5 with weight, bias, and activation quantization. '''
    def __init__(self):
        super(QuantWeightActBiasLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4, return_quant_tensor=True)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc1   = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu3 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=4, bias_quant=Int32Bias)
        self.relu4 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
        self.fc3   = qnn.QuantLinear(84, 10, bias=True, weight_bit_width=4, bias_quant=Int32Bias)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

quant_weight_act_bias_lenet = QuantWeightActBiasLeNet()
print(quant_weight_act_bias_lenet)
# ... training ...

QuantWeightActBiasLeNet(
  (quant_inp): QuantIdentity(
    (input_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
    )
    (act_quant): ActQuantProxyFromInjector(
      (_zero_hw_sentinel): StatelessBuffer()
      (fused_activation_quant_proxy): FusedActivationQuantProxy(
        (activation_impl): Identity()
        (tensor_quant): RescalingIntQuant(
          (int_quant): IntQuant(
            (float_to_int_impl): RoundSte()
            (tensor_clamp_impl): TensorClamp()
            (delay_wrapper): DelayWrapper(
              (delay_impl): _NoDelay()
            )
            (input_view_impl): Identity()
          )
          (scaling_impl): ParameterFromRuntimeStatsScaling(
            (stats_input_view_shape_impl): OverTensorView()
            (stats): _Stats(
              (stats_impl): AbsPercentile()
            )
            (restrict_scaling_impl): FloatRestrictValue()
            (restrict_scaling): _RestrictValue(
              (restrict_va

Building upon the previous setup, here are the enhancements and their implications:

- Propagating QuantTensor: We now set return_quant_tensor=True for all quantized activation layers. This allows each layer to pass a QuantTensor—a tensor that includes metadata about its quantization—to the next layer. This is crucial because layers like QuantLinear or QuantConv2d use this metadata to understand how their inputs have been quantized.

- What is a QuantTensor?
A QuantTensor is a tensor-like object that carries information about quantization (e.g., scale, zero-point) alongside the actual data. It’s conceptually similar to PyTorch’s qint data types but designed to be training-friendly. Importantly, setting return_quant_tensor=True does not change how quantization is computed—it only affects how the output is represented and interpreted by subsequent layers.

- Bias Quantization with Int32Bias:
We now enable bias quantization using the Int32Bias quantizer. This quantizer computes the bias scale as:

`bias_scale = input_scale * weight_scale`

—a standard practice in inference frameworks. To do this correctly, each layer needs to know the input_scale, which is why we must propagate QuantTensor using return_quant_tensor=True.

- Torch Functions and QuantTensor:
Functions like torch.nn.functional.max_pool2d that are algorithmically invariant to quantization (i.e., their computation doesn’t depend on value scaling or quantization granularity) can still operate as usual. They will simply pass the QuantTensor through without modification, so no extra steps are needed for them to work in a quantization-aware pipeline.

## 4. Export to ONNX
Brevitas itself does not provide any low-precision compute acceleration during training or inference. To take advantage of hardware acceleration for quantized models, the model must first be exported to an inference toolchain using an intermediate format such as ONNX.

A common format for representing 8-bit quantization in ONNX is QDQ (Quantize-Dequantize). Brevitas extends this idea to QCDQ (Quantize-Clip-Dequantize), which adds a Clip node between quantization and dequantization to better support quantization to bit-widths ≤ 8 bits. This extension allows for a more precise representation of quantization behavior, especially for sub-8-bit formats.

The export process is straightforward: you can export the previously defined Brevitas quantized model to QCDQ format using Brevitas's ONNX export functionality. The interface of Brevitas’s export function is designed to mirror torch.onnx.export, so it accepts the same arguments (kwargs), making it easy to integrate into existing PyTorch-to-ONNX export workflows.

In [7]:
!pip install onnx onnxruntime onnxoptimizer

Collecting onnx
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting onnxoptimizer
  Downloading onnxoptimizer-0.3.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m124.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [12]:
from brevitas.export import export_onnx_qcdq
import torch

# Weight-only model
export_onnx_qcdq(quant_weight_lenet, torch.randn(1, 3, 32, 32), export_path='4b_weight_lenet.onnx')

# Weight-activation model
export_onnx_qcdq(quant_weight_act_lenet, torch.randn(1, 3, 32, 32), export_path='4b_weight_act_lenet.onnx')

# Weight-activation-bias model
export_onnx_qcdq(quant_weight_act_bias_lenet, torch.randn(1, 3, 32, 32), export_path='4b_weight_act_bias_lenet.onnx')

