In [3]:
import torch 

## weight quantization ## 

weights are of the shape [out_features, in_features] 

out_features -> number of neurons in hidden layer
in_features -> number of neurons in previous layer

output is calculated as xW.T + b


so for per channel quantization, quantization is performed separately for each output neuron



In [4]:
a = torch.randn(4,7)
w = a
w

tensor([[ 0.3971, -0.9492, -0.2258,  2.1995,  1.5506,  0.4588,  0.5867],
        [ 0.6155,  0.5954,  0.7435, -1.1585,  0.1466, -0.2693,  0.0341],
        [ 0.4187,  0.9666,  0.7911, -0.5947,  1.1770, -1.1629, -0.2461],
        [-1.4873,  0.7512, -0.6221, -0.3574,  0.5999, -1.4290,  0.2632]])

In [43]:
##################################################################

def quantize_weight_per_tensor_absmax(w, n_bits=8, inplace=True, symmetric=True):
    # w: (out_features, in_features)

    if symmetric:
        s = w.abs().max()
        # print (f"scale: {scales}")
        q_max = 2 ** (n_bits - 1) - 1
        # print (f"q_max: {q_max}")
        if inplace: 
            s.clamp_(min=1e-5).div_(q_max)
            w.div_(s).round_().mul_(s)
            return w
        else: 
            s = s.clamp(1e-5) / q_max 
            w_q = (w / s).round() * s 
            l2_dist = torch.norm(w - w_q, p=2)
            return w_q, l2_dist 
    
    else: 
        q_min = 0 
        q_max = 2 ** n_bits - 1 
        w_min = w.min() 
        w_max = w.max() 

        s = (w_max - w_min).clamp(1e-5) / (q_max - q_min) 
        zp = q_min - (w_min / s).round() # round here because zp should be int because this is where the 0 of input weight tensor is stored in quantization grid
        zp = zp.clamp(q_min, q_max) 
        w_q = torch.round(w / s + zp).clamp(q_min, q_max) 
        w_deq = (w_q - zp) * s 
        l2_dist = torch.norm(w - w_deq, p=2)
        return w_deq, l2_dist

##################################################################

def quantize_weight_per_channel_absmax(w, n_bits=8, inplace=True, symmetric=True):
    # w: (out_features, in_features)
    if symmetric: 
        s = w.abs().max(dim=-1, keepdim=True)[0]
        # print (f"scales: {scales}")
        q_max = 2 ** (n_bits - 1) - 1
        # print (f"q_max: {q_max}")
        if inplace:
            s.clamp_(min=1e-5).div_(q_max)
            w.div_(s).round_().mul_(s)
            return w
        else: 
            s = s.clamp(1e-5) / q_max 
            w_q = (w / s).round() * s 
            l2_dist = torch.norm(w - w_q, p=2) 
            return w_q, l2_dist
    else: 
        q_min = 0
        q_max = 2 ** n_bits - 1 
        w_min = w.min(dim=-1, keepdim=True)[0] 
        print(f"w_min: {w_min}")
        w_max = w.max(dim=-1, keepdim=True)[0] 
        print(f"w_max: {w_max}")

        s = (w_max - w_min).clamp(1e-5) / (q_max - q_min) 
        print(f"s: {s}")
        zp = q_min - torch.round(w_min / s) 
        print(f"zp: {zp}")

        w_q = torch.round(w / s) + zp 
        w_deq = (w_q - zp) * s 
        l2_dist = torch.norm(w - w_deq, p=2) 
        return w_deq, l2_dist 


##################################################################

def quantize_weight_per_column_absmax(w, n_bits=8, inplace=True, symmetric=True):
    # w: (out_features, in_features)
    if symmetric:
        s = w.abs().max(dim=-2, keepdim=True)[0]
        # print (f"scales: {scales}")
        q_max = 2 ** (n_bits - 1) - 1
        # print (f"q_max: {q_max}")
        if inplace: 
            s.clamp_(min=1e-5).div_(q_max)
            w.div_(scales).round_().mul_(scales)
            return w
        else: 
            s = s.clamp(1e-5) / q_max 
            w_q = (w / s).round() * s 
            l2_dist = torch.norm(w - w_q, p=2)
            return w_q, l2_dist 
    else: 
        q_min = 0 
        q_max = 2 ** n_bits - 1 
        w_min = w.min(dim=-2, keepdim=True)[0] 
        w_max = w.max(dim=-2, keepdim=True)[0] 

        s = (w_max - w_min).clamp(1e-5) / (q_max - q_min) 
        zp = q_min - torch.round(w_min / s) 
        zp = zp.clamp(q_min, q_max)

        w_q = torch.round(w / s) + zp 
        w_deq = (w_q - zp) * s 

        l2_dist = torch.norm(w - w_deq, p=2) 
        return w_deq, l2_dist 

    ##################################################################

