In [67]:
import torch

original_tensor = torch.randn(10) * 200 + 50
print(f"Max value: {original_tensor.max()}")
print(f"Min value: {original_tensor.min()}")

#Set 0 as first element
original_tensor[0] = 0


Max value: 376.07928466796875
Min value: -194.05178833007812


In [68]:
original_tensor

tensor([   0.0000,   -7.2955, -174.8314, -194.0518,  174.6870,  376.0793,
         -29.7248, -143.4094,  260.2398,  -57.8868])

In [74]:
def clamp(params_q, lower_bound, upper_bound):
    return torch.clamp(params_q, lower_bound, upper_bound)

def asymmetric_quantization(params, bits):

    alpha = params.max()
    beta = params.min()
    scale = (alpha - beta) / (2**bits-1)
    zero = -1*torch.round(beta / scale)
    lower_bound, upper_bound = 0, 2**bits-1
    quantized = clamp(torch.round(params / scale + zero), lower_bound, upper_bound).long()
    return quantized, scale, zero

def asymmetric_dequantize(params_q, scale, zero):
    return (params_q.float() - zero) * scale

def symmetric_dequantize(params_q, scale):
    return params_q.float() * scale

def symmetric_quantization(params, bits):
    alpha = torch.max(torch.abs(params))
    scale = alpha / (2**(bits-1)-1)
    lower_bound = -2**(bits-1)
    upper_bound = 2**(bits-1)-1
    quantized = clamp(torch.round(params / scale), lower_bound, upper_bound).long()
    return quantized, scale

def quantization_error(params, params_q):
    return torch.mean((params - params_q)**2)

In [76]:
asymmetric_q, asymmetric_scale, asymmetric_zero = asymmetric_quantization(original_tensor, 8)
symmetric_q, symmetric_scale = symmetric_quantization(original_tensor, 8)


print(f'Symmetric scale: {symmetric_scale}')
print(symmetric_q)

print(f'Asymmetric scale: {asymmetric_scale}, zero: {asymmetric_zero}')
print(asymmetric_q)



Symmetric scale: 2.961254119873047
tensor([  0,  -2, -59, -66,  59, 127, -10, -48,  88, -20])
Asymmetric scale: 2.2358081340789795, zero: 87.0
tensor([ 87,  84,   9,   0, 165, 255,  74,  23, 203,  61])


In [78]:
# Dequantize
params_deq_asymmetric = asymmetric_dequantize(asymmetric_q, asymmetric_scale, asymmetric_zero)
params_deq_symmetric = symmetric_dequantize(symmetric_q, symmetric_scale)

print(f'Dequantize Asymmetric: {params_deq_asymmetric}')

print(f'Dequantize Symmetric: {params_deq_symmetric}')

Dequantize Asymmetric: tensor([   0.0000,   -6.7074, -174.3930, -194.5153,  174.3930,  375.6158,
         -29.0655, -143.0917,  259.3537,  -58.1310])
Dequantize Symmetric: tensor([   0.0000,   -5.9225, -174.7140, -195.4428,  174.7140,  376.0793,
         -29.6125, -142.1402,  260.5904,  -59.2251])


In [75]:
print(f'{"Asymmetric error: "}{quantization_error(original_tensor, params_deq_asymmetric)}')
print(f'{"Symmetric error: "}{quantization_error(original_tensor, params_deq_symmetric)}')

Asymmetric error: 0.24344968795776367
Symmetric error: 0.7371761798858643
