In [27]:
import torch

# Quantization

**Quantization refers to the process of mapping a large set to smaller set of values (the higher precision to the lower precision)**



![](images/img.png)

**Original Tensor**
|    |     |     |
|----|-----|-----|
| 191.6 | -13.5 | 728.6 |
| 92.14 | 295.5 | -184  |
| 0     | 684.6 | 245.5 |

**Quantized Tensor**
|   |   |   |
|--|--|--|
|-23|-81|127|
|-51|6|-128|
|-77|114|-8|

### Advantage of Quantization

- Smaller model
- speed gains
    - memory bandwidth
    - Faster operation
        - GEMM: General  Matrix Multiplication
        - GEMV: General Matrix Vector Multiplication

### Challenges Of Quantization

- Quantization Error
- Retraining (Quantization Aware Training)
- Limited Hardware Support
- Calibration Dataset Needed
- Packing / Unpacking

# Linear Quantization

![](images/img_1.png)

**Idea:** Linear Mapping <br>
 $$ r = s(q-z)$$
 **WHERE,**<br>
 $r$: Original Value<br>
 $s$: Scale<br>
 $q$: Quantized value<br>
 $z$: zero-point 
 
**EXAMPLE**<br>
Assume $s = 2$ , $z = 0$, then <br>
$ r = 2(q-0) = 2q$ <br>
for $q = 10$, we have $r = 2*10 = 20$ 

In [28]:
def get_dequantised_tensor(tensor,scale,zero_point):
    return scale * (tensor.float()-zero_point)
    

### How do we get the quantized tensor q ?


as we know $ r = s(q-z)$<br>
then, $ r/s = q-z $<br>
then, $ r/s-z =q$ can also be written as $ q = r/s-z$<br>
then, $ round(r/s -z) $ , to eliminate floating-point overflow<br>
at end, $ q = int( round (r/s - z))$

#### with random value for `scaler` and `Zero-point`

In [29]:
def linear_q_with_scale_and_zero_point(
    tensor, scale, zero_point, dtype = torch.int8):

    scaled_and_shifted_tensor = tensor / scale + zero_point

    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

In [30]:
### a dummy tensor to test the implementation
test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]]
)

In [31]:
scale = 2.03
zero_point = 0

In [32]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point, dtype = torch.int8)

In [33]:
quantized_tensor

tensor([[ 94,  -7, 127],
        [ 45, 127, -91],
        [  0, 127, 121]], dtype=torch.int8)

In [34]:
dequantized_tensor = get_dequantised_tensor(quantized_tensor,scale,zero_point)

In [35]:
dequantized_tensor

tensor([[ 190.8200,  -14.2100,  257.8100],
        [  91.3500,  257.8100, -184.7300],
        [   0.0000,  257.8100,  245.6300]])

In [36]:
# quantized error
torch.square(test_tensor-dequantized_tensor).mean()

tensor(45023.9648)

#### without random

In [37]:
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()

    scale = (r_max - r_min) / (q_max - q_min)

    zero_point = q_min - (r_min / scale)

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        # round and cast to int
        zero_point = int(round(zero_point))
    
    return scale, zero_point

In [38]:
scale, zero_point = get_q_scale_and_zero_point(test_tensor)

In [39]:
scale

3.578823433670343

In [40]:
zero_point

-77

In [41]:
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point, dtype = torch.int8)

In [42]:
quantized_tensor

tensor([[ -23,  -81,  127],
        [ -51,    6, -128],
        [ -77,  114,   -8]], dtype=torch.int8)

In [43]:
dequantised_tensor = get_dequantised_tensor(quantized_tensor,scale,zero_point)

In [44]:
dequantized_tensor

tensor([[ 190.8200,  -14.2100,  257.8100],
        [  91.3500,  257.8100, -184.7300],
        [   0.0000,  257.8100,  245.6300]])

In [45]:
# quantized error
torch.square(test_tensor-dequantised_tensor).mean()

tensor(1.5730)

### Let's write all functions in  class format

In [46]:
class Quantizer:
        
    def quantize(self, tensor,dtype=torch.int8):
        self.r_tensor = tensor
        self.scale, self.zero_point = self.get_q_scale_and_zero_point(tensor)
        scaled_and_shifted_tensor = tensor / self.scale + self.zero_point
    
        rounded_tensor = torch.round(scaled_and_shifted_tensor)
    
        q_min = torch.iinfo(dtype).min
        q_max = torch.iinfo(dtype).max
    
        self.q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
        
        return self.q_tensor
        
    
    @staticmethod
    def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
        q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
        r_min, r_max = tensor.min().item(), tensor.max().item()
    
        scale = (r_max - r_min) / (q_max - q_min)
    
        zero_point = q_min - (r_min / scale)
    
        # clip the zero_point to fall in [quantized_min, quantized_max]
        if zero_point < q_min:
            zero_point = q_min
        elif zero_point > q_max:
            zero_point = q_max
        else:
            # round and cast to int
            zero_point = int(round(zero_point))
        
        return scale, zero_point
    
    def dequantized(self):
        
        self.dq_tensor = self.scale * (self.q_tensor.float()-self.zero_point)
        return self.dq_tensor
    
    def quantize_error(self):
        return torch.square(self.r_tensor-self.dequantized()).mean().item()
        
        
        
    
        
        
        