In [46]:
w = torch.rand(4,17)
w

tensor([[0.2284, 0.7886, 0.5808, 0.2179, 0.0542, 0.3543, 0.2932, 0.6134, 0.0462,
         0.2177, 0.7858, 0.1044, 0.1523, 0.4921, 0.3982, 0.0437, 0.1084],
        [0.5337, 0.2013, 0.9084, 0.8421, 0.6303, 0.1473, 0.1696, 0.6418, 0.2077,
         0.7126, 0.3547, 0.2940, 0.7034, 0.7461, 0.6865, 0.9587, 0.3734],
        [0.2848, 0.4662, 0.7523, 0.9377, 0.3219, 0.9402, 0.0697, 0.3023, 0.2270,
         0.4377, 0.8794, 0.0976, 0.1356, 0.3573, 0.7558, 0.6483, 0.8741],
        [0.6451, 0.0496, 0.8922, 0.9471, 0.0070, 0.0725, 0.4155, 0.4269, 0.0997,
         0.8235, 0.2724, 0.7945, 0.6350, 0.6971, 0.7308, 0.8927, 0.2584]])

In [47]:
wq, l2 = quantize_weight_per_tensor_absmax(w, inplace=False, symmetric=True) 
wq, l2 

(tensor([[0.2265, 0.7851, 0.5813, 0.2189, 0.0528, 0.3548, 0.2944, 0.6115, 0.0453,
          0.2189, 0.7851, 0.1057, 0.1510, 0.4907, 0.4001, 0.0453, 0.1057],
         [0.5360, 0.2038, 0.9059, 0.8455, 0.6266, 0.1510, 0.1661, 0.6417, 0.2114,
          0.7096, 0.3548, 0.2944, 0.7021, 0.7474, 0.6870, 0.9587, 0.3699],
         [0.2869, 0.4680, 0.7549, 0.9361, 0.3246, 0.9436, 0.0679, 0.3020, 0.2265,
          0.4378, 0.8757, 0.0981, 0.1359, 0.3548, 0.7549, 0.6492, 0.8757],
         [0.6417, 0.0528, 0.8908, 0.9436, 0.0075, 0.0755, 0.4152, 0.4303, 0.0981,
          0.8228, 0.2718, 0.7927, 0.6341, 0.6945, 0.7323, 0.8908, 0.2567]]),
 tensor(0.0172))

In [48]:
wq, l2 = quantize_weight_per_tensor_absmax(w, inplace=False, symmetric=False) 
wq, l2

(tensor([[0.2277, 0.7875, 0.5822, 0.2165, 0.0560, 0.3546, 0.2949, 0.6121, 0.0448,
          0.2165, 0.7875, 0.1045, 0.1530, 0.4927, 0.3994, 0.0448, 0.1082],
         [0.5337, 0.2015, 0.9070, 0.8435, 0.6308, 0.1456, 0.1680, 0.6420, 0.2090,
          0.7129, 0.3546, 0.2949, 0.7017, 0.7465, 0.6868, 0.9518, 0.3732],
         [0.2837, 0.4665, 0.7539, 0.9368, 0.3210, 0.9406, 0.0709, 0.3023, 0.2277,
          0.4367, 0.8808, 0.0970, 0.1344, 0.3583, 0.7539, 0.6494, 0.8734],
         [0.6457, 0.0485, 0.8920, 0.9480, 0.0075, 0.0709, 0.4143, 0.4255, 0.1008,
          0.8249, 0.2725, 0.7950, 0.6345, 0.6980, 0.7315, 0.8920, 0.2575]]),
 tensor(0.0110))

In [50]:
w = torch.randn(4,17)
w

