In [8]:
import torch
import torch.nn as nn

In [15]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [18]:
class LayerNormalization():
    def __init__(self, parameter_shape, eps=1e-5):
        self.parameter_shape = parameter_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameter_shape))
        self.beta = nn.Parameter(torch.zeros(parameter_shape))

    def forward(self, input):
        dims = [-(i+1) for i in range(len(self.parameter_shape))]
        mean = inputs.mean(dim = dims, keepdim=True)
        print(f"Mean size: ({mean.size()}): \n {mean}")
        var = ((inputs-mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var+self.eps).sqrt()
        print(f"Std size: ({std.size()}): \n {std}")
        y = (input - mean) / std
        print(f"Y size: ({y.size()}): \n {y}")
        out = self.gamma * y + self.beta
        print(f"Out size: ({out.size()}): \n {out}")
        return out

In [19]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size ,embedding_dim)
print(f"Out size: ({inputs.size()}): \n {inputs}")

Out size: (torch.Size([5, 3, 8])): 
 tensor([[[-0.5500, -1.1780,  0.7251, -0.7280,  1.1794,  0.1364, -1.4647,
           0.9952],
         [-1.0367, -1.2310,  0.1948, -1.6604, -0.0228,  0.1269,  0.0654,
           0.3323],
         [ 1.5870, -1.4790, -0.3670,  1.3185, -1.0514, -0.3926,  1.0626,
           2.1141]],

        [[-0.8286, -0.1633, -0.8340,  1.1443, -0.2363, -0.4382, -0.9609,
           0.6907],
         [-1.9499, -0.9663, -1.3738,  1.1423, -0.9968, -2.0265,  0.0133,
          -0.7461],
         [-0.9786, -0.8595,  0.5756, -0.6386,  0.4608, -0.5107, -0.7438,
           0.3069]],

        [[-1.2470, -0.7354,  0.4540,  0.7019, -0.4166, -0.6381, -0.0074,
          -0.7602],
         [-0.3209, -1.3230,  0.3226,  0.0558, -1.0550,  1.5281, -1.5776,
           1.0734],
         [ 0.4035,  1.4981, -0.1282,  0.0438,  0.3550, -0.6199,  0.2412,
           1.2206]],

        [[-0.3736,  1.9303, -0.6937,  0.8747,  0.9640,  0.5960, -0.3872,
          -0.0855],
         [ 0.9654,  0.1855,

In [20]:
layer_norm = LayerNormalization(inputs.size()[-2:])
out = layer_norm.forward(inputs)

Mean size: (torch.Size([5, 1, 1])): 
 tensor([[[-0.0552]],

        [[-0.4549]],

        [[-0.0388]],

        [[-0.1366]],

        [[ 0.0969]]])
Std size: (torch.Size([5, 1, 1])): 
 tensor([[[1.0443]],

        [[0.8302]],

        [[0.8527]],

        [[0.8717]],

        [[0.8443]]])
Y size: (torch.Size([5, 3, 8])): 
 tensor([[[-0.4738, -1.0753,  0.7471, -0.6443,  1.1822,  0.1835, -1.3497,
           1.0058],
         [-0.9399, -1.1260,  0.2393, -1.5372,  0.0310,  0.1744,  0.1155,
           0.3710],
         [ 1.5725, -1.3635, -0.2986,  1.3155, -0.9540, -0.3232,  1.0704,
           2.0773]],

        [[-0.4502,  0.3512, -0.4566,  1.9263,  0.2633,  0.0201, -0.6095,
           1.3799],
         [-1.8008, -0.6159, -1.1068,  1.9240, -0.6527, -1.8931,  0.5640,
          -0.3508],
         [-0.6308, -0.4873,  1.2413, -0.2213,  1.1031, -0.0672, -0.3480,
           0.9176]],

        [[-1.4169, -0.8169,  0.5779,  0.8686, -0.4431, -0.7028,  0.0369,
          -0.8460],
         [-0.3309, -