In [1]:
import torch

In [26]:
def quantize_linear_W8A32_without_bias(input, quantized_w, scale_w, zeropoint_w):
  # make sure the input is float32
  assert input.dtype == torch.float32, "Input should be float32"
  # make sure the quanted weights are int8
  assert quantized_w.dtype == torch.int8, "Weights should be int8"
  # dequantize the weights
  dequantized_weights = quantized_w.float() * scale_w + zeropoint_w
  #make inference using linear layer
  output = torch.nn.functional.linear(input, dequantized_weights)
  return output

In [35]:
# let's test our inference function on random input
input  = torch.tensor([1, 2, 3], dtype=torch.float32)
print("input: ", input, input.type())
weights = torch.tensor([[-2, -1.13, 0.42],
                            [1.2, 3.1, 2.1],
                            [0.1, 0.3, -0.2]])
print("quantized_w: ", weights)

input:  tensor([1., 2., 3.]) torch.FloatTensor
quantized_w:  tensor([[-2.0000, -1.1300,  0.4200],
        [ 1.2000,  3.1000,  2.1000],
        [ 0.1000,  0.3000, -0.2000]])


In [37]:
from helpers import linear_q_symmetric
quantized_w, scale_w = linear_q_symmetric(weights)
print("scale_w: ", scale_w)
print("quantized_w: ", quantized_w)

scale_w:  0.024409448067972978
quantized_w:  tensor([[-82, -46,  17],
        [ 49, 127,  86],
        [  4,  12,  -8]], dtype=torch.int8)


In [38]:
output = quantize_linear_W8A32_without_bias(input, quantized_w, scale_w, 0)
print("quantized inference output: ", output)

quantized inference output:  tensor([-3.0024, 13.6937,  0.0976])


In [39]:
# let's compare the output with the original floating  point inference
output = torch.nn.functional.linear(input, weights)
print("floating point inference output: ", output)

floating point inference output:  tensor([-3.0000, 13.7000,  0.1000])