tensor([[-0.2234, -0.8111,  0.2828, -1.6396,  0.8526,  1.2311, -0.2821,  0.4726,
          0.9382, -0.3193,  1.6686, -0.5200,  0.9954, -1.1602, -0.6560, -0.4317,
          0.1805],
        [-1.0949,  0.4200, -0.8530, -0.3955,  1.0330,  0.4917, -0.7618, -1.4418,
         -0.4005,  1.3088,  1.6871,  1.2046,  0.3381,  0.5463, -0.1350, -0.1745,
          1.4844],
        [-0.8913, -0.5092, -0.4106, -0.2644,  2.3901, -0.5111, -2.0488,  0.3269,
          1.0665, -0.2785,  1.4403,  0.3653, -0.0092,  0.2465, -2.0544, -0.7793,
         -0.4758],
        [-0.0751,  0.1638, -1.4036,  0.5129,  1.1954,  0.0223,  1.1068, -1.9306,
         -1.3471, -1.4262,  0.3879,  0.1126, -0.4289, -0.1028, -0.3864,  0.9276,
          0.0493]])

In [51]:
wq, l2 = quantize_weight_per_channel_absmax(w, inplace=False, symmetric=True) 
wq, l2


(tensor([[-0.2234, -0.8146,  0.2891, -1.6423,  0.8540,  1.2350, -0.2759,  0.4730,
           0.9328, -0.3153,  1.6686, -0.5255,  0.9985, -1.1562, -0.6569, -0.4336,
           0.1839],
         [-1.0893,  0.4251, -0.8502, -0.3985,  1.0362,  0.4915, -0.7572, -1.4480,
          -0.3985,  1.3151,  1.6871,  1.2088,  0.3321,  0.5446, -0.1328, -0.1727,
           1.4878],
         [-0.8845, -0.5081, -0.4140, -0.2635,  2.3901, -0.5081, -2.0514,  0.3199,
           1.0727, -0.2823,  1.4491,  0.3576, -0.0000,  0.2447, -2.0514, -0.7716,
          -0.4705],
         [-0.0760,  0.1672, -1.3986,  0.5169,  1.2009,  0.0152,  1.1097, -1.9306,
          -1.3530, -1.4290,  0.3952,  0.1064, -0.4256, -0.1064, -0.3800,  0.9273,
           0.0456]]),
 tensor(0.0367))

In [52]:
wq, l2 = quantize_weight_per_channel_absmax(w, inplace=False, symmetric=False) 
wq, l2

w_min: tensor([[-1.6396],
        [-1.4418],
        [-2.0544],
        [-1.9306]])
w_max: tensor([[1.6686],
        [1.6871],
        [2.3901],
        [1.1954]])
s: tensor([[0.0130],
        [0.0123],
        [0.0174],
        [0.0123]])
zp: tensor([[126.],
        [118.],
        [118.],
        [157.]])


(tensor([[-0.2206, -0.8173,  0.2854, -1.6347,  0.8563,  1.2325, -0.2854,  0.4670,
           0.9341, -0.3243,  1.6736, -0.5189,  0.9990, -1.1546, -0.6617, -0.4281,
           0.1816],
         [-1.0920,  0.4172, -0.8589, -0.3926,  1.0307,  0.4908, -0.7607, -1.4479,
          -0.4049,  1.3129,  1.6810,  1.2025,  0.3436,  0.5521, -0.1350, -0.1718,
           1.4847],
         [-0.8889, -0.5055, -0.4183, -0.2614,  2.3879, -0.5055, -2.0567,  0.3312,
           1.0632, -0.2789,  1.4467,  0.3660, -0.0174,  0.2440, -2.0567, -0.7843,
          -0.4706],
         [-0.0736,  0.1594, -1.3975,  0.5149,  1.2014,  0.0245,  1.1033, -1.9247,
          -1.3485, -1.4220,  0.3923,  0.1103, -0.4291, -0.0981, -0.3923,  0.9317,
           0.0490]]),
 tensor(0.0347))

In [53]:
w = torch.randn(4,17)
w

