In [None]:
# !pip install accelerate==0.26.1 seaborn==0.13.1 torch==2.1.1 transformers==4.35.0

In [None]:
import torch

## linear quantization

In [None]:
sample_tensor = torch.tensor([[191.6, -13.5, 728.6],
                              [92.14, 295.5, -184],
                              [0, 684.6, 245.5]
                              ])

In [None]:
def linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):
  scaled_tensor = (tensor/scale) + zero_point   # linear transformation
  rounded_tensor = torch.round(scaled_tensor)

  q_min = torch.iinfo(torch.int8).min #-128
  q_max = torch.iinfo(torch.int8).max # 127

  quantized_tensor = torch.clamp(rounded_tensor, q_min, q_max).to(dtype) # clipping outliers

  return quantized_tensor

In [None]:
# let's try the quantization with ramdom values
scale = 3.5
zero_point = -70
quantized_tensor = linear_quantization_with_scale_and_zero_point(sample_tensor, scale, zero_point)
quantized_tensor

tensor([[ -15,  -74,  127],
        [ -44,   14, -123],
        [ -70,  126,    0]], dtype=torch.int8)

## Dequantization

In [None]:
def linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, zero_point):
  dequantized_tensor = scale*(quantized_tensor.float() - zero_point)
  return dequantized_tensor

In [None]:
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, zero_point)
dequantized_tensor

tensor([[ 192.5000,  -14.0000,  689.5000],
        [  91.0000,  294.0000, -185.5000],
        [   0.0000,  686.0000,  245.0000]])

## Quantization Error(MSE)


In [None]:
(dequantized_tensor - sample_tensor).square().mean()

tensor(170.8753)



---



### Finding scale and zero point

In [None]:
q_min = torch.iinfo(torch.int8).min
q_max = torch.iinfo(torch.int8).max

r_min = sample_tensor.min().item()
r_max = sample_tensor.max().item()

In [None]:
q_min, q_max, r_min, r_max

(-128, 127, -184.0, 728.5999755859375)

In [None]:
scale = (r_max - r_min)/(q_max - q_min)
scale

3.578823433670343

In [None]:
# r = s* (q-zero_point)
zero_point = int(round(q_min-(r_min/scale)))
zero_point

-77

In [None]:
def get_scale_and_zero_point(tensor, dtype=torch.int8):
  q_min = torch.iinfo(torch.int8).min
  q_max = torch.iinfo(torch.int8).max

  r_min = sample_tensor.min().item()
  r_max = sample_tensor.max().item()

  scale = (r_max - r_min)/(q_max - q_min)
  zero_point = q_min-(r_min/scale)

  # clip the zero point to fall in the range
  if zero_point < q_min:
    zero_point = q_min
  elif zero_point > q_max:
    zero_point = q_max
  else:
    zero_point = int(round(zero_point))

  return scale, zero_point

In [None]:
scale, zero_point = get_scale_and_zero_point(sample_tensor)

In [None]:
scale, zero_point

(3.578823433670343, -77)

In [None]:
# quantization
quantized_tensor = linear_quantization_with_scale_and_zero_point(sample_tensor, scale, zero_point)
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, zero_point)
sample_tensor, quantized_tensor, dequantized_tensor

(tensor([[ 191.6000,  -13.5000,  728.6000],
         [  92.1400,  295.5000, -184.0000],
         [   0.0000,  684.6000,  245.5000]]),
 tensor([[ -23,  -81,  127],
         [ -51,    6, -128],
         [ -77,  114,   -8]], dtype=torch.int8),
 tensor([[ 193.2565,  -14.3153,  730.0800],
         [  93.0494,  297.0423, -182.5200],
         [   0.0000,  683.5552,  246.9388]]))

In [None]:
quantization_error = (dequantized_tensor - sample_tensor).square().mean()
quantization_error

tensor(1.5730)

## Linear quantizer

In [None]:
def linear_quantization(tensor, dtype=torch.int8):
  scale, zero_point = get_scale_and_zero_point(tensor)
  quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point)
  return quantized_tensor, scale, zero_point

In [None]:
test_tensor = torch.randn((4,4))
test_tensor

tensor([[-1.1377,  1.2975,  0.5796, -1.1617],
        [ 1.7812, -0.1997,  0.7769,  0.7538],
        [-0.7402, -0.8930, -0.5111, -0.7521],
        [ 0.8354,  1.6310,  0.1411, -0.0541]])

In [None]:
quantized_tensor

tensor([[ -23,  -81,  127],
        [ -51,    6, -128],
        [ -77,  114,   -8]], dtype=torch.int8)

In [None]:
quantized_tensor, scale, zero_point = linear_quantization(test_tensor)
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, zero_point)
test_tensor, quantized_tensor, dequantized_tensor

(tensor([[-1.1377,  1.2975,  0.5796, -1.1617],
         [ 1.7812, -0.1997,  0.7769,  0.7538],
         [-0.7402, -0.8930, -0.5111, -0.7521],
         [ 0.8354,  1.6310,  0.1411, -0.0541]]),
 tensor([[-77, -77, -77, -77],
         [-77, -77, -77, -77],
         [-77, -77, -77, -77],
         [-77, -77, -77, -77]], dtype=torch.int8),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]))

In [None]:
quantization_error=(dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(0.9125)

In [None]:
class linear_quantizer:
  def __init__(self, dtype=torch.int8):
    self.dtype = dtype

  def linear_quantization(self, tensor):
    scale, zero_point = get_scale_and_zero_point(tensor)
    quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point)
    return quantized_tensor, scale, zero_point

  def get_scale_and_zero_point(self, tensor):
    q_min = torch.iinfo(self.dtype).min
    q_max = torch.iinfo(self.dtype).max

    r_min = tensor.min().item()
    r_max = tensor.max().item()

    scale = (r_max - r_min)/(q_max - q_min)
    zero_point = q_min-(r_min/scale)

    # clip the zero point to fall in the range
    if zero_point < q_min:
      zero_point = q_min
    elif zero_point > q_max:
      zero_point = q_max
    else:
      zero_point = int(round(zero_point))

    return scale, zero_point

  def linear_quantization_with_scale_and_zero_point(self, tensor, scale, zero_point):
    scaled_tensor = (tensor/scale) + zero_point   # linear transformation
    rounded_tensor = torch.round(scaled_tensor)

    q_min = torch.iinfo(self.dtype).min #-128
    q_max = torch.iinfo(self.dtype).max # 127

    quantized_tensor = torch.clamp(rounded_tensor, q_min, q_max).to(self.dtype) # clipping outliers

    return quantized_tensor



