In [5]:
import numpy as np

# suppress scientific notation
np.set_printoptions(suppress=True)

params = np.random.uniform(low=-50, high=150, size=20)

params[0] = params.max() + 1
params[1] = params.min() - 1
params[3] = 0

# round each number to the second decimal place
params = np.round(params, 2)

print(params)

[128.03 -46.88 -13.11   0.   101.37 -16.99 127.03  41.95 -45.88  75.77
  32.16  17.29  66.55 -43.09  49.38 -36.    14.69  44.08 101.44  90.32]


# Define the quantization methods and quantize


In [6]:
def clamp(params_q: np.array, lower_bound: int, upper_bound: int) -> np.array:
    params_q[params_q < lower_bound] = lower_bound
    params_q[params_q > upper_bound] = upper_bound
    return params_q

def asymmetric_quantization(params: np.array, bits: int) -> tuple[np.array, float, int]:
    alpha = np.max(params)
    beta = np.min(params)
    scale = (alpha - beta) / (2**bits - 1)
    zeros = -1*np.round(beta/scale)
    lower_bound = 0
    upper_bound = 2**bits - 1
    
    # quantize the parameters
    quantized = clamp(np.round((params / scale) + zeros), lower_bound, upper_bound).astype(np.int32)
    return quantized, scale, zeros

def asymmetric_dequantize(params_q: np.array, scale: float, zero: int) -> np.array:
    return scale*(params_q - zero)

def symmetric_quantization(params: np.array, bits: int) -> tuple[np.array, float]:
    alpha = np.max(np.abs(params))
    scale = abs(alpha) / (2**(bits-1) - 1)    
    lower_bound = -(2**(bits-1) - 1)
    upper_bound = 2**(bits-1) - 1 
    
    # quantize the parameters
    quantized = clamp(np.round(params/scale), lower_bound, upper_bound).astype(np.int32)
    return quantized, scale

def symmetric_dequantize(params_q: np.array, scale: float) -> np.array:
    return params_q * scale

def quantization_error(params: np.array, parmas_q: np.array):
    # calculate the MSE
    return np.mean((params - parmas_q)**2)


(asymmetric_q, asymmetric_scale, asymmetric_zeros) = asymmetric_quantization(params, 8)
(symmetric_q, symmetric_scale) = symmetric_quantization(params, 8)


print(f'Original:')
print(np.round(params, 2))
print('')
print(f'Asymmetric scale: {asymmetric_scale}, zero: {asymmetric_zeros}')
print(asymmetric_q)
print('')
print(f'Symmetric scale: {symmetric_scale}')
print(symmetric_q)

Original:
[128.03 -46.88 -13.11   0.   101.37 -16.99 127.03  41.95 -45.88  75.77
  32.16  17.29  66.55 -43.09  49.38 -36.    14.69  44.08 101.44  90.32]

Asymmetric scale: 0.685921568627451, zero: 68.0
[255   0  49  68 216  43 253 129   1 178 115  93 165   5 140  16  89 132
 216 200]

Symmetric scale: 1.0081102362204724
[127 -47 -13   0 101 -17 126  42 -46  75  32  17  66 -43  49 -36  15  44
 101  90]


In [7]:
## Dequantize the parametes back to 32 bits
params_deq_asymmetric = asymmetric_dequantize(asymmetric_q, asymmetric_scale, asymmetric_zeros)
params_deq_symmetric = symmetric_dequantize(symmetric_q, symmetric_scale)



print(f'Original:')
print(np.round(params, 2))
print('')
print(f'Dequantize Asymmetric :')
print(params_deq_asymmetric)
print('')
print(f'Deqantize Symmetric:')
print(params_deq_symmetric)

Original:
[128.03 -46.88 -13.11   0.   101.37 -16.99 127.03  41.95 -45.88  75.77
  32.16  17.29  66.55 -43.09  49.38 -36.    14.69  44.08 101.44  90.32]

Dequantize Asymmetric :
[128.26733333 -46.64266667 -13.0325098    0.         101.51639216
 -17.14803922 126.8954902   41.84121569 -45.9567451   75.45137255
  32.23831373  17.14803922  66.53439216 -43.21305882  49.38635294
 -35.66792157  14.40435294  43.89898039 101.51639216  90.54164706]

Deqantize Symmetric:
[128.03       -47.3811811  -13.10543307   0.         101.81913386
 -17.13787402 127.02188976  42.34062992 -46.37307087  75.60826772
  32.25952756  17.13787402  66.53527559 -43.34874016  49.39740157
 -36.2919685   15.12165354  44.35685039 101.81913386  90.72992126]


In [8]:
print(f'{"Asymmetric error: ":>20}{np.round(quantization_error(params, params_deq_asymmetric), 2)}')
print(f'{"Symmetric error: ":>20}{np.round(quantization_error(params, params_deq_symmetric), 2)}')

  Asymmetric error: 0.03
   Symmetric error: 0.08
