In [2]:
# !pip install torch
import torch

In [3]:
def asymmetric_quantization(original_weight):
  # define the data type that you want to quantized to. In our example, it's INT8
  quantized_data_type = torch.int8

  # Get the Wmax and Wmin value from the orginal weight which is in FP32
  Wmax = original_weight.max().item()
  Wmin = original_weight.min().item()

  # Get the Qmax and Qmin value from the quantized data type.
  Qmax = torch.iinfo(quantized_data_type).max
  Qmin = torch.iinfo(quantized_data_type).min

  # Calculate the scale value using the scale formula. Datatype - FP32
  # Please refer to math section of this post if you want to find out how the formula has been derived.
  S = (Wmax - Wmin)/(Qmax - Qmin)

  # Calculate the zero point value using the zero point formula. Datatype - INT8
  # Please refer to math section of this post if you want to find out how the formula has been derived.
  Z = Qmin - (Wmin/S)
  # Check if the Z value is out of range
  if Z < Qmin:
    Z = Qmin
  elif Z > Qmax:
    Z = Qmax
  else:
    # Zero point datatype should be INT8 same as the Quantized value.
    Z = int(round(Z))

  # We have original_weight, scale and zero_point, now we can calculate the quantized weight using the formula we've derived in math section.
  quantized_weight = (original_weight/S) + Z

  # We'll also round it and also use the torch clamp function to ensure the quantized weight doesn't goes out of range and should remain within Qmin and Qmax.
  quantized_weight = torch.clamp(torch.round(quantized_weight), Qmin, Qmax)

  # finally cast the datatype to INT8
  quantized_weight = quantized_weight.to(quantized_data_type)

  # return the final quantized weight.
  return quantized_weight, S, Z

def asymmetric_dequantization(quantized_weight, scale, zero_point):
  # Use the dequantization calculation formula derived in the math section of this post.
  # Also make sure to convert quantized_weight to float as substraction between two INT8 values (quantized_weight and zero_point) will give unwanted result.
  dequantized_weight = scale * (quantized_weight.to(torch.float32) - zero_point)

  return dequantized_weight



In [5]:
# Assign random value to the original weight matrix(4,4) parameters (DataType: FP32)
# The value of the weight matrix below is same as the one in the diagram above.
original_weight = torch.randn((4,4))
print(original_weight)

tensor([[ 0.9125, -1.9589, -0.2497, -0.6999],
        [ 1.5274, -1.3953,  0.1922,  0.0445],
        [ 1.4998,  1.3379,  0.2339,  0.8801],
        [-1.6090,  2.3914, -0.1345, -0.9948]])


In [7]:
quantized_weight, scale, zero_point = asymmetric_quantization(original_weight)
print(f"quantized weight: {quantized_weight}")
print("\n")
print(f"scale: {scale}")
print("\n")
print(f"zero point: {zero_point}")

quantized weight: tensor([[  40, -128,  -28,  -54],
        [  77,  -95,   -2,  -10],
        [  75,   65,    1,   39],
        [-107,  127,  -21,  -71]], dtype=torch.int8)


scale: 0.017060000288720224


zero point: -13


In [36]:
quantized_weight

tensor([[  40, -128,  -28,  -54],
        [  77,  -95,   -2,  -10],
        [  75,   65,    1,   39],
        [-107,  127,  -21,  -71]], dtype=torch.int8)

In [8]:
dequantized_weight = asymmetric_dequantization(quantized_weight, scale, zero_point)
print(dequantized_weight)

tensor([[ 0.9042, -1.9619, -0.2559, -0.6995],
        [ 1.5354, -1.3989,  0.1877,  0.0512],
        [ 1.5013,  1.3307,  0.2388,  0.8871],
        [-1.6036,  2.3884, -0.1365, -0.9895]])


In [38]:
dequantized_weight

tensor([[ 0.9042, -1.9619, -0.2559, -0.6995],
        [ 1.5354, -1.3989,  0.1877,  0.0512],
        [ 1.5013,  1.3307,  0.2388,  0.8871],
        [-1.6037,  2.3884, -0.1365, -0.9895]])

In [9]:
quantization_error = (dequantized_weight - original_weight).square().mean()
print(quantization_error)

tensor(2.8572e-05)