tensor([[-0.7371,  0.5871,  2.3430, -0.9472,  1.5481, -1.8203,  0.0453,  1.2388,
         -0.9581, -0.3545, -1.1569,  0.4803,  0.6892, -0.0628, -1.7181, -1.3986,
         -0.2100],
        [ 0.0687,  1.4421, -1.1840, -0.1821,  0.3414,  0.2949, -0.4815, -0.0241,
         -1.5234, -0.2638,  0.1884, -0.6919,  0.4791, -0.3447,  0.7389, -1.4151,
          0.2188],
        [ 0.3938,  0.9661, -0.0706,  0.1853, -0.1085,  0.9662, -1.3678, -0.2627,
         -1.0061,  0.4234, -1.3399,  0.0443, -0.1558, -1.3252,  0.7670, -1.4057,
         -2.2499],
        [-1.0071, -0.9806, -1.2983, -0.5068, -0.0988,  0.5853, -2.0780,  0.9685,
         -0.8785, -0.4958,  2.4415, -0.4127, -0.5302,  0.0950,  0.6870, -0.3067,
         -0.1257]])

In [54]:
wq, l2 = quantize_weight_per_column_absmax(w, inplace=False, symmetric=True) 
wq, l2

(tensor([[-0.7375,  0.5905,  2.3430, -0.9472,  1.5481, -1.8203,  0.0491,  1.2388,
          -0.9596, -0.3553, -1.1535,  0.4794,  0.6892, -0.0626, -1.7181, -1.4040,
          -0.2126],
         [ 0.0714,  1.4421, -1.1807, -0.1790,  0.3413,  0.3010, -0.4745, -0.0195,
          -1.5234, -0.2655,  0.1922, -0.6919,  0.4775, -0.3443,  0.7441, -1.4151,
           0.2126],
         [ 0.3965,  0.9652, -0.0738,  0.1865, -0.1097,  0.9603, -1.3744, -0.2634,
          -1.0076,  0.4216, -1.3457,  0.0436, -0.1574, -1.3252,  0.7711, -1.4040,
          -2.2499],
         [-1.0071, -0.9765, -1.2914, -0.5072, -0.0975,  0.5876, -2.0780,  0.9657,
          -0.8756, -0.4958,  2.4415, -0.4140, -0.5318,  0.0939,  0.6899, -0.3120,
          -0.1240]]),
 tensor(0.0243))

In [55]:
wq, l2 = quantize_weight_per_column_absmax(w, inplace=False, symmetric=False) 
wq, l2

(tensor([[-0.7362,  0.5890,  2.3419, -0.9460,  1.5461, -1.8249,  0.0416,  1.2365,
          -0.9584, -0.3533, -1.1567,  0.4781,  0.6886, -0.0613, -1.7152, -1.3996,
          -0.2130],
         [ 0.0714,  1.4441, -1.1852, -0.1821,  0.3443,  0.2950, -0.4830, -0.0236,
          -1.5223, -0.2631,  0.1928, -0.6941,  0.4782, -0.3453,  0.7407, -1.4170,
           0.2227],
         [ 0.3956,  0.9691, -0.0714,  0.1865, -0.1104,  0.9616, -1.3656, -0.2650,
          -1.0065,  0.4218, -1.3346,  0.0460, -0.1578, -1.3255,  0.7699, -1.4040,
          -2.2460],
         [-1.0054, -0.9786, -1.2994, -0.5063, -0.0974,  0.5901, -2.0817,  0.9657,
          -0.8775, -0.4975,  2.4468, -0.4137, -0.5308,  0.0947,  0.6822, -0.3086,
          -0.1259]]),
 tensor(0.0195))

### Activation Quantization ### 

in llms, activation tensor can be 2D [batch, hidden_dim] or 3D [batch, seq_len, hidden_dim].

Here,

hidden_dim -> embedding dimension or channel dimension.

seq_len -> token dimension

For per token quantization, [batch, token, channel], first transform shape to [batch x token, channel].

Now rows are number of tokens, and columns are number of channel or embedding size of each token.

For per token quantization, find scale for each row (dim=-1).

For per channel quantization, find scale for each col (dim=-2). 



In [71]:
############################################################

