# L4-A - Building your own Quantizer: Custom Build an 8-Bit Quantizer

In this lesson, you will learn how to compress any model in 8-bit precision.

### 1.1 - `w8_a16_forward` Function

-
```Python
W8A16LinearLayer
                    # 8-bit  # 16-bit         # optional
* w8_a16_forward -> weights, input,   scales, bias=None
                    
```
- Cast the 8-bit `weights` to the same data type as the `input`, "casted weights",
- keeping the "casted weights" in the same range as before, [-128, 127]
- Next, $$(({inputs} \cdot \text{``casted weights''}) * {scale}) + {bias}$$ 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hidden_state = torch.randn((1, 16), dtype=torch.bfloat16)
random_scale = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)


In [10]:
# perform matrix multiplication
matrix_mul = F.linear(random_hidden_state, random_int8.to(random_hidden_state.dtype))
print("Matrix Multiplication between hidden states and random weights: ", matrix_mul)

Matrix Multiplication between hidden states and random weights:  tensor([[-326.0000,   57.5000, -414.0000,  376.0000,   68.5000, -102.5000,
          -52.0000,  468.0000, -232.0000,  179.0000,  116.5000, -149.0000,
         -102.5000,   99.0000, -127.0000,  314.0000,  150.0000, -191.0000,
         -112.5000,   97.5000, -100.5000, -201.0000,   29.0000,  178.0000,
          104.5000,  111.5000, -194.0000,   98.5000,  173.0000,  239.0000,
         -131.0000, -326.0000]], dtype=torch.bfloat16)


In [11]:
# Multiply the result with the scale 
matrix_scale = F.linear(random_hidden_state, random_int8.to(random_hidden_state.dtype)) * random_scale
print("Matrix Multiplication with scale: ", matrix_scale)

Matrix Multiplication with scale:  tensor([[-7.1000e+01, -3.4000e+01,  2.5400e+02,  1.2900e+02,  4.1211e-01,
          4.8000e+01, -6.3000e+01, -4.5200e+02,  3.0800e+02, -1.0500e+02,
         -5.8250e+01, -1.5400e+02,  9.7500e+01,  5.9062e+00, -1.7000e+02,
         -2.0400e+02, -8.8125e+00,  3.1800e+02, -1.6000e+02,  1.0650e+02,
          1.6800e+02,  2.5300e+02,  1.6250e+01, -1.6500e+02,  8.8000e+01,
          3.8250e+01,  1.4000e+02,  2.0200e+02, -3.0600e+02, -2.0900e+02,
         -1.4400e+02,  6.0800e+02]], dtype=torch.bfloat16)


In [12]:
# Add bias to the result
matrix_bias = F.linear(random_hidden_state, random_int8.to(random_hidden_state.dtype)) * random_scale + bias
print("Matrix Multiplication with scale and bias: ", matrix_bias)

Matrix Multiplication with scale and bias:  tensor([[-7.1500e+01, -3.5000e+01,  2.5300e+02,  1.2900e+02, -1.7773e-01,
          4.7750e+01, -6.3250e+01, -4.5200e+02,  3.0800e+02, -1.0500e+02,
         -5.6000e+01, -1.5500e+02,  9.8000e+01,  3.7969e+00, -1.7100e+02,
         -2.0200e+02, -8.7500e+00,  3.1800e+02, -1.5900e+02,  1.0700e+02,
          1.6800e+02,  2.5400e+02,  1.7625e+01, -1.6500e+02,  8.6000e+01,
          3.8000e+01,  1.4100e+02,  2.0200e+02, -3.0600e+02, -2.0900e+02,
         -1.4600e+02,  6.0800e+02]], dtype=torch.bfloat16)


Implement all of this together into a forward function wieght_8bits_activation_16_forward