In [None]:
l_quantizer = linear_quantizer()
q_tensor, scale, zero_point = l_quantizer.linear_quantization(test_tensor)

In [None]:
deq_tensor = linear_dequantization_with_scale_and_zero_point(q_tensor, scale, zero_point)

In [None]:
quantization_error=(deq_tensor - test_tensor).square().mean()
quantization_error

tensor(0.9125)

In [None]:
test_tensor, q_tensor, deq_tensor

(tensor([[-1.1377,  1.2975,  0.5796, -1.1617],
         [ 1.7812, -0.1997,  0.7769,  0.7538],
         [-0.7402, -0.8930, -0.5111, -0.7521],
         [ 0.8354,  1.6310,  0.1411, -0.0541]]),
 tensor([[-77, -77, -77, -77],
         [-77, -77, -77, -77],
         [-77, -77, -77, -77],
         [-77, -77, -77, -77]], dtype=torch.int8),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]))

In [None]:
scale, zero_point

(3.578823433670343, -77)

In [None]:
quantization_error=(dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(0.9125)



---



### Symmetric mode

In [None]:
test_tensor.abs().max().item()

1.7811648845672607

In [None]:
def get_symmetric_scale(tensor, dtype= torch.int8):
  r_max = tensor.abs().max().item()
  q_max = torch.iinfo(torch.int8).max
  scale = r_max/q_max
  return scale


In [None]:
get_symmetric_scale(test_tensor)

0.014024920350923313

In [None]:
def linear_symmetric_quantization(tensor, dtype=torch.int8):
  scale = get_symmetric_scale(tensor)
  zero_point = 0
  quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point=0)
  return quantized_tensor, scale

## This is a per tensor quantization

In [None]:
quantized_tensor, scale = linear_symmetric_quantization(test_tensor)
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, 0)

In [None]:
test_tensor, quantized_tensor, dequantized_tensor

(tensor([[-1.1377,  1.2975,  0.5796, -1.1617],
         [ 1.7812, -0.1997,  0.7769,  0.7538],
         [-0.7402, -0.8930, -0.5111, -0.7521],
         [ 0.8354,  1.6310,  0.1411, -0.0541]]),
 tensor([[-81,  93,  41, -83],
         [127, -14,  55,  54],
         [-53, -64, -36, -54],
         [ 60, 116,  10,  -4]], dtype=torch.int8),
 tensor([[-1.1360,  1.3043,  0.5750, -1.1641],
         [ 1.7812, -0.1963,  0.7714,  0.7573],
         [-0.7433, -0.8976, -0.5049, -0.7573],
         [ 0.8415,  1.6269,  0.1402, -0.0561]]))

In [None]:
quantization_error=(dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(1.7827e-05)

In [None]:
class symmetric_per_tensor_quantizer:
  def __init__(self, dtype=torch.int8):
    self.dtype = dtype

  def linear_symmetric_quantization(self, tensor):
    scale = get_symmetric_scale(tensor)
    zero_point = 0
    quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point=0)
    return quantized_tensor, scale

  def get_symmetric_scale(self, tensor):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(self.dtype).max
    scale = r_max/q_max
    return scale

  def linear_quantization_with_scale_and_zero_point(self, tensor, scale, zero_point):
    scaled_tensor = (tensor/scale) + zero_point   # linear transformation
    rounded_tensor = torch.round(scaled_tensor)

    q_min = torch.iinfo(self.dtype).min #-128
    q_max = torch.iinfo(self.dtype).max # 127

    quantized_tensor = torch.clamp(rounded_tensor, q_min, q_max).to(self.dtype) # clipping outliers

    return quantized_tensor



In [None]:
sym_quantizer = symmetric_per_tensor_quantizer()
quantized_tensor, scale =sym_quantizer.linear_symmetric_quantization(test_tensor)
quantized_tensor, scale

(tensor([[-81,  93,  41, -83],
         [127, -14,  55,  54],
         [-53, -64, -36, -54],
         [ 60, 116,  10,  -4]], dtype=torch.int8),
 0.014024920350923313)

In [None]:
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, 0)
dequantized_tensor

tensor([[-1.1360,  1.3043,  0.5750, -1.1641],
        [ 1.7812, -0.1963,  0.7714,  0.7573],
        [-0.7433, -0.8976, -0.5049, -0.7573],
        [ 0.8415,  1.6269,  0.1402, -0.0561]])

In [None]:
quantization_error=(dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(1.7827e-05)



---



## This is a per channel quantization

In [None]:
sample_tensor

tensor([[ 191.6000,  -13.5000,  728.6000],
        [  92.1400,  295.5000, -184.0000],
        [   0.0000,  684.6000,  245.5000]])

In [None]:
dim = 0 # along the rows
sample_tensor.shape[dim]

3

In [None]:
def linear_per_channel_quantization(tensor, dim, dtype=torch.int8):
  output_dim = tensor.shape[dim]
  scale = torch.zeros(output_dim)

  for i in range(output_dim):
    sub_tensor = tensor.select(dim, i)
    scale[i] = get_symmetric_scale(sub_tensor)

  scale_shape = [1]* tensor.dim()
  scale_shape[dim] = -1
  scale = scale.view(scale_shape)
  quantized_tensor = linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point=0)
  return quantized_tensor, scale


In [None]:
test_tensor

tensor([[-1.1377,  1.2975,  0.5796, -1.1617],
        [ 1.7812, -0.1997,  0.7769,  0.7538],
        [-0.7402, -0.8930, -0.5111, -0.7521],
        [ 0.8354,  1.6310,  0.1411, -0.0541]])

In [None]:
quantized_tensor_0 , scale_0 = linear_per_channel_quantization(test_tensor, dim=0)
quantized_tensor_0 , scale_0

(tensor([[-111,  127,   57, -114],
         [ 127,  -14,   55,   54],
         [-105, -127,  -73, -107],
         [  65,  127,   11,   -4]], dtype=torch.int8),
 tensor([[0.0102],
         [0.0140],
         [0.0070],
         [0.0128]]))

In [None]:
dequantized_tensor_0 = linear_dequantization_with_scale_and_zero_point(quantized_tensor_0, scale_0, zero_point=0)
quantization_error = (dequantized_tensor_0-test_tensor).square().mean()
quantization_error

