In [8]:
import numpy as np
np.random.seed(0)
np.set_printoptions(precision=2)

params = np.random.uniform(low=-50, high=150, size=10) # not this is not symmetric

params[0] = params.max() + 1
params[1] = params.min() - 1
params[2] = 0

print(params)

[143.73  25.69   0.    58.98  34.73  79.18  37.52 128.35 142.73  26.69]


In [9]:
def clamp(q_params: np.array, low_bound: int, upper_bound: int):
    q_params[q_params < low_bound] = low_bound
    q_params[q_params > upper_bound] = upper_bound
    return q_params

def asymmetric_quantisation(params, n_bits, stategy="min-max", percentile=99.99):
    up_bound = 2**n_bits - 1
    if stategy == "percentile":
        alpha = np.percentile(params, percentile)
        beta = np.percentile(params, 100-percentile)
    else:
        alpha = params.max()
        beta = params.min()

    scale = (alpha - beta) / up_bound
    shift = np.round(-beta / scale)
    q_params = np.round(params / scale) + shift
    return clamp(q_params, 0, up_bound).astype(np.int32), scale, shift

def asymmetric_de_quantisation(q_params, scale, shift):
    return (q_params - shift) * scale

def symmetric_quantisation(params, n_bits):
    up_bound = 2**(n_bits-1) - 1
    alpha = params.max()
    beta = params.min()
    alpha = alpha if np.abs(alpha) > np.abs(beta) else beta
    scale = np.abs(alpha) / up_bound
    q_params = np.round(params / scale)
    return clamp(q_params, -up_bound, up_bound).astype(np.int32), scale

def symmetric_de_quantisation(q_params, scale):
    return q_params * scale

def quantization_error(params, dequant_params):
    # MSE
    return np.mean((params - dequant_params) ** 2)

In [10]:
n_bits = 8
# asym
q_params_asym, scale, shift = asymmetric_quantisation(params, n_bits)
params_asym = asymmetric_de_quantisation(q_params_asym, scale, shift)
# sym
q_params_sym, scale = symmetric_quantisation(params, n_bits)
params_sym = symmetric_de_quantisation(q_params_sym, scale)


print(8*"#")
print("Before Quantisation")
print(params)
print("After ASymmetric DeQuantisation")
print(params_asym)
print("After Symmetric DeQuantisation")
print(params_sym)

print("\n"+8*"#")
print("Asymmetric Error")
print(quantization_error(params, params_asym))
print("Symmetric Error")
print(quantization_error(params, q_params_sym))

print("\n"+8*"#")
print("After Asymmetric Quantisation (MinMax)")
print(q_params_asym)
print("After Symmetric Quantisation")
print(q_params_sym)

########
Before Quantisation
[143.73  25.69   0.    58.98  34.73  79.18  37.52 128.35 142.73  26.69]
After ASymmetric DeQuantisation
[143.73  25.93   0.    59.18  34.95  78.91  37.77 128.51 142.61  26.49]
After Symmetric DeQuantisation
[143.73  26.03   0.    58.85  35.08  79.22  37.35 127.89 142.6   27.16]

########
Asymmetric Error
0.03597537537177541
Symmetric Error
97.74258886499265

########
After Asymmetric Quantisation (MinMax)
[255  46   0 105  62 140  67 228 253  47]
After Symmetric Quantisation
[127  23   0  52  31  70  33 113 126  24]


# Percentile Strategy

In [11]:
params = np.random.uniform(low=-50, high=150, size=10000)
params[0] = 1000.0 # outlier

q_params_asym_mm, scale_mm, shift_mm = asymmetric_quantisation(params, n_bits, stategy="min-max")
q_params_asym_p, scale_p, shift_p = asymmetric_quantisation(params, n_bits, stategy="percentile", percentile=99.99)
params_asym_mm = asymmetric_de_quantisation(q_params_asym_mm, scale_mm, shift_mm)
params_asym_p = asymmetric_de_quantisation(q_params_asym_p, scale_p, shift_p)
error_mm = quantization_error(params, params_asym_mm)
error_p = quantization_error(params, params_asym_p)
error_mm_exclude = quantization_error(params[1:], params_asym_mm[1:])
error_p_exclude = quantization_error(params[1:], params_asym_p[1:])

print_limit = 7
print("\n"+8*"#")
print(f"Original, first {print_limit} values: \n{params[:print_limit]}")
print(f"Asym(min-max), first {print_limit} values: \n{params_asym_mm[:print_limit]}")
print(f"Asym(percentile), first {print_limit} values: \n{params_asym_p[:print_limit]}")

print("\n"+8*"#")
print(f"Asym(min-max) error including outlier: \n{error_mm}")
print(f"Asym(percentile) error including outlier: \n{error_p}")
print("\n"+8*"#")
print(f"Asym(min-max) error excluding outlier: \n{error_mm_exclude}")
print(f"Asym(percentile) error excluding outlier: \n{error_p_exclude}") # much less


########
Original, first 7 values: 
[1000.     55.78   63.61  135.12  -35.79  -32.57  -45.96]
Asym(min-max), first 7 values: 
[1000.57   57.65   61.76  135.88  -37.06  -32.94  -45.29]
Asym(percentile), first 7 values: 
[149.85  55.7   63.55 134.95 -36.09 -32.95 -46.29]

########
Asym(min-max) error including outlier: 
1.4153214531795646
Asym(percentile) error including outlier: 
72.32654082673872

########
Asym(min-max) error excluding outlier: 
1.4154299995134934
Asym(percentile) error excluding outlier: 
0.05159551830205784
