In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt



In [None]:
torch.manual_seed(123)
input = torch.randn(2,5)
input

tensor([[-0.1115,  0.1204, -0.3696, -0.2404, -1.1969],
        [ 0.2093, -0.9724, -0.7550,  0.3239, -0.1085]])

In [None]:
embd_dim = 6 #  nodes
layer = nn.Sequential(nn.Linear(input.shape[1], 6),
                      nn.ReLU())
out = layer(input)
out

tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)

In [None]:
mean = out.mean(dim=-1, keepdim=True)
var = out.var(dim=-1, keepdim=True)
mean, var

(tensor([[0.1324],
         [0.2170]], grad_fn=<MeanBackward1>),
 tensor([[0.0231],
         [0.0398]], grad_fn=<VarBackward0>))

In [None]:
# Layer normalistion = mean/Std    ==> Std = sqrt(var)
normed = ((out - mean)/torch.sqrt(var))
normed.mean(dim=-1,keepdim=True)  # mean = 0
normed.var(dim=-1,keepdim=True)   # Var = 1

tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)

In [None]:
torch.set_printoptions(sci_mode=False)

## **Layer Normalization Class**

In [None]:
class LayerNormalization(nn.Module):
  def __init__(self, emb_dim):
    super().__init__()
    self.eps = 1e-5
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.zeros(emb_dim))

  def forward(self,x):
    self.mean   = x.mean(dim=-1, keepdim=True)
    self.var    = x.var(dim=-1, keepdim=True, unbiased=False)
    self.x_norm = (x - mean)/torch.sqrt(self.var + self.eps)
    return self.scale * self.x_norm + self.shift

In [None]:
normLayer = LayerNormalization(embd_dim)
nm_out = normLayer(out)
nm_out.mean(dim=-1,keepdim=True)

tensor([[     0.0000],
        [    -0.0000]], grad_fn=<MeanBackward1>)