tensor(6.3167e-06)

In [None]:
quantized_tensor_1 , scale_1 = linear_per_channel_quantization(test_tensor, dim=1)
quantized_tensor_1 , scale_1

(tensor([[ -81,  101,   95, -127],
         [ 127,  -16,  127,   82],
         [ -53,  -70,  -84,  -82],
         [  60,  127,   23,   -6]], dtype=torch.int8),
 tensor([[0.0140, 0.0128, 0.0061, 0.0091]]))

In [None]:
dequantized_tensor_1 = linear_dequantization_with_scale_and_zero_point(quantized_tensor_1, scale_1, zero_point=0)
quantization_error = (dequantized_tensor_1 - test_tensor).square().mean()
quantization_error

tensor(9.1553e-06)

In [None]:
class per_channel_quantizer:
  def __init__(self, dtype=torch.int8):
    self.dtype = dtype

  def linear_per_channel_quantization(self, tensor, dim):
    output_dim = tensor.shape[dim]
    scale = torch.zeros(output_dim)

    for i in range(output_dim):
      sub_tensor = tensor.select(dim, i)
      scale[i] = self.get_symmetric_scale(sub_tensor)

    scale_shape = [1]* tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = self.linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point=0)
    return quantized_tensor, scale

  def get_symmetric_scale(self, tensor):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(self.dtype).max
    scale = r_max/q_max
    return scale

  def linear_quantization_with_scale_and_zero_point(self, tensor, scale, zero_point):
    scaled_tensor = (tensor/scale) + zero_point   # linear transformation
    rounded_tensor = torch.round(scaled_tensor)

    q_min = torch.iinfo(self.dtype).min #-128
    q_max = torch.iinfo(self.dtype).max # 127

    quantized_tensor = torch.clamp(rounded_tensor, q_min, q_max).to(self.dtype) # clipping outliers

    return quantized_tensor



In [None]:
per_c_quantizer = per_channel_quantizer()
quantized_tensor , scale = per_c_quantizer.linear_per_channel_quantization(test_tensor, dim=0)
quantized_tensor , scale

(tensor([[-111,  127,   57, -114],
         [ 127,  -14,   55,   54],
         [-105, -127,  -73, -107],
         [  65,  127,   11,   -4]], dtype=torch.int8),
 tensor([[0.0102],
         [0.0140],
         [0.0070],
         [0.0128]]))

In [None]:
dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor_0, scale_0, zero_point=0)
dequantized_tensor

tensor([[-1.1340,  1.2975,  0.5823, -1.1647],
        [ 1.7812, -0.1963,  0.7714,  0.7573],
        [-0.7383, -0.8930, -0.5133, -0.7524],
        [ 0.8347,  1.6310,  0.1413, -0.0514]])

In [None]:
quantization_error = (dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(6.3167e-06)



---



## This is a per group quantization

In [None]:
def linear_per_group_quantization(tensor, group_size, dtype=torch.int8):
  tensor_shape = tensor.shape
  assert tensor_shape[1]% group_size == 0
  assert tensor.dim() == 2

  tensor = tensor.view(-1, group_size)
  quantized_tensor, scale = linear_per_channel_quantization(tensor, dim=0)

  quantized_tensor = quantized_tensor.view(tensor_shape)
  return quantized_tensor, scale

In [None]:
def linear_per_group_dequantization(quantized_tensor, scale, group_size):
  quantized_tensor_shape = quantized_tensor.shape
  quantized_tensor = quantized_tensor.view(-1, group_size)
  dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, 0)
  dequantized_tensor = dequantized_tensor.view(quantized_tensor_shape)
  return dequantized_tensor


In [None]:
test_tensor = torch.rand((6,6))
test_tensor

tensor([[0.7977, 0.8257, 0.5962, 0.5998, 0.6907, 0.8333],
        [0.3198, 0.5195, 0.7537, 0.2716, 0.0857, 0.2507],
        [0.3562, 0.8249, 0.0028, 0.7432, 0.0482, 0.0051],
        [0.4960, 0.7146, 0.3630, 0.5088, 0.8020, 0.1363],
        [0.8739, 0.4375, 0.0498, 0.5100, 0.7065, 0.9855],
        [0.3782, 0.5187, 0.7559, 0.3816, 0.2202, 0.1814]])

In [None]:
test_tensor

tensor([[0.7977, 0.8257, 0.5962, 0.5998, 0.6907, 0.8333],
        [0.3198, 0.5195, 0.7537, 0.2716, 0.0857, 0.2507],
        [0.3562, 0.8249, 0.0028, 0.7432, 0.0482, 0.0051],
        [0.4960, 0.7146, 0.3630, 0.5088, 0.8020, 0.1363],
        [0.8739, 0.4375, 0.0498, 0.5100, 0.7065, 0.9855],
        [0.3782, 0.5187, 0.7559, 0.3816, 0.2202, 0.1814]])

In [None]:
group_size = 2
quantized_tensor, scale = linear_per_group_quantization(test_tensor, group_size)
dequantized_tensor = linear_per_group_dequantization(quantized_tensor, scale, group_size)
quantized_tensor, scale, dequantized_tensor

(tensor([[123, 127, 126, 127, 105, 127],
         [ 78, 127, 127,  46,  43, 127],
         [ 55, 127,   0, 127, 127,  13],
         [ 88, 127,  91, 127, 127,  22],
         [127,  64,  12, 127,  91, 127],
         [ 93, 127, 127,  64, 127, 105]], dtype=torch.int8),
 tensor([[0.0065],
         [0.0047],
         [0.0066],
         [0.0041],
         [0.0059],
         [0.0020],
         [0.0065],
         [0.0059],
         [0.0004],
         [0.0056],
         [0.0040],
         [0.0063],
         [0.0069],
         [0.0040],
         [0.0078],
         [0.0041],
         [0.0060],
         [0.0017]]),
 tensor([[0.7997, 0.8257, 0.5951, 0.5998, 0.6890, 0.8333],
         [0.3190, 0.5195, 0.7537, 0.2730, 0.0849, 0.2507],
         [0.3572, 0.8249, 0.0000, 0.7432, 0.0482, 0.0049],
         [0.4951, 0.7146, 0.3646, 0.5088, 0.8020, 0.1389],
         [0.8739, 0.4404, 0.0482, 0.5100, 0.7062, 0.9855],
         [0.3798, 0.5187, 0.7559, 0.3809, 0.2202, 0.1821]]))