In [13]:
def w8_a16_forward(weight, input, scale, bias=None):
  # Cast weight to input type
  casted_weight = weight.to(input.dtype)
  # Perform matrix multiplication
  matrix_mul = F.linear(input, casted_weight) * scale
  # Add bias
  if bias is not None:
    matrix_mul += bias
  return matrix_mul

In [14]:
print("With bias: \n\n", w8_a16_forward(random_int8, random_hidden_state, random_scale, bias))
print("Without bias: \n\n", w8_a16_forward(random_int8, random_hidden_state, random_scale))

With bias: 

 tensor([[-7.1500e+01, -3.5000e+01,  2.5300e+02,  1.2900e+02, -1.7773e-01,
          4.7750e+01, -6.3250e+01, -4.5200e+02,  3.0800e+02, -1.0500e+02,
         -5.6000e+01, -1.5500e+02,  9.8000e+01,  3.7969e+00, -1.7100e+02,
         -2.0200e+02, -8.7500e+00,  3.1800e+02, -1.5900e+02,  1.0700e+02,
          1.6800e+02,  2.5400e+02,  1.7625e+01, -1.6500e+02,  8.6000e+01,
          3.8000e+01,  1.4100e+02,  2.0200e+02, -3.0600e+02, -2.0900e+02,
         -1.4600e+02,  6.0800e+02]], dtype=torch.bfloat16)
