<a href="https://colab.research.google.com/github/samitha278/CoreLlama/blob/main/norm_methods.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

### test torch mean method

In [None]:
torch.manual_seed(278)

x = torch.randn((2,4,12))

print(x)

tensor([[[-0.3113, -1.6257, -0.4428,  0.7869, -2.3081, -2.4534,  2.5515,
           1.5013, -0.4279,  0.0149,  0.6168,  2.5252],
         [-0.4348,  1.8983, -0.4243,  1.1160,  1.3476, -1.8999,  0.2999,
           2.0132, -0.0537,  0.0273, -0.9289,  2.4260],
         [ 0.5965, -0.9634,  0.6497,  0.3516, -0.5396,  0.5949, -0.8981,
           1.1714, -0.2333,  0.9272,  1.0551,  0.7002],
         [-0.7530,  0.3305, -1.3149, -0.8310, -0.9323, -0.2117,  0.3902,
           0.9124, -0.6891,  0.4506,  3.0519, -0.5101]],

        [[-0.3543,  2.1348,  1.1454,  0.4737, -0.6503,  2.7713, -1.4388,
           1.0588, -1.2221, -0.6614, -2.8858, -0.3586],
         [-0.3820, -1.2520,  0.3543,  1.0647,  0.1902,  0.0061, -0.1596,
          -0.1823,  0.1748, -0.0332,  1.1809,  1.9225],
         [-0.1120,  1.1314, -0.1735, -0.5383, -2.2424, -0.7450, -1.2321,
          -0.8446,  0.6800,  1.0753,  0.1750,  0.8562],
         [ 2.9217, -0.5220,  2.1552,  2.1454, -2.8136, -0.4572,  0.6009,
          -0.8725, -1.

In [None]:
x.mean(-1,keepdim=True)

tensor([[[ 0.0356],
         [ 0.4489],
         [ 0.2843],
         [-0.0089]],

        [[ 0.0011],
         [ 0.2404],
         [-0.1642],
         [ 0.3047]]])

In [None]:
l = x[0,0]

In [None]:
mean_1 = l.sum(-1)/len(l)
print(mean_1)

tensor(0.0356)


### test torch standard deviation (sqrt of variance) method

In [None]:
x.std(-1,keepdim=True)

tensor([[[1.6676],
         [1.3184],
         [0.7508],
         [1.1773]],

        [[1.5926],
         [0.8255],
         [1.0218],
         [1.7292]]])

In [None]:
var_1 = (((l - mean_1.item())**2).sum(-1))/(len(l)-1)   # divide by n-1  # for sample variance
print(var_1**0.5)

tensor(1.6676)


In [None]:
l_norm = (l-mean_1)/torch.sqrt(var_1)

l_norm.var(-1)

tensor(1.0000)

In [None]:
l_temp = l_norm*5
l_temp.var(-1)

tensor(25.0000)

## Layer Normalization from scratch

$$
\mu = \frac{1}{n}\sum_{i=1}^{n} x_i
$$

$$
\sigma^2 = \frac{1}{n-1}\sum_{i=1}^{n} (x_i - \mu)^2
$$

$$
\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}
$$

$$
y = \gamma \odot \hat{x} + \beta
$$



In [None]:
class LayerNorm(nn.Module):

    def __init__(self,dim,ln_eps = 1e-6):
        super().__init__()


        self.dim = dim
        self.ln_eps = ln_eps

        self.gamma = nn.Parameter(torch.ones((dim)))
        self.beta = nn.Parameter(torch.zeros((dim)))


    def forward(self,x):

        # x : [B,T,C]   ; C = dim

        mean = torch.mean(x,-1,keepdim=True)
        var = torch.var(x,-1, unbiased=False,keepdim=True)  # Bessel's correction

        x_norm = (x - mean)/torch.sqrt(var+self.ln_eps)

        y = x_norm * self.gamma + self.beta

        return y


In [None]:
torch.manual_seed(278)

x = torch.randn((2,4,12))

ln_custom = LayerNorm(12)
x_norm_custom = ln_custom(x)

x_norm_custom

tensor([[[-0.2173, -1.0405, -0.2996,  0.4705, -1.4679, -1.5590,  1.5757,
           0.9180, -0.2903, -0.0130,  0.3640,  1.5593],
         [-0.7000,  1.1482, -0.6918,  0.5285,  0.7120, -1.8607, -0.1181,
           1.2393, -0.3981, -0.3340, -1.0915,  1.5663],
         [ 0.4343, -1.7358,  0.5082,  0.0936, -1.1462,  0.4320, -1.6449,
           1.2339, -0.7201,  0.8943,  1.0722,  0.5785],
         [-0.6602,  0.3011, -1.1587, -0.7294, -0.8192, -0.1799,  0.3541,
           0.8173, -0.6035,  0.4077,  2.7155, -0.4447]],

        [[-0.2331,  1.3994,  0.7505,  0.3100, -0.4272,  1.8168, -0.9443,
           0.6937, -0.8022, -0.4344, -1.8933, -0.2359],
         [-0.7875, -1.8882,  0.1442,  1.0430, -0.0635, -0.2963, -0.5060,
          -0.5348, -0.0829, -0.3462,  1.1899,  2.1283],
         [ 0.0533,  1.3242, -0.0096, -0.3824, -2.1242, -0.5937, -1.0915,
          -0.6955,  0.8628,  1.2669,  0.3467,  1.0429],
         [ 1.5807, -0.4994,  1.1177,  1.1119, -1.8836, -0.4602,  0.1789,
          -0.7111, -0.

In [None]:
torch.manual_seed(278)

ln_torch = torch.nn.LayerNorm(12)
x_norm_torch = ln_torch(x)

x_norm_torch

tensor([[[-0.2173, -1.0405, -0.2996,  0.4705, -1.4679, -1.5589,  1.5757,
           0.9180, -0.2903, -0.0130,  0.3640,  1.5593],
         [-0.7000,  1.1482, -0.6918,  0.5285,  0.7120, -1.8607, -0.1181,
           1.2393, -0.3981, -0.3340, -1.0915,  1.5663],
         [ 0.4343, -1.7358,  0.5082,  0.0936, -1.1462,  0.4320, -1.6449,
           1.2339, -0.7201,  0.8942,  1.0722,  0.5785],
         [-0.6602,  0.3011, -1.1587, -0.7294, -0.8192, -0.1799,  0.3541,
           0.8173, -0.6035,  0.4077,  2.7155, -0.4447]],

        [[-0.2331,  1.3994,  0.7505,  0.3100, -0.4272,  1.8167, -0.9443,
           0.6937, -0.8022, -0.4344, -1.8933, -0.2359],
         [-0.7875, -1.8882,  0.1442,  1.0430, -0.0635, -0.2963, -0.5060,
          -0.5348, -0.0829, -0.3462,  1.1899,  2.1283],
         [ 0.0533,  1.3242, -0.0096, -0.3824, -2.1242, -0.5937, -1.0915,
          -0.6955,  0.8628,  1.2668,  0.3467,  1.0429],
         [ 1.5807, -0.4994,  1.1177,  1.1119, -1.8836, -0.4602,  0.1789,
          -0.7111, -0.

In [None]:
torch.allclose(x_norm_torch,x_norm_custom)

True

## RMS Norm from scracth

$$
\text{RMS}(x) = \sqrt{\epsilon + \frac{1}{d}\sum_{i=1}^{d} x_i^2}
$$

$$
\hat{x} = \frac{x}{\text{RMS}(x)}
$$

$$
y = \gamma \odot \hat{x}
$$

In [None]:
class RMSNorm(nn.Module):

    def __init__(self,dim,rn_eps = 1e-6):
        super().__init__()

        self.dim = dim
        self.rn_eps = rn_eps

        self.gamma = nn.Parameter(torch.ones(dim))


    def forward(self,x):

        # x : [B,T,C]

        rms = torch.sqrt((torch.sum(x**2,dim=-1,keepdim=True)/self.dim) + self.rn_eps)

        x_norm = x/rms

        y = self.gamma * x_norm

        return y


In [None]:
torch.manual_seed(278)

x = torch.randn((2,4,12))

rn_custom = RMSNorm(12)
x_norm_custom = rn_custom(x)

x_norm_custom

tensor([[[-0.1949, -1.0179, -0.2773,  0.4927, -1.4453, -1.5363,  1.5976,
           0.9401, -0.2679,  0.0093,  0.3862,  1.5812],
         [-0.3245,  1.4169, -0.3167,  0.8330,  1.0059, -1.4181,  0.2238,
           1.5027, -0.0401,  0.0204, -0.6933,  1.8108],
         [ 0.7717, -1.2463,  0.8404,  0.4548, -0.6980,  0.7695, -1.1618,
           1.5153, -0.3018,  1.1994,  1.3649,  0.9058],
         [-0.6680,  0.2932, -1.1665, -0.7372, -0.8271, -0.1878,  0.3462,
           0.8094, -0.6113,  0.3998,  2.7075, -0.4525]],

        [[-0.2324,  1.4001,  0.7512,  0.3107, -0.4265,  1.8175, -0.9436,
           0.6944, -0.8014, -0.4337, -1.8926, -0.2351],
         [-0.4624, -1.5155,  0.4289,  1.2888,  0.2302,  0.0074, -0.1932,
          -0.2207,  0.2116, -0.0402,  1.4294,  2.3272],
         [-0.1129,  1.1405, -0.1749, -0.5426, -2.2604, -0.7510, -1.2420,
          -0.8514,  0.6855,  1.0839,  0.1764,  0.8630],
         [ 1.7356, -0.3101,  1.2803,  1.2745, -1.6714, -0.2716,  0.3569,
          -0.5183, -0.

In [None]:
torch.manual_seed(278)

rn_torch = torch.nn.RMSNorm(12)
x_norm_torch = rn_torch(x)

x_norm_torch

tensor([[[-0.1949, -1.0179, -0.2773,  0.4927, -1.4453, -1.5363,  1.5976,
           0.9401, -0.2679,  0.0093,  0.3862,  1.5812],
         [-0.3245,  1.4169, -0.3167,  0.8330,  1.0059, -1.4181,  0.2238,
           1.5027, -0.0401,  0.0204, -0.6933,  1.8108],
         [ 0.7717, -1.2463,  0.8404,  0.4548, -0.6980,  0.7695, -1.1618,
           1.5153, -0.3018,  1.1994,  1.3649,  0.9058],
         [-0.6680,  0.2932, -1.1665, -0.7372, -0.8271, -0.1878,  0.3462,
           0.8094, -0.6113,  0.3998,  2.7075, -0.4525]],

        [[-0.2324,  1.4001,  0.7512,  0.3107, -0.4265,  1.8175, -0.9436,
           0.6944, -0.8014, -0.4337, -1.8926, -0.2351],
         [-0.4624, -1.5155,  0.4289,  1.2888,  0.2302,  0.0074, -0.1932,
          -0.2207,  0.2116, -0.0402,  1.4294,  2.3272],
         [-0.1129,  1.1405, -0.1749, -0.5426, -2.2604, -0.7510, -1.2420,
          -0.8514,  0.6855,  1.0839,  0.1764,  0.8630],
         [ 1.7356, -0.3101,  1.2803,  1.2745, -1.6714, -0.2716,  0.3569,
          -0.5183, -0.

In [None]:
torch.allclose(x_norm_torch,x_norm_custom)

True

```
Initially: gamma = [1, 1, 1, ..., 1] (all ones)
During pre-training, gamma learns to:

Rescale dimensions - Some feature dimensions may be more important than others
Recover lost scale - Normalization removes the original scale, gamma restores useful scaling
Adapt to downstream tasks - Different dimensions get different importance weights

Key observations from actual trained models:

Not all 1s anymore - Values deviate from initial 1.0
Some dimensions up-weighted (values > 1.0) - These features are more important
Some dimensions down-weighted (values < 1.0) - These features are less important
Patterns emerge - Similar dimensions get similar scaling factors  
```