def quantize_activation_per_tensor_absmax(a, n_bits=8, inplace=True, symmetric=True): 

    if symmetric:
        q_max = 2 ** (n_bits - 1) - 1 
        # print(f"q_max: {q_max}")
        s = a.abs().max() 
        # print(f"scale: {s}")
        if inplace: 
            s.clamp_(1e-5).div_(q_max)
            a.div_(s).round_().mul_(s)
            return a
        else: 
            s = s.clamp(1e-5) / q_max 
            a_q = (a / s).round() * s 
            l2_dist = torch.norm(a - a_q, p=2) 
            return a_q, l2_dist 
    else: 
        q_min = 0 
        q_max = 2 ** n_bits - 1
        a_min = a.min() 
        a_max = a.max() 

        s = (a_max - a_min).clamp(1e-5) / (q_max - q_min) 
        print(f"s: {s}")
        zp = q_min - torch.round(a_min / s) 
        zp = zp.clamp(q_min, q_max) 
        print(f"zp: {zp}")
        a_q = torch.round(a / s) + zp 
        a_deq = (a_q - zp) * s 

        l2_dist = torch.norm(a - a_deq, p=2) 
        return a_deq, l2_dist 

############################################################

def quantize_activation_per_token_absmax(a, n_bits=8, inplace=True, symmetric=True): 

    if symmetric:
        q_max = 2 ** (n_bits - 1) - 1 
        # print(f"q_max: {q_max}")
        s = a.abs().max(dim=-1, keepdim=True)[0] 
        # print(f"scale: {s}")
        if inplace: 
            s.clamp_(1e-5).div_(q_max) 
            a.div_(s).round_().mul_(s) 
            return a
        else: 
            s = s.clamp(1e-5) / q_max
            a_q = (a / s).round() * s 
            l2_dist = torch.norm(a - a_q, p=2) 
            return a_q, l2_dist 
    else: 
        q_min = 0 
        q_max = 2 ** n_bits - 1 
        a_min = a.min(dim=-1, keepdim=True)[0] 
        a_max = a.max(dim=-1, keepdim=True)[0] 

        s = (a_max - a_min).clamp(1e-5) / (q_max - q_min) 
        print(f"s: {s}")
        zp = (q_min - torch.round(a_min / s) ).clamp(q_min, q_max) 
        print(f"zp: {zp}")
        a_q = torch.round(a / s) + zp 
        a_deq = (a_q - zp) * s 

        l2_dist = torch.norm(a - a_deq, p=2) 
        return a_deq, l2_dist

############################################################ 

def quantize_activation_per_channel_absmax(a, n_bits=8, inplace=True, symmetric=True): 

    a_transform = a.view(-1, a.shape[-1]) #[batch, token, channel] -> [batch x token, channel] 

    if symmetric: 
        q_max = 2 ** (n_bits - 1) - 1 
        s = a_transform.abs().max(dim=-2, keepdim=True)[0] # max per column i.e., channels
        if inplace: 
            s.clamp_(1e-5).div_(q_max) 
            a.div_(s).round_().mul_(s) 
            return a 
        else: 
            s = s.clamp(1e-5) / q_max 
            a_q = (a / s).round() * s 
            l2_dist = torch.norm(a - a_q, p=2) 
            return a_q, l2_dist 
    else: 
        q_min = 0 
        q_max = 2 ** n_bits - 1 
        a_min = a_transform.min(dim=-2, keepdim=True)[0] 
        a_max = a_transform.max(dim=-2, keepdim=True)[0] 

        s = (a_max - a_min).clamp(1e-5) / (q_max - q_min) 
        print(f"s: {s}")
        zp = (q_min - torch.round(a_min / s)).clamp(q_min, q_max) 
        print(f"zp: {zp}")
        a_q = torch.round(a / s) + zp 
        a_deq = (a_q - zp) * s 
        l2_dist = torch.norm(a - a_deq, p=2) 
        return a_deq, l2_dist 


 

In [75]:
a = torch.randn((2,3,7))
a

tensor([[[-1.0994e+00, -1.5468e+00,  4.3185e-01, -1.3401e+00, -5.1583e-01,
          -1.8047e-03, -1.3345e+00],
         [-1.0263e+00,  4.1658e-02,  2.1585e-01, -1.5945e+00,  2.6714e-01,
          -3.1076e-01,  4.4648e-02],
         [-2.7162e-01,  5.3547e-02,  1.4429e+00,  1.3999e+00, -1.5108e+00,
           1.4921e+00,  2.0892e-01]],

        [[ 1.4114e-01,  2.5016e-01,  1.1074e+00,  1.2074e+00,  1.6801e+00,
           1.0770e+00,  1.9161e-01],
         [ 2.4013e+00, -3.3593e+00,  1.0788e+00,  1.7354e+00, -4.2466e-01,
           8.6096e-01,  8.8684e-01],
         [ 1.3117e+00,  1.8079e-01,  9.6708e-01, -1.5179e+00, -6.3877e-01,
          -7.0382e-01,  2.5925e-01]]])