Without bias: 

 tensor([[-7.1000e+01, -3.4000e+01,  2.5400e+02,  1.2900e+02,  4.1211e-01,
          4.8000e+01, -6.3000e+01, -4.5200e+02,  3.0800e+02, -1.0500e+02,
         -5.8250e+01, -1.5400e+02,  9.7500e+01,  5.9062e+00, -1.7000e+02,
         -2.0400e+02, -8.8125e+00,  3.1800e+02, -1.6000e+02,  1.0650e+02,
          1.6800e+02,  2.5300e+02,  1.6250e+01, -1.6500e+02,  8.8000e+01,
          3.8250e+01,  1.4000e+02,  2.0200e+02, -3.0600e+02, -2.0900e+02,
         -1.4400e+02, 

#### Put it all together to a class

Note:
- This is how the `init` is of [PyTorch Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear):
```Python
def __init__(self, in_features, out_features, bias=True,
             device=None, dtype=None)


The below code would fail, because pytorch doesnt support yet, computation of gradient on int8 which nn.parameter will try to perform then, result in error.

In [None]:
### running this will result in an error
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
                                     ).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

#### Correct class implementation
We can bypass the gradient by storing the int8 weight as a buffer
Since we're not interesting in the training and only doing simple inference.

In [22]:
class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                bias=True, dtype=torch.float32):
        super().__init__()
        
        self.register_buffer( 
            "int8_weights",
            torch.randint(-128, 127, (out_features, in_features)).to(torch.int8)
        )

        self.register_buffer( 
            "scale",
            torch.randn((out_features), dtype=dtype)
        )
        if bias:
            self.register_buffer( 
                "bias",
                torch.randn((1, out_features), dtype=dtype)
            )
        else:
            self.bias = None

    # Add quantizer
    def quantize(self, weights):
        # cast weight to float32
        w_fp32 = weights.to(torch.float32)
        # get the scale by getting absolute max value of the the last dimension
        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)
        self.int8_weights = int8_weights

        self.scale = scales

    def forward(self, input):
        return w8_a16_forward(self.int8_weights, input, self.scale, self.bias)

In [21]:
# try that out with dummy instance
module = W8A16LinearLayer(16, 32)
dummy_hidden_state = torch.randn((1, 6, 16))
print(module.int8_weights.shape)
print(module.scale.shape)
print(module(dummy_hidden_state).dtype)


torch.Size([32, 16])
torch.Size([32])
torch.float32


In [24]:
module = W8A16LinearLayer(4, 8)
print("Weights before: ", module.int8_weights)

Weights before:  tensor([[  51,   55,   35,  -89],
        [  94,    3,   47,   27],
        [  37,   44,   73,  117],
        [   8,  -88,  -36,  -42],
        [-111,  -25,   68, -107],
        [ -30,   -2,   -3,   48],
        [ -46,   28,    2, -102],
        [ 108,  112, -112, -115]], dtype=torch.int8)


In [27]:
#pass in dummy weights as original weights andquantize them
# we can see the weighst are between -128 and 127
random_matrix = torch.randn((4, 8))
q_weights = module.quantize(random_matrix)
print("Weights after: ", module.int8_weights)

Weights after:  tensor([[  71,   21,  127,  120,  -75,   -4, -102,   55],
        [ -23,   -2,  -54,   25,  -11,   30, -127,    5],
        [ -59,  -77,   14,   -6,  -13, -127,   59,   28],
        [  -4,  -22,  127,  -31,    1,   41,  -23, -125]], dtype=torch.int8)


In [30]:
# look at the scales
print("Scales: ", module.scale)
print("Scale shape: ", module.scale.shape)

Scales:  tensor([0.0109, 0.0199, 0.0205, 0.0117])
Scale shape:  torch.Size([4])


In [31]:
# let's multiply our unquantized weights with the scale
# make sure the scale is broadcasted correctly
scale = module.scale.unsqueeze(1)
print("Scale shape after unsqueeze: ", scale.shape)
print("Weights shape: ", random_matrix.shape)


Scale shape after unsqueeze:  torch.Size([4, 1])
Weights shape:  torch.Size([4, 8])


In [35]:
# since now they're multipyable, let's multiply them
q_weights = module.int8_weights * scale
print("Weights * scale: ",q_weights)
print("Comparing with random matrix: ", random_matrix)

Weights * scale:  tensor([[ 0.7711,  0.2281,  1.3792,  1.3032, -0.8145, -0.0434, -1.1077,  0.5973],
        [-0.4581, -0.0398, -1.0754,  0.4979, -0.2191,  0.5975, -2.5292,  0.0996],
        [-1.2076, -1.5761,  0.2866, -0.1228, -0.2661, -2.5995,  1.2076,  0.5731],
        [-0.0469, -0.2578,  1.4880, -0.3632,  0.0117,  0.4804, -0.2695, -1.4646]])
Comparing with random matrix:  tensor([[ 0.7718,  0.2334,  1.3792,  1.3077, -0.8145, -0.0439, -1.1104,  0.5937],
        [-0.4559, -0.0325, -1.0828,  0.5024, -0.2117,  0.5930, -2.5292,  0.0959],
        [-1.2057, -1.5770,  0.2922, -0.1297, -0.2734, -2.5995,  1.2076,  0.5824],
        [-0.0484, -0.2610,  1.4880, -0.3655,  0.0063,  0.4819, -0.2655, -1.4615]])


In [39]:
# we can compute quantization error
def quantization_error(original, quantized):
    error_matrix = original - quantized
    mean_error = error_matrix.abs().mean()
    return error_matrix, mean_error

In [40]:
error_matrix, mean_error =  quantization_error(random_matrix, q_weights)
print("Quantization error: ", error_matrix)
print("Mean error: ", mean_error)

Quantization error:  tensor([[ 7.7575e-04,  5.2895e-03,  0.0000e+00,  4.4931e-03,  3.6538e-05,
         -4.6511e-04, -2.6394e-03, -3.6582e-03],
        [ 2.1196e-03,  7.3543e-03, -7.4043e-03,  4.5610e-03,  7.3683e-03,
         -4.4237e-03,  0.0000e+00, -3.6940e-03],
        [ 1.9451e-03, -9.8431e-04,  5.6031e-03, -6.9324e-03, -7.2791e-03,
          0.0000e+00,  8.7023e-06,  9.2616e-03],
        [-1.5544e-03, -3.1916e-03,  0.0000e+00, -2.2379e-03, -5.4338e-03,
          1.5437e-03,  3.9859e-03,  3.1277e-03]])
Mean error:  tensor(0.0034)