In [47]:
test_tensor

tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])

In [48]:
quantizer = Quantizer()
quantizer.quantize(test_tensor)

tensor([[ -23,  -81,  127],
        [ -51,    6, -128],
        [ -77,  114,   -8]], dtype=torch.int8)

In [49]:
quantizer.dequantized()

tensor([[ 193.2565,  -14.3153,  730.0800],
        [  93.0494,  297.0423, -182.5200],
        [   0.0000,  683.5552,  246.9388]])

In [50]:
quantizer.quantize_error()

1.5729731321334839

## Linear Quantization Mode

There are **two** modes in linear quantization:
- **Asymmetric:** We map $[r_{min},r_{max}]$ to $[q_{min},q_{max}]$. This what we implemented in previous example and code.
- **Symmetric:** We map $[-r_{max},r_{max}]$ to $[-q_{max},q_{max}]$. Where we can set $r_{max} = max(|r\_tensor|)$
    - We don't need to use the zero point (z=0)
    - This happens beacuse the floating -point range and the quantized range are symmetric with respect to zero.
    - Hence, we can simplify the equations to:
    $$q = int(round(r/s))$$
    $$s = r_{max}/q_{max}$$
    $$z = 0$$

![](images/img_2.png)

- **Trade-off:**
    - Utilization of quantized range: 
        - When using asymmetric quantization, the quantized range is fully utilized.
        - When symmetric mode, if the float range is biased towards one side, this will result in a quantized range where a part of the range is dedicated to values that we'll never see. (e.g RELU where the output is positive). 
    - Simplicity: Symmetric mode is much simpler compared to asymmetric mode. 
    - Memory: We don't store the zero-point for symmetric quantization

In [60]:
def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max

### let's update the class function

In [53]:
class Quantizer:
        
    def quantize(self, tensor,dtype=torch.int8,mode = "symmetric"):
        self.r_tensor = tensor
        if mode == "symmetric":
            self.scale, self.zero_point = self.get_q_scale_symmetric(tensor), 0
        else:
            self.scale, self.zero_point = self.get_q_scale_and_zero_point(tensor)
        scaled_and_shifted_tensor = tensor / self.scale + self.zero_point
    
        rounded_tensor = torch.round(scaled_and_shifted_tensor)
    
        q_min = torch.iinfo(dtype).min
        q_max = torch.iinfo(dtype).max
    
        self.q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
        
        return self.q_tensor
        
    
    @staticmethod
    def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
        q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
        r_min, r_max = tensor.min().item(), tensor.max().item()
    
        scale = (r_max - r_min) / (q_max - q_min)
    
        zero_point = q_min - (r_min / scale)
    
        # clip the zero_point to fall in [quantized_min, quantized_max]
        if zero_point < q_min:
            zero_point = q_min
        elif zero_point > q_max:
            zero_point = q_max
        else:
            # round and cast to int
            zero_point = int(round(zero_point))
        
        return scale, zero_point
    
    @staticmethod
    def get_q_scale_symmetric(tensor, dtype=torch.int8):
        r_max = tensor.abs().max().item()
        q_max = torch.iinfo(dtype).max
    
        # return the scale
        return r_max/q_max
    
    def dequantized(self):
        
        self.dq_tensor = self.scale * (self.q_tensor.float()-self.zero_point)
        return self.dq_tensor
    
    def quantize_error(self):
        return torch.square(self.r_tensor-self.dequantized()).mean().item()
        
        
        
    
        
        
        

In [51]:
test_tensor

tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])

In [54]:
quantizer = Quantizer()
quantizer.quantize(test_tensor)

tensor([[ 33,  -2, 127],
        [ 16,  52, -32],
        [  0, 119,  43]], dtype=torch.int8)

In [55]:
quantizer.dequantized()

tensor([[ 189.3213,  -11.4740,  728.6000],
        [  91.7921,  298.3244, -183.5842],
        [   0.0000,  682.7039,  246.6913]])

In [56]:
quantizer.quantize_error()

2.5091912746429443

# Different Granularities for Quantization
- For simplicity, we'll perform these using Symmetric mode.
- In the context of quantization, **"granularity"** refers to the level of detail or the smallest unit of measure in which a signal or data is represented or approximated after being quantized. It typically describes the precision or resolution of the quantization process.



![](images/img_3.png)

