### Quantization from scratch
 - Absmax quantization
 - Zeropoint quantization

In [1]:
## imports
import torch
from torch import tensor
from typing import Optional, Union
import numpy as np

In [20]:
# create a random tensor
B, M, P = 4, 128, 256
weight_tensor = np.random.uniform(-50, 150, (B, M, P))
print(weight_tensor.shape)

weight_tensor[0, 0, 0] = weight_tensor.max() + 1
weight_tensor[0, 0, 1] = weight_tensor.min() - 1
weight_tensor[0, 0, 2] = 0


(4, 128, 256)


### Quantization operations

In [16]:
def asymmetric_quantization_tensor(tensor: Union[torch.tensor, np.ndarray], bits: int):
    """
    Quantize the input tensor/array by using assymetric uniform quantization:
        Also known as zeropoint quantization.
    """
    # find the quantization range and constant
    alpha = np.max(tensor)
    beta = np.min(tensor)

    # quantization constant
    scale = (2 ** bits - 1) / (alpha - beta)
    zeropoint = -1 * np.round(beta * scale)
    lower_bound, upper_bound = 0, 2 ** bits - 1
    
    tensor_q = np.clip(np.round((tensor * scale) + zeropoint), lower_bound, upper_bound).astype(np.int32)

    return tensor_q, scale, zeropoint


def symmetric_quantization_tensor(tensor: Union[torch.tensor, np.ndarray], bits: int):
    """
    Quantizer the input tensor/array by using symmetric uniform quantization
        also known as absmax quantization
    """

    # find the quantization range and constant
    # alpha = np.min(tensor)
    beta = np.max(tensor)

    # quantization constant
    scale = (2 ** (bits - 1) - 1)/ beta 
    lower_bound, upper_bound = - 2 ** (bits - 1) + 1, 2 ** (bits -1) - 1

    tensor_q = np.clip(np.round(tensor * scale), lower_bound, upper_bound).astype(np.int32)

    return tensor_q, scale


def asymmetric_dequantize(tensor: Union[tensor, np.ndarray], scale: float, zeropoint: float):
    """
    Dequantize asymmetric quantization using scale and input tensor
    """

    return (tensor - zeropoint) / scale 


def symmetric_dequantize(tensor: Union[tensor, np.ndarray], scale: float) -> Union[tensor, np.ndarray]:
    """ 
    Dequantize the symmetric quantization operation
    """

    return tensor / scale 


def quantization_error(tensor_deq: Union[tensor, np.ndarray], tensor: Union[tensor, np.ndarray]) -> float:
    """ 
    MSE of the quantization after dequantizing
    """
    return np.mean((tensor - tensor_deq) ** 2)


In [21]:
# Asymmetric Quantization
asym_quantized_weight, asym_scale, asym_zeropoint = asymmetric_quantization_tensor(weight_tensor, 8)
print(f'Original :')
print(np.round(weight_tensor, 2))
print("")
print(f'Assymetric scale: {np.round(asym_scale, 3)}, Zeropoint: {asym_zeropoint}')
print(asym_quantized_weight)
print('')

Original :
[[[151.   -51.     0.   ... 133.34 140.58 108.55]
  [100.7   27.4  139.06 ... 130.54 121.36  91.81]
  [ 63.11 112.89  19.36 ...  64.66  -0.9  131.01]
  ...
  [139.87  19.23  45.87 ... 144.32 111.11  27.78]
  [-46.54  26.34 -11.64 ...  38.97  17.3  -37.19]
  [ 53.47  20.1   48.86 ...   6.05 -34.25  87.98]]

 [[ 44.63   8.84 148.58 ... -32.08  69.96  31.05]
  [ 48.8   46.9  142.94 ...  86.29  10.   -42.49]
  [ 69.18 113.08  97.8  ...  74.43  69.12   7.23]
  ...
  [ 37.45 124.55 131.54 ...  44.76 113.78  10.45]
  [145.13  97.41  78.14 ...  32.84 119.57  98.94]
  [-19.12 109.69  77.26 ...  -8.9   45.56  11.11]]

 [[126.7   26.34  73.78 ...  85.76  -8.23 118.15]
  [ 13.95 148.42  67.7  ... -10.4   10.31  99.4 ]
  [ 53.88  -9.08   7.43 ...  29.64 -39.71  54.11]
  ...
  [112.28 -36.95  15.41 ... 122.28   1.97  36.65]
  [-41.3   16.26  72.81 ... -15.14  52.82 108.25]
  [148.41  84.35  75.77 ... -21.78 -34.7   45.33]]

 [[-40.31  -1.23  90.2  ... 135.9   27.24 -28.49]
  [-25.47  88.0

In [23]:
# Symmetric quantization
sym_quantized_weight, sym_scale = symmetric_quantization_tensor(weight_tensor, 8)
print(f"Original :")
print(np.round(weight_tensor, 2))
print("")
print(f"Symmetric scale: {np.round(sym_scale, 3)}")
print(sym_quantized_weight)
print("")

Original :
[[[151.   -51.     0.   ... 133.34 140.58 108.55]
  [100.7   27.4  139.06 ... 130.54 121.36  91.81]
  [ 63.11 112.89  19.36 ...  64.66  -0.9  131.01]
  ...
  [139.87  19.23  45.87 ... 144.32 111.11  27.78]
  [-46.54  26.34 -11.64 ...  38.97  17.3  -37.19]
  [ 53.47  20.1   48.86 ...   6.05 -34.25  87.98]]

 [[ 44.63   8.84 148.58 ... -32.08  69.96  31.05]
  [ 48.8   46.9  142.94 ...  86.29  10.   -42.49]
  [ 69.18 113.08  97.8  ...  74.43  69.12   7.23]
  ...
  [ 37.45 124.55 131.54 ...  44.76 113.78  10.45]
  [145.13  97.41  78.14 ...  32.84 119.57  98.94]
  [-19.12 109.69  77.26 ...  -8.9   45.56  11.11]]

 [[126.7   26.34  73.78 ...  85.76  -8.23 118.15]
  [ 13.95 148.42  67.7  ... -10.4   10.31  99.4 ]
  [ 53.88  -9.08   7.43 ...  29.64 -39.71  54.11]
  ...
  [112.28 -36.95  15.41 ... 122.28   1.97  36.65]
  [-41.3   16.26  72.81 ... -15.14  52.82 108.25]
  [148.41  84.35  75.77 ... -21.78 -34.7   45.33]]

 [[-40.31  -1.23  90.2  ... 135.9   27.24 -28.49]
  [-25.47  88.0

In [14]:
## Quantization errors
print(f'Asymmetric Error: {np.round(quantization_error(asymmetric_dequantize(asym_quantized_weight, asym_scale, asym_zeropoint), weight_tensor), 3)}')

print(f'Symmetric Error: {np.round(quantization_error(symmetric_dequantize(sym_quantized_weight, sym_scale), weight_tensor), 3)}')


Asymmetric Error: 1685.414
Symmetric Error: 0.116
