In [1]:
1

1

In [28]:
import math
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):
    def __init__(self, embed_dim:int, eps:float = 1e-9) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(embed_dim)).float()
        self.beta = nn.Parameter(torch.ones(embed_dim)).float()
        self.eps = eps

    def forward(self, input_data:torch.Tensor) -> torch.Tensor:
        # Assume input dim(2, 3, 6)
        mean = torch.mean(input_data, dim=-1, keepdim=True) # (2, 3, 1)
        std = torch.std(input_data, dim=-1, keepdim=True) # (2, 3, 1)

        # to normalize (2, 3, 6) - (2, 3, 1) = (2, 3, 6) due to broadcasting
        normalized_input_data = (input_data - mean) / (std + self.eps)

        # some weights do not require normalized output so alpha learnable parameter is introduced 
        return self.alpha * normalized_input_data + self.beta
        

In [29]:
x = torch.arange(36).reshape(2, 3, 6).float()
x

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]],

        [[18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.]]])

In [30]:
mean = torch.mean(x, dim=-1, keepdim=True)
std = torch.std(x, dim= -1, keepdim=True)
mean, std

(tensor([[[ 2.5000],
          [ 8.5000],
          [14.5000]],
 
         [[20.5000],
          [26.5000],
          [32.5000]]]),
 tensor([[[1.8708],
          [1.8708],
          [1.8708]],
 
         [[1.8708],
          [1.8708],
          [1.8708]]]))

In [23]:
(x - mean) / std

tensor([[[-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363],
         [-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363],
         [-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363]],

        [[-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363],
         [-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363],
         [-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363]]])

In [25]:
torch.mean(torch.tensor([-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363])), torch.std(torch.tensor([-1.3363, -0.8018, -0.2673,  0.2673,  0.8018,  1.3363]))

(tensor(0.), tensor(1.0000))