In [None]:
quantization_error = (dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(1.2599e-06)

In [None]:
class per_group_quantizer:
  def __init__(self, dtype=torch.int8):
    self.dtype = dtype

  def linear_per_group_quantization(self, tensor, group_size):
    tensor_shape = tensor.shape
    assert tensor_shape[1]% group_size == 0
    assert tensor.dim() == 2

    tensor = tensor.view(-1, group_size)
    quantized_tensor, scale = self.linear_per_channel_quantization(tensor, dim=0)

    quantized_tensor = quantized_tensor.view(tensor_shape)
    return quantized_tensor, scale

  def linear_per_channel_quantization(self, tensor, dim):
    output_dim = tensor.shape[dim]
    scale = torch.zeros(output_dim)

    for i in range(output_dim):
      sub_tensor = tensor.select(dim, i)
      scale[i] = self.get_symmetric_scale(sub_tensor)

    scale_shape = [1]* tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = self.linear_quantization_with_scale_and_zero_point(tensor, scale, zero_point=0)
    return quantized_tensor, scale

  def get_symmetric_scale(self, tensor):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(self.dtype).max
    scale = r_max/q_max
    return scale

  def linear_quantization_with_scale_and_zero_point(self, tensor, scale, zero_point):
    scaled_tensor = (tensor/scale) + zero_point   # linear transformation
    rounded_tensor = torch.round(scaled_tensor)

    q_min = torch.iinfo(self.dtype).min #-128
    q_max = torch.iinfo(self.dtype).max # 127

    quantized_tensor = torch.clamp(rounded_tensor, q_min, q_max).to(self.dtype) # clipping outliers

    return quantized_tensor




In [None]:
  def linear_per_group_dequantization(quantized_tensor, scale, group_size):
    quantized_tensor_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    dequantized_tensor = linear_dequantization_with_scale_and_zero_point(quantized_tensor, scale, 0)
    dequantized_tensor = dequantized_tensor.view(quantized_tensor_shape)
    return dequantized_tensor

In [None]:
per_grp_quantizer = per_group_quantizer()


In [None]:
group_size = 2
quantized_tensor, scale = per_grp_quantizer.linear_per_group_quantization(test_tensor, group_size)
quantized_tensor, scale

(tensor([[123, 127, 126, 127, 105, 127],
         [ 78, 127, 127,  46,  43, 127],
         [ 55, 127,   0, 127, 127,  13],
         [ 88, 127,  91, 127, 127,  22],
         [127,  64,  12, 127,  91, 127],
         [ 93, 127, 127,  64, 127, 105]], dtype=torch.int8),
 tensor([[0.0065],
         [0.0047],
         [0.0066],
         [0.0041],
         [0.0059],
         [0.0020],
         [0.0065],
         [0.0059],
         [0.0004],
         [0.0056],
         [0.0040],
         [0.0063],
         [0.0069],
         [0.0040],
         [0.0078],
         [0.0041],
         [0.0060],
         [0.0017]]))

In [None]:
dequantized_tensor = linear_per_group_dequantization(quantized_tensor, scale, group_size)
dequantized_tensor

tensor([[0.7997, 0.8257, 0.5951, 0.5998, 0.6890, 0.8333],
        [0.3190, 0.5195, 0.7537, 0.2730, 0.0849, 0.2507],
        [0.3572, 0.8249, 0.0000, 0.7432, 0.0482, 0.0049],
        [0.4951, 0.7146, 0.3646, 0.5088, 0.8020, 0.1389],
        [0.8739, 0.4404, 0.0482, 0.5100, 0.7062, 0.9855],
        [0.3798, 0.5187, 0.7559, 0.3809, 0.2202, 0.1821]])

In [None]:
quantization_error = (dequantized_tensor - test_tensor).square().mean()
quantization_error

tensor(1.2599e-06)



---



## Quantizing weights and activations for inference

In [None]:
def quantized_linear_w8A32_layer_without_bias(input_activations, quantized_weights, scale, zero_point):
  assert input_activations.dtype == torch.float32
  assert quantized_weights.dtype == torch.int8

  dequantized_weights = scale*(quantized_weights.float() - zero_point)
  output = torch.nn.functional.linear(input_activations, dequantized_weights)

  return output

In [None]:
activations = torch.tensor([1,2,3], dtype=torch.float32)
activations

tensor([1., 2., 3.])

In [None]:
weights = torch.tensor([[-2,   -1.13, 0.42],
                       [-1.51, 0.25, 1.62],
                       [0.23,  1.35, 2.15]])
weights

tensor([[-2.0000, -1.1300,  0.4200],
        [-1.5100,  0.2500,  1.6200],
        [ 0.2300,  1.3500,  2.1500]])

In [None]:
quantized_weights, scale = linear_symmetric_quantization(weights)
quantized_weights, scale

(tensor([[-118,  -67,   25],
         [ -89,   15,   96],
         [  14,   80,  127]], dtype=torch.int8),
 0.016929134609192376)

output with quantization

In [None]:
output = quantized_linear_w8A32_layer_without_bias(activations, quantized_weights, scale, 0)
output

tensor([-2.9965,  3.8768,  9.3957])

Output without quantization

In [None]:
output1 = torch.nn.functional.linear(activations, weights)
output1

tensor([-3.0000,  3.8500,  9.3800])



---



# During inference

In [None]:
def linear_W8A32_layer(inputs, quantized_weights, scales, zero_point):
  assert inputs.dtype == torch.float32
  assert quantized_weights.dtype == torch.int8

  dequantized_weights = (quantized_weights.to(torch.float32) - zero_point) * scales
  output = torch.nn.functional.linear(inputs, dequantized_weights)

  return output

In [None]:
inputs = torch.tensor([1,2,3,4,5], dtype=torch.float32)
inputs

tensor([1., 2., 3., 4., 5.])

In [None]:
weights = torch.tensor([[-2,   -1.13, 0.42, 1.2, 0.56],
                       [-1.51, 0.25, 1.62, 0.34, -0.98],
                       [0.23,  1.35, 2.15, 0.67, -0.56]])

In [None]:
q_weights, scale = linear_symmetric_quantization(weights)
q_weights, scale

(tensor([[-118,  -67,   25,   71,   33],
         [ -89,   15,   96,   20,  -58],
         [  14,   80,  127,   40,  -33]], dtype=torch.int8),
 0.016929134609192376)

In [None]:
output = linear_W8A32_layer(inputs, q_weights, scale, 0)
output

tensor([4.6047, 0.3217, 9.3110])

In [None]:
print(f"The W8A32 output : {output}")

The W8A32 output : tensor([4.6047, 0.3217, 9.3110])


In [None]:
output_1 = torch.nn.functional.linear(inputs, weights)
print(f"Output if no quantization step : {output_1}")

Output if no quantization step : tensor([4.6000, 0.3100, 9.2600])


In [None]:
error = (output_1 - output).abs().mean()
print(f"Deviation in the output because of the quantization : {error}")

Deviation in the output because of the quantization : 0.022466978058218956




---



## Building custom W8A16 Linear layer Quantizer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class W8A16LinearLayer(nn.Module):
  def __init__(self, input_features, output_features, bias=True, dtype=torch.float32):
    super().__init__()

    self.register_buffer("int8_weights", torch.randint(-128,127, (output_features,input_features), dtype=torch.int8))
    self.register_buffer("scales", torch.randn((output_features), dtype= dtype))

    if bias:
      self.register_buffer("bias", torch.randn((1, output_features), dtype = dtype))
    else:
      self.bias = None

  def forward(self, inputs):
    converted_weights = self.int8_weights.to(inputs.dtype)
    output = F.linear(inputs, converted_weights) * self.scales

    if self.bias is not None:
      output = output + self.bias

    return output

  def quantize(self, weights):
    w_fp32 = weights.clone().to(torch.float32)

    scales = w_fp32.abs().max(dim=-1).values/127
    scales = scales.to(weights.dtype)

    int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

    self.int8_weights  = int8_weights
    self.scales = scales



In [None]:
module = W8A16LinearLayer(16, 32)
dummy_hidden_states = torch.randn(2, 6, 16)
module(dummy_hidden_states).shape

torch.Size([2, 6, 32])

In [None]:
dummy_hidden_states

tensor([[[ 5.0864e-01,  9.4294e-01, -5.0666e-01, -1.8538e-01,  9.4969e-01,
          -9.1660e-01,  9.4054e-01,  2.0055e+00,  2.0107e+00, -7.3471e-01,
           8.4933e-01,  7.6130e-01, -1.7449e+00, -1.5734e+00,  1.3508e+00,
          -4.6354e-01],
         [-5.1628e-01,  1.3770e+00, -9.5952e-01, -2.5794e-01,  1.1945e+00,
           1.9859e-01,  1.1551e+00,  9.0630e-01,  1.2083e+00,  2.1418e+00,
           2.8370e-01, -3.0197e-01, -3.8528e-01, -4.6498e-01,  1.3024e+00,
          -8.7642e-01],
         [-1.0415e+00, -8.0999e-01, -4.4456e-01, -6.3457e-01, -8.0390e-01,
           1.1102e+00, -8.0562e-01, -5.9924e-01,  8.4165e-01,  8.2397e-01,
          -3.5747e-01,  8.3150e-01,  5.2564e-01, -5.4652e-01,  9.6413e-01,
           1.5457e+00],
         [ 1.2524e-01,  1.4643e-01,  1.5920e+00,  2.4337e+00, -8.2539e-01,
           6.6588e-01,  1.8606e+00, -2.7514e+00,  2.0577e+00, -7.0306e-01,
           1.1762e+00, -8.3200e-01,  1.8268e+00, -1.0748e-01,  1.9035e+00,
           6.9365e-01],
    

In [None]:
module.int8_weights.shape

torch.Size([32, 16])

In [None]:
module(dummy_hidden_states).shape

torch.Size([2, 6, 32])

In [None]:
instance1 = W8A16LinearLayer(4, 8)

In [None]:
instance1.int8_weights

tensor([[ -64,   60,   87, -113],
        [-110,   58,  -30,    7],
        [  -9,   37,  -95, -128],
        [  68,  -77,   66,  100],
        [-105,  -52,  122,   12],
        [-106,  -87,  -66, -111],
        [ -96,   89,   61,   50],
        [ -16,  101,   -9,   72]], dtype=torch.int8)

In [None]:
input = torch.randn((4,8), dtype=torch.bfloat16)
instance1.quantize(input)


In [None]:
input

tensor([[ 0.3848,  0.7461,  0.0166,  0.3594, -0.5078, -2.8281,  0.3008, -0.9922],
        [-0.8828, -0.4082,  0.6953,  1.0859, -0.0233,  0.5859, -1.8281, -0.5625],
        [ 1.0859,  0.9453, -0.8789, -1.9531,  1.8828,  0.5000,  0.5547,  1.0078],
        [-0.1729, -0.4082,  0.2695,  0.8047,  0.1396,  1.0703,  0.6133,  0.0248]],
       dtype=torch.bfloat16)

In [None]:
instance1.int8_weights

tensor([[  17,   34,    1,   16,  -23, -128,   14,  -45],
        [ -61,  -28,   48,   76,   -2,   41, -127,  -39],
        [  70,   62,  -57, -127,  122,   32,   36,   66],
        [ -20,  -48,   32,   96,   17,  127,   73,    3]], dtype=torch.int8)

In [None]:
instance1.scales

tensor([0.0222, 0.0144, 0.0154, 0.0084], dtype=torch.bfloat16)

In [None]:
# dequantized weights
dequantized_weights = instance1.int8_weights * instance1.scales.unsqueeze(1)

## Replace PyTorch layers with Quantized Layers

In [None]:
def replace_linear_layer_with_W8A16Linear_layer(module, target , exclude_list):
  for name, child in module.named_children():
    if isinstance(child, nn.Linear) and not any([x == name for x in exclude_list]):
      old_bias = child.bias

      new_module = target(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
      setattr(module, name, new_module)

      if old_bias is not None:
        getattr(module, name).bias = old_bias

    else:
      replace_linear_layer_with_W8A16Linear_layer(child, target, exclude_list)


In [None]:
class neural_network(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.embd = torch.nn.Embedding(4, 8)
    self.linear_1 = nn.Linear(8, 16)
    self.linear_2 = nn.Linear(16, 4, bias = False)
    self.lm_head = nn.Linear(4, 6, bias = False)

  def forward(self, x):
    x = self.embd(x)
    x = self.linear_1(x)
    x = self.linear_2(x)
    x = self.lm_head(x)

    return x

In [None]:
model_1 = neural_network()
model_2 = neural_network()

In [None]:
model_1

neural_network(
  (embd): Embedding(4, 8)
  (linear_1): Linear(in_features=8, out_features=16, bias=True)
  (linear_2): Linear(in_features=16, out_features=4, bias=False)
  (lm_head): Linear(in_features=4, out_features=6, bias=False)
)

In [None]:
for name, child in model_1.named_children():
  print(name, child)

embd Embedding(4, 8)
linear_1 Linear(in_features=8, out_features=16, bias=True)
linear_2 Linear(in_features=16, out_features=4, bias=False)
lm_head Linear(in_features=4, out_features=6, bias=False)


In [13]:
input = torch.randint(0, 4, (1, 2), dtype=torch.long)
input

tensor([[1, 1]])

In [None]:
model_1(input)

tensor([[[ 0.0729, -0.0059, -0.0542, -0.0091, -0.0573,  0.0130],
         [ 0.6495, -0.1006, -0.2716, -0.3222, -0.4652, -0.4569]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
replace_linear_layer_with_W8A16Linear_layer(model_1, W8A16LinearLayer, ["lm_head"])
model_1

neural_network(
  (embd): Embedding(4, 8)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=4, out_features=6, bias=False)
)

In [None]:
replace_linear_layer_with_W8A16Linear_layer(model_2, W8A16LinearLayer, [])
model_2

neural_network(
  (embd): Embedding(4, 8)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)

### Linear layer replacement with quantization

In [None]:
def replace_linear_layer_with_W8A16Linear_layer_and_quantization(module, target , exclude_list):
  for name, child in module.named_children():
    if isinstance(child, nn.Linear) and not any([x == name for x in exclude_list]):
      old_bias = child.bias
      old_weights = child.weight

      new_module = target(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
      setattr(module, name, new_module)
      getattr(module, name).quantize(old_weights)

      if old_bias is not None:
        getattr(module, name).bias = old_bias

    else:
      replace_linear_layer_with_W8A16Linear_layer(child, target, exclude_list)


In [None]:
model_3 = neural_network()


In [None]:
replace_linear_layer_with_W8A16Linear_layer_and_quantization(model_3, W8A16LinearLayer, ["lm_head"])

In [None]:
model_3

neural_network(
  (embd): Embedding(4, 8)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=4, out_features=6, bias=False)
)

In [None]:
for name, child in model_3.named_children():
  print(name, child)
  if isinstance(child, W8A16LinearLayer):
    print(child.int8_weights, child.scales.dtype)
  else:
    print(child.weight)


embd Embedding(4, 8)
Parameter containing:
tensor([[ 1.0006, -0.4327, -0.8884, -0.5999, -1.0869, -1.2392,  0.8722, -2.0542],
        [-0.7918,  1.5889, -1.0925,  0.8230,  0.5623, -0.0047, -0.8779,  1.7216],
        [ 0.1103, -1.2674,  0.5033, -0.3895, -0.6331, -1.0500,  0.9856,  0.2398],
        [ 1.3366,  0.6915, -1.2141, -1.2603,  0.8708,  1.7693,  1.2354, -0.2309]],
       requires_grad=True)
linear_1 W8A16LinearLayer()
tensor([[   5,   23,  -33,  127,  107,  -24,   -1,  -57],
        [  12, -119,   96, -117,   75,  127,   36,   99],
        [  66,  -60,    1, -127,  -18,  -35,   39,   38],
        [ 118,  -60, -117,  -63, -123, -107,  -93, -127],
        [   3,  127,   86,   50,   18, -127,  119,   19],
        [-124,   73,  -79,  101, -127, -106,   32,    8],
        [  65,  127,   10,   40, -111,  -17,  -85,   98],
        [ -24,   90,  123, -126, -124,   54,   55,  127],
        [  98,  -69, -125,   -6,  127,  -29,  -83,  -47],
        [  52,  -80,  127,   12,  -84,   89,  -57, 

In [None]:
for name, child in model_3.named_children():
  print(name, child)


embd Embedding(4, 8)
linear_1 W8A16LinearLayer()
linear_2 W8A16LinearLayer()
lm_head Linear(in_features=4, out_features=6, bias=False)


In [None]:
g = getattr(model_3, name)
print(g.weight)

Parameter containing:
tensor([[ 0.4117,  0.2183, -0.1079,  0.3153],
        [ 0.2494,  0.1546, -0.2000, -0.2972],
        [ 0.4826, -0.4088, -0.3705,  0.3727],
        [ 0.3053, -0.0941,  0.0614, -0.2570],
        [-0.1868, -0.2262, -0.4779,  0.0295],
        [-0.0114, -0.4945,  0.3408,  0.0426]], requires_grad=True)


## Quantize a open source model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [None]:
model_id = "Salesforce/codegen-350M-mono"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/999 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/797M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/240 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]



In [None]:
memory_footprint_before_quantization = model.get_memory_footprint()/1e+6
print(f"model size before quantization : {memory_footprint_before_quantization} MB")

model size before quantization : 797.310976 MB


In [None]:
model

CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): CodeGenMLP(
          (fc_in): Linear(in_features=1024, out_features=4096, bias=True)
          (fc_out): Linear(in_features=4096, out_features=1024, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)

In [None]:
replace_linear_layer_with_W8A16Linear_layer_and_quantization(model,
                                        W8A16LinearLayer, ["lm_head"])

In [None]:
model

CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): W8A16LinearLayer()
          (out_proj): W8A16LinearLayer()
        )
        (mlp): CodeGenMLP(
          (fc_in): W8A16LinearLayer()
          (fc_out): W8A16LinearLayer()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)

In [None]:
import numpy as np

In [None]:
memory_footprint_after_quantization = model.get_memory_footprint()/1e+6
print(f"model size after quantization : {np.round(memory_footprint_after_quantization,2)} MB")

model size after quantization : 546.02 MB


In [None]:
print(f"Memory saved : {np.round((memory_footprint_before_quantization - memory_footprint_after_quantization), 2)} MB")

Memory saved : 251.29 MB


## Load quantizer weights from Hugging Face

In [None]:
quantized_model_state_dict = model.state_dict()
quantized_model_state_dict.keys()

odict_keys(['transformer.wte.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.qkv_proj.int8_weights', 'transformer.h.0.attn.qkv_proj.scales', 'transformer.h.0.attn.out_proj.int8_weights', 'transformer.h.0.attn.out_proj.scales', 'transformer.h.0.mlp.fc_in.bias', 'transformer.h.0.mlp.fc_in.int8_weights', 'transformer.h.0.mlp.fc_in.scales', 'transformer.h.0.mlp.fc_out.bias', 'transformer.h.0.mlp.fc_out.int8_weights', 'transformer.h.0.mlp.fc_out.scales', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.qkv_proj.int8_weights', 'transformer.h.1.attn.qkv_proj.scales', 'transformer.h.1.attn.out_proj.int8_weights', 'transformer.h.1.attn.out_proj.scales', 'transformer.h.1.mlp.fc_in.bias', 'transformer.h.1.mlp.fc_in.int8_weights', 'transformer.h.1.mlp.fc_in.scales', 'transformer.h.1.mlp.fc_out.bias', 'transformer.h.1.mlp.fc_out.int8_weights', 'transformer.h.1.mlp.fc_out.scales', 'transformer.h.2.ln_1.weight', 'transformer.

In [None]:
torch.save(quantized_model_state_dict, "quantized_model_state_dict.pt")

## Weight packing

In [2]:
import torch

In [3]:
packed_tensor = torch.randint(0,4,(12,), dtype=torch.uint8)
packed_tensor

tensor([3, 2, 3, 0, 1, 0, 1, 2, 1, 1, 2, 3], dtype=torch.uint8)

In [33]:
def weight_pack(int8weight_tensor , bits):
  assert (int8weight_tensor.shape[0] * bits) % 8 == 0

  no_of_packing = int8weight_tensor.shape[0] * bits // 8
  shifts = 8 // bits
  index = 0
  weight_packed_tensor = torch.zeros((no_of_packing), dtype=torch.uint8)

  for i in range(no_of_packing):
    for j in range(shifts):
      weight_packed_tensor[i] |= int8weight_tensor[index] << (bits * j)
      print(f"weight_packed_tensor[{i}] : {weight_packed_tensor[i] }   int8weight_tensor[{index}] : {int8weight_tensor[index]} index : {index} (bits * j) : ({bits * j})")
      index += 1

  return weight_packed_tensor

In [34]:
weight_pack(packed_tensor, 2)

weight_packed_tensor[0] : 3   int8weight_tensor[0] : 3 index : 0 (bits * j) : (0)
weight_packed_tensor[0] : 11   int8weight_tensor[1] : 2 index : 1 (bits * j) : (2)
weight_packed_tensor[0] : 59   int8weight_tensor[2] : 3 index : 2 (bits * j) : (4)
weight_packed_tensor[0] : 59   int8weight_tensor[3] : 0 index : 3 (bits * j) : (6)
weight_packed_tensor[1] : 1   int8weight_tensor[4] : 1 index : 4 (bits * j) : (0)
weight_packed_tensor[1] : 1   int8weight_tensor[5] : 0 index : 5 (bits * j) : (2)
weight_packed_tensor[1] : 17   int8weight_tensor[6] : 1 index : 6 (bits * j) : (4)
weight_packed_tensor[1] : 145   int8weight_tensor[7] : 2 index : 7 (bits * j) : (6)
weight_packed_tensor[2] : 1   int8weight_tensor[8] : 1 index : 8 (bits * j) : (0)
weight_packed_tensor[2] : 5   int8weight_tensor[9] : 1 index : 9 (bits * j) : (2)
weight_packed_tensor[2] : 37   int8weight_tensor[10] : 2 index : 10 (bits * j) : (4)
weight_packed_tensor[2] : 229   int8weight_tensor[11] : 3 index : 11 (bits * j) : (6)


tensor([ 59, 145, 229], dtype=torch.uint8)

In [35]:
unpacked_tensor = torch.tensor([1, 0, 3, 2], dtype=torch.uint8)

In [36]:
weight_pack(unpacked_tensor, 2)

weight_packed_tensor[0] : 1   int8weight_tensor[0] : 1 index : 0 (bits * j) : (0)
weight_packed_tensor[0] : 1   int8weight_tensor[1] : 0 index : 1 (bits * j) : (2)
weight_packed_tensor[0] : 49   int8weight_tensor[2] : 3 index : 2 (bits * j) : (4)
weight_packed_tensor[0] : 177   int8weight_tensor[3] : 2 index : 3 (bits * j) : (6)


tensor([177], dtype=torch.uint8)

In [27]:
a = torch.zeros((2) , dtype = torch.uint8)
a

tensor([0, 0], dtype=torch.uint8)

In [20]:
b = torch.tensor([1, 2], dtype=torch.uint8)

In [29]:
a |= b
a

tensor([1, 2], dtype=torch.uint8)

In [30]:
a

tensor([1, 2], dtype=torch.uint8)

## Weight unpacking

In [37]:
import torch

In [41]:
def weight_unpack(int8weight_tensor , bits):
  unpacked_values = int8weight_tensor.shape[0] * 8 // bits

  shifts = 8 // bits
  index = 0
  weight_unpacked_tensor = torch.zeros((unpacked_values), dtype=torch.uint8)

  mask = 2**bits-1

  for i in range(int8weight_tensor.shape[0]):
    for j in range(shifts):
      weight_unpacked_tensor[index] |= int8weight_tensor[i] >> (bits * j)
      print(f"weight_unpacked_tensor[{index}] : {weight_unpacked_tensor[index]}   int8weight_tensor[{i}] : {int8weight_tensor[i]} index : {index} (bits * j) : ({bits * j})")
      index += 1

  weight_unpacked_tensor &= mask

  return weight_unpacked_tensor



In [42]:
weight_unpack(torch.tensor([ 59, 145, 229], dtype=torch.uint8), 2)

weight_unpacked_tensor[0] : 59   int8weight_tensor[0] : 59 index : 0 (bits * j) : (0)
weight_unpacked_tensor[1] : 14   int8weight_tensor[0] : 59 index : 1 (bits * j) : (2)
weight_unpacked_tensor[2] : 3   int8weight_tensor[0] : 59 index : 2 (bits * j) : (4)
weight_unpacked_tensor[3] : 0   int8weight_tensor[0] : 59 index : 3 (bits * j) : (6)
weight_unpacked_tensor[4] : 145   int8weight_tensor[1] : 145 index : 4 (bits * j) : (0)
weight_unpacked_tensor[5] : 36   int8weight_tensor[1] : 145 index : 5 (bits * j) : (2)
weight_unpacked_tensor[6] : 9   int8weight_tensor[1] : 145 index : 6 (bits * j) : (4)
weight_unpacked_tensor[7] : 2   int8weight_tensor[1] : 145 index : 7 (bits * j) : (6)
weight_unpacked_tensor[8] : 229   int8weight_tensor[2] : 229 index : 8 (bits * j) : (0)
weight_unpacked_tensor[9] : 57   int8weight_tensor[2] : 229 index : 9 (bits * j) : (2)
weight_unpacked_tensor[10] : 14   int8weight_tensor[2] : 229 index : 10 (bits * j) : (4)
weight_unpacked_tensor[11] : 3   int8weight_ten

tensor([3, 2, 3, 0, 1, 0, 1, 2, 1, 1, 2, 3], dtype=torch.uint8)

In [43]:
weight_unpack(torch.tensor([ 177], dtype=torch.uint8), 2)

weight_unpacked_tensor[0] : 177   int8weight_tensor[0] : 177 index : 0 (bits * j) : (0)
weight_unpacked_tensor[1] : 44   int8weight_tensor[0] : 177 index : 1 (bits * j) : (2)
weight_unpacked_tensor[2] : 11   int8weight_tensor[0] : 177 index : 2 (bits * j) : (4)
weight_unpacked_tensor[3] : 2   int8weight_tensor[0] : 177 index : 3 (bits * j) : (6)


tensor([1, 0, 3, 2], dtype=torch.uint8)

In [9]:
import torch
class WeightPack:
  def __init__(self):
    pass

  def weight_pack(self, int8weight_tensor , bits):
    assert (int8weight_tensor.shape[0] * bits) % 8 == 0

    no_of_packing = int8weight_tensor.shape[0] * bits // 8
    shifts = 8 // bits
    index = 0
    weight_packed_tensor = torch.zeros((no_of_packing), dtype=torch.uint8)

    for i in range(no_of_packing):
      for j in range(shifts):
        weight_packed_tensor[i] |= int8weight_tensor[index] << (bits * j)
        print(f"weight_packed_tensor[{i}] : {weight_packed_tensor[i] }   int8weight_tensor[{index}] : {int8weight_tensor[index]} index : {index} (bits * j) : ({bits * j})")
        index += 1

    return weight_packed_tensor

  def weight_unpack(self, int8weight_tensor , bits):
    unpacked_values = int8weight_tensor.shape[0] * 8 // bits

    shifts = 8 // bits
    index = 0
    weight_unpacked_tensor = torch.zeros((unpacked_values), dtype=torch.uint8)

    mask = 2**bits-1

    for i in range(int8weight_tensor.shape[0]):
      for j in range(shifts):
        weight_unpacked_tensor[index] |= int8weight_tensor[i] >> (bits * j)
        print(f"weight_unpacked_tensor[{index}] : {weight_unpacked_tensor[index]}   int8weight_tensor[{i}] : {int8weight_tensor[i]} index : {index} (bits * j) : ({bits * j})")
        index += 1
    print(f"Before masking : {weight_unpacked_tensor}")

    weight_unpacked_tensor &= mask

    print(f"After masking : {weight_unpacked_tensor}")

    return weight_unpacked_tensor



In [10]:
unpacked_tensor = torch.tensor([3, 2, 3, 0, 1, 0, 1, 2, 1, 1, 2, 3], dtype=torch.uint8)

In [11]:
instance = WeightPack()
instance.weight_pack(unpacked_tensor, 2)

weight_packed_tensor[0] : 3   int8weight_tensor[0] : 3 index : 0 (bits * j) : (0)
weight_packed_tensor[0] : 11   int8weight_tensor[1] : 2 index : 1 (bits * j) : (2)
weight_packed_tensor[0] : 59   int8weight_tensor[2] : 3 index : 2 (bits * j) : (4)
weight_packed_tensor[0] : 59   int8weight_tensor[3] : 0 index : 3 (bits * j) : (6)
weight_packed_tensor[1] : 1   int8weight_tensor[4] : 1 index : 4 (bits * j) : (0)
weight_packed_tensor[1] : 1   int8weight_tensor[5] : 0 index : 5 (bits * j) : (2)
weight_packed_tensor[1] : 17   int8weight_tensor[6] : 1 index : 6 (bits * j) : (4)
weight_packed_tensor[1] : 145   int8weight_tensor[7] : 2 index : 7 (bits * j) : (6)
weight_packed_tensor[2] : 1   int8weight_tensor[8] : 1 index : 8 (bits * j) : (0)
weight_packed_tensor[2] : 5   int8weight_tensor[9] : 1 index : 9 (bits * j) : (2)
weight_packed_tensor[2] : 37   int8weight_tensor[10] : 2 index : 10 (bits * j) : (4)
weight_packed_tensor[2] : 229   int8weight_tensor[11] : 3 index : 11 (bits * j) : (6)


tensor([ 59, 145, 229], dtype=torch.uint8)

In [12]:
packed_tensor = torch.tensor([ 59, 145, 229], dtype=torch.uint8)

In [13]:
instance.weight_unpack(packed_tensor, 2)

weight_unpacked_tensor[0] : 59   int8weight_tensor[0] : 59 index : 0 (bits * j) : (0)
weight_unpacked_tensor[1] : 14   int8weight_tensor[0] : 59 index : 1 (bits * j) : (2)
weight_unpacked_tensor[2] : 3   int8weight_tensor[0] : 59 index : 2 (bits * j) : (4)
weight_unpacked_tensor[3] : 0   int8weight_tensor[0] : 59 index : 3 (bits * j) : (6)
weight_unpacked_tensor[4] : 145   int8weight_tensor[1] : 145 index : 4 (bits * j) : (0)
weight_unpacked_tensor[5] : 36   int8weight_tensor[1] : 145 index : 5 (bits * j) : (2)
weight_unpacked_tensor[6] : 9   int8weight_tensor[1] : 145 index : 6 (bits * j) : (4)
weight_unpacked_tensor[7] : 2   int8weight_tensor[1] : 145 index : 7 (bits * j) : (6)
weight_unpacked_tensor[8] : 229   int8weight_tensor[2] : 229 index : 8 (bits * j) : (0)
weight_unpacked_tensor[9] : 57   int8weight_tensor[2] : 229 index : 9 (bits * j) : (2)
weight_unpacked_tensor[10] : 14   int8weight_tensor[2] : 229 index : 10 (bits * j) : (4)
weight_unpacked_tensor[11] : 3   int8weight_ten

tensor([3, 2, 3, 0, 1, 0, 1, 2, 1, 1, 2, 3], dtype=torch.uint8)