In [76]:
aq, l2 = quantize_activation_per_tensor_absmax(a, inplace=False, symmetric=True)
aq, l2


(tensor([[[-1.1109, -1.5342,  0.4232, -1.3490, -0.5290, -0.0000, -1.3225],
          [-1.0316,  0.0529,  0.2116, -1.5871,  0.2645, -0.3174,  0.0529],
          [-0.2645,  0.0529,  1.4548,  1.4019, -1.5077,  1.4813,  0.2116]],
 
         [[ 0.1323,  0.2381,  1.1109,  1.2167,  1.6929,  1.0845,  0.1852],
          [ 2.4070, -3.3593,  1.0845,  1.7458, -0.4232,  0.8729,  0.8993],
          [ 1.3225,  0.1852,  0.9787, -1.5077, -0.6348, -0.7142,  0.2645]]]),
 tensor(0.0550))

In [77]:
aq, l2 = quantize_activation_per_tensor_absmax(a, inplace=False, symmetric=False)
aq, l2

s: 0.022590549662709236
zp: 149.0


(tensor([[[-1.1069, -1.5362,  0.4292, -1.3328, -0.5196,  0.0000, -1.3328],
          [-1.0166,  0.0452,  0.2259, -1.6039,  0.2711, -0.3163,  0.0452],
          [-0.2711,  0.0452,  1.4458,  1.4006, -1.5136,  1.4910,  0.2033]],
 
         [[ 0.1355,  0.2485,  1.1069,  1.1973,  1.6717,  1.0843,  0.1807],
          [ 2.3946, -3.3660,  1.0843,  1.7395, -0.4292,  0.8584,  0.8810],
          [ 1.3103,  0.1807,  0.9714, -1.5136, -0.6325, -0.7003,  0.2485]]]),
 tensor(0.0387))

In [68]:
a = torch.rand((2,3,7))
a

tensor([[[0.6862, 0.9213, 0.2398, 0.2512, 0.2083, 0.5578, 0.0448],
         [0.3540, 0.3802, 0.0582, 0.8646, 0.8948, 0.7454, 0.5220],
         [0.8517, 0.8722, 0.7148, 0.4096, 0.4370, 0.7026, 0.6147]],

        [[0.2200, 0.4913, 0.8703, 0.9941, 0.3274, 0.7329, 0.9646],
         [0.0307, 0.6621, 0.1406, 0.7983, 0.2755, 0.8204, 0.4318],
         [0.3658, 0.2045, 0.2352, 0.3311, 0.2605, 0.5256, 0.4939]]])

In [69]:
aq, l2 = quantize_activation_per_token_absmax(a, inplace=False, symmetric=True)
aq, l2

(tensor([[[0.6892, 0.9213, 0.2394, 0.2539, 0.2104, 0.5586, 0.0435],
          [0.3523, 0.3805, 0.0564, 0.8666, 0.8948, 0.7468, 0.5214],
          [0.8516, 0.8722, 0.7142, 0.4120, 0.4395, 0.7005, 0.6181]],
 
         [[0.2192, 0.4932, 0.8689, 0.9941, 0.3288, 0.7358, 0.9628],
          [0.0323, 0.6589, 0.1421, 0.8010, 0.2778, 0.8204, 0.4328],
          [0.3642, 0.2028, 0.2359, 0.3311, 0.2607, 0.5256, 0.4925]]]),
 tensor(0.0110))

In [70]:
aq, l2 = quantize_activation_per_token_absmax(a, inplace=False, symmetric=False)
aq, l2

s: tensor([[[0.0034],
         [0.0033],
         [0.0018]],

        [[0.0030],
         [0.0031],
         [0.0013]]])