## Per tensor
 - till now, we have been doing per tensor quantization.

## Per channel

In [58]:
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8): # dim =0 represent axis = 0 and dim =1 axis=1
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale

In [69]:
test_tensor

tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])

In [62]:
q_tensor,scale = linear_q_symmetric_per_channel(test_tensor,dim =0)

In [63]:
q_tensor

tensor([[ 33,  -2, 127],
        [ 40, 127, -79],
        [  0, 127,  46]], dtype=torch.int8)

In [70]:
scale

tensor([5.7370, 2.3268, 5.3906])

In [83]:
d_quantized_tensor = get_dequantised_tensor(q_tensor,scale,0)

In [84]:
d_quantized_tensor

tensor([[ 189.3213,  -11.4740,  728.6000],
        [  93.0709,  295.5000, -183.8150],
        [   0.0000,  684.6000,  247.9653]])

In [85]:
# quantization error
torch.square(test_tensor-d_quantized_tensor).mean()

tensor(1.8084)

## Per group

- **Note:**<br>
    - Per-group quantization can require a lot more memory. 
Let's say we want to quantize a tensor in 4-bit, and we choose group_size = 32, symmetric mode (z=0), and we store the scales in FP16. 
It means that we're actually quantizing the tensor in 4.5 bits since we have: 4 bit (each element is stored in 4 bit) 16/32 bit (scale in 16 bits for every 32 elements)

In [86]:
def linear_q_symmetric_per_group(tensor, group_size,
                                 dtype=torch.int8):
    
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    
    quantized_tensor, scale = linear_q_symmetric_per_channel(
                                tensor, dim=0, dtype=dtype)
    
    quantized_tensor = quantized_tensor.view(t_shape)
    
    return quantized_tensor, scale

In [102]:
def linear_dequantization_per_group(quantized_tensor, scale, 
                                    group_size):
    
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    
    dequantized_tensor =  scale * quantized_tensor.float()
    
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

In [94]:
test_tensor = torch.rand((6, 6))

In [95]:
test_tensor

tensor([[0.4197, 0.6648, 0.5409, 0.3064, 0.1431, 0.0918],
        [0.7418, 0.2223, 0.6072, 0.0738, 0.8673, 0.9437],
        [0.8802, 0.7510, 0.4111, 0.5284, 0.0147, 0.2914],
        [0.0059, 0.5684, 0.9731, 0.6982, 0.8171, 0.9633],
        [0.3606, 0.0750, 0.1799, 0.5641, 0.1451, 0.7478],
        [0.2887, 0.5852, 0.2396, 0.2173, 0.3833, 0.9906]])

In [99]:
group_size = 3

In [104]:
quantized_tensor,scale = linear_q_symmetric_per_group(test_tensor,group_size)

In [105]:
dequantized_tensor = linear_dequantization_per_group(quantized_tensor,scale,group_size)

In [106]:
# quantization error
torch.square(test_tensor-dequantized_tensor).mean()

tensor(1.3180e-06)

# Quantization in Neural Network

#### In a neural network, we can quantize the **weights** but also the **activation**. Depending on what we quantize, the **storage** and the **computation** are not the same!

|  |  | |
|--|--|--|
|Storage| Quantized Weight + Activation (e.g. W8A32) |Quantized Weight + Quantized Activation (e.g. W8A8)|
 |Computation| Floating point arithmetics (FP32,FP16,BF16...) |Integer based arithmetics (INT8, INT4...|
 |Note|we need to dequantize the weights to perform the floating point computation!|Not supported by all hardware!|

## Linear Quantization: Inference

- `W8A32` means weights in 8-bits and activations in 32-bits.
- For simplicity, the linear layer will be without bias.

In [107]:
def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
    assert input.dtype == torch.float32
    assert q_w.dtype == torch.int8

    dequantized_weight = q_w.to(torch.float32) * s_w + z_w
    output = torch.nn.functional.linear(input, dequantized_weight)
    
    return output

In [108]:
input = torch.tensor([1, 2, 3], dtype=torch.float32)

In [109]:
weight = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])

In [110]:
weight.dtype

torch.float32

In [112]:
q_w, s_w  = linear_q_symmetric_per_channel(weight,dim=0)

In [113]:
q_w.dtype

torch.int8

In [114]:
s_w

tensor([[0.0157],
        [0.0128],
        [0.0169]])

In [115]:
output = quantized_linear_W8A32_without_bias(input,
                                             q_w,
                                             s_w,
                                             0)

In [117]:
print(f"This is the W8A32 output: {output}")

This is the W8A32 output: tensor([-2.9921,  3.8650,  9.3957])


In [118]:
fp32_output = torch.nn.functional.linear(input, weight)

In [119]:
print(f"This is the output if we don't quantize: {fp32_output}")

This is the output if we don't quantize: tensor([-3.0000,  3.8500,  9.3800])