zp: tensor([[[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])


(tensor([[[0.6875, 0.9213, 0.2406, 0.2509, 0.2097, 0.5569, 0.0447],
          [0.3543, 0.3805, 0.0591, 0.8661, 0.8956, 0.7447, 0.5216],
          [0.8526, 0.8726, 0.7147, 0.4100, 0.4372, 0.7020, 0.6150]],
 
         [[0.2186, 0.4918, 0.8712, 0.9927, 0.3279, 0.7316, 0.9653],
          [0.0310, 0.6627, 0.1394, 0.7990, 0.2756, 0.8207, 0.4305],
          [0.3651, 0.2040, 0.2355, 0.3311, 0.2606, 0.5250, 0.4936]]]),
 tensor(0.0050))

In [78]:
a = torch.rand((2,3,7))
a

tensor([[[0.3129, 0.9934, 0.8751, 0.9903, 0.1319, 0.1711, 0.1677],
         [0.6785, 0.0528, 0.2087, 0.2965, 0.2752, 0.9361, 0.8337],
         [0.2526, 0.0846, 0.6035, 0.3533, 0.2393, 0.0378, 0.5802]],

        [[0.9465, 0.2929, 0.7141, 0.4343, 0.9606, 0.2727, 0.2263],
         [0.5849, 0.8796, 0.7057, 0.7981, 0.7380, 0.2425, 0.1302],
         [0.6876, 0.8398, 0.2638, 0.1423, 0.6277, 0.2213, 0.3849]]])

In [79]:
aq, l2 = quantize_activation_per_channel_absmax(a, inplace=False, symmetric=True)
aq, l2

(tensor([[[0.3130, 0.9934, 0.8751, 0.9903, 0.1286, 0.1695, 0.1707],
          [0.6782, 0.0548, 0.2067, 0.2963, 0.2723, 0.9361, 0.8337],
          [0.2534, 0.0860, 0.6064, 0.3509, 0.2420, 0.0369, 0.5776]],
 
         [[0.9465, 0.2894, 0.7166, 0.4367, 0.9606, 0.2727, 0.2232],
          [0.5813, 0.8761, 0.7028, 0.7954, 0.7413, 0.2432, 0.1313],
          [0.6857, 0.8370, 0.2618, 0.1404, 0.6278, 0.2211, 0.3873]]]),
 tensor(0.0135))

In [80]:
aq, l2 = quantize_activation_per_channel_absmax(a, inplace=False, symmetric=False)
aq, l2

s: tensor([[0.0027, 0.0037, 0.0026, 0.0033, 0.0032, 0.0035, 0.0028]])
zp: tensor([[0., 0., 0., 0., 0., 0., 0.]])


(tensor([[[0.3130, 0.9923, 0.8755, 0.9910, 0.1332, 0.1726, 0.1683],
          [0.6776, 0.0516, 0.2091, 0.2960, 0.2762, 0.9370, 0.8331],
          [0.2531, 0.0848, 0.6037, 0.3525, 0.2405, 0.0387, 0.5793]],
 
         [[0.9470, 0.2914, 0.7134, 0.4356, 0.9620, 0.2712, 0.2262],
          [0.5851, 0.8779, 0.7056, 0.7981, 0.7377, 0.2431, 0.1297],
          [0.6885, 0.8411, 0.2639, 0.1430, 0.6272, 0.2219, 0.3862]]]),
 tensor(0.0057))

### W8A8 Class ###

Now for the main class, this is a nn.module subclass which will replace the linear layer of fp model with a custom layer class. 

This class will quantize the weights and activation before matmul. 

Forward pass takes input x and quantizes it then pass it through linear layer with weight which were initialized as random values in buffer but from float method loaded and quantized the weights and stored in buffer.

Every layer is a subclass of nn.Module. A module can contain paramters (weights, biases), buffers (non-trainable states), and other submodules (layers)</br>
e.g., nn.Linear, nn.Conv2d, nn.Sequential all belong to nn.Module

all the layers defined using this subclass are called instances of these modules.

layer1 = nn.Linear(in_features=4, out_features=3)

isinstance(layer1, nn.Module) -> True 

In [2]:
import torch 
import torch.nn as nn 


layer1 = nn.Linear(in_features=4, out_features=3) 

print(isinstance(layer1, nn.Module))
print(isinstance(layer1, nn.Linear)) 

True
True


  from .autonotebook import tqdm as notebook_tqdm


Module class is like a car model. When you buy one, you get an instance (your specific car with specific serial number etc). Similarly with layer1 you are instantiating a specific linear layer with its own parameters.

We can define our own subclass from nn.Module and create an instance of that subclass.

Defined layer will be instance of all parent classes.

In [3]:
class MyLinear(nn.Module): 

    def __init__(self, in_features, out_features): 
        super().__init__() 
        self.weight = nn.Parameter(torch.randn(out_features, in_features)) 
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x): 
        return x @ self.weight.T + self.bias


layer_1 = MyLinear(3, 5)
print(layer_1)

MyLinear()


In [5]:
layer_1.weight, layer_1.bias 

(Parameter containing:
 tensor([[ 0.3800, -1.4295,  0.6623],
         [-0.6595, -0.0987, -0.5230],
         [-1.8856, -0.4804, -0.6995],
         [ 1.1160, -1.4982,  0.2491],
         [ 0.3730, -0.7380, -0.1484]], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0.], requires_grad=True))

In [7]:
isinstance(layer_1, nn.Module), isinstance(layer_1, MyLinear) 

(True, True)

In OPT's modeling code we see the class class OPTDecoderLayer which has the following init method:

In [None]:
'''
class OPTDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.embed_dim = config.hidden_size

        self.self_attn = OPTAttention(config=config, layer_idx=layer_idx)

        self.do_layer_norm_before = config.do_layer_norm_before
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]

        self.self_attn_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
        )
        self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
        self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
'''



here we see fc1 and fc2 are instances of nn.Linear with some in and out features. Hence we want to take these instances and replace with our custom layers which perform the forward after quantization.

Every nn.Module class that we define can contain other submodules i.e., other layers or structure of model. Like we define a model as nn.Module subclass and within its init method we define other submodules like 1 linear layer another linear layer and 1 conv2d layer.

model.named_modules() gives an iterator of all submodules with their names. 

In [31]:
class MyNet(nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.fc1 = nn.Linear(4, 8) 
        self.relu = nn.ReLU() 
        self.fc2 = nn.Linear(8, 2) 
        self.fc3 = MyLinear(3, 5)
        self.seq = nn.Sequential(
            nn.Linear(2, 4),
            nn.ReLU(),
            nn.Linear(4, 2), 
            MyLinear(2, 9)) 

            
    def forward(self, x): 
        return self.fc2(self.relu(self.fc1(x))) 


net = MyNet() 

In [33]:
for name, module in net.named_modules(): 
    print(f"name: {name}")
    print(f"type(name): {type(name)}")
    print(f"module: {module}")
    print(f"type(module): {type(module)}")
    
    print(isinstance(module, nn.Module))
    print(isinstance(module, nn.Linear))
    print(isinstance(module, nn.ReLU))
    print(isinstance(module, MyLinear))
    print(isinstance(module, nn.Sequential))
    
    
    print("---------------")

name: 
type(name): <class 'str'>
module: MyNet(
  (fc1): Linear(in_features=4, out_features=8, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=8, out_features=2, bias=True)
  (fc3): MyLinear()
  (seq): Sequential(
    (0): Linear(in_features=2, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=2, bias=True)
    (3): MyLinear()
  )
)
type(module): <class '__main__.MyNet'>
True
False
False
False
False
---------------
name: fc1
type(name): <class 'str'>
module: Linear(in_features=4, out_features=8, bias=True)
type(module): <class 'torch.nn.modules.linear.Linear'>
True
True
False
False
False
---------------
name: relu
type(name): <class 'str'>
module: ReLU()
type(module): <class 'torch.nn.modules.activation.ReLU'>
True
False
True
False
False
---------------
name: fc2
type(name): <class 'str'>
module: Linear(in_features=8, out_features=2, bias=True)
type(module): <class 'torch.nn.modules.linear.Linear'>
True
True
False
False
False
---------------

So named modules goes through the entire init method and returns everything we initialized. Starts by model itself where there is no name and module contains everything. Then one by one if goes through every layer where name is variable name we assigned and module is from what class it was instantiated.

If it encounters nn.Sequential then after naming seq layer and whole sequential list as its module, it goes through each list element where name is seq_[index] and module is defined nn module.