In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()

inputs = inputs.reshape(S, B, E)
inputs.shape # num_words, batch_size, embedding_size

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.shape[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.shape, beta.shape

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i+1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [23]:
parameter_shape

torch.Size([1, 3])

In [6]:
mean = inputs.mean(dims, keepdim=True)
mean.shape

torch.Size([2, 1, 1])

In [7]:
vari = ((inputs-mean)**2).mean(dims, keepdim=True)
epsilon = 1e-5
std = (vari + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [8]:
y = (inputs-mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [9]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# Layer Normalization Class

In [24]:
class LayerNormalization(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        super(LayerNormalization, self).__init__()
        self.hidden_size = hidden_size
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.eps = eps

    def forward(self, x):
        dim = [-(i+1) for i in range(len(self.hidden_size))]
        print('Dim: ', dim)

        mean = x.mean(dim=dim, keepdim=True)
        print(f"Mean\n ({mean.shape}): \n{mean}")

        var = ((x - mean) ** 2).mean(dim=dim, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standerd Deviation\n ({var.shape}): \n{var}")
        y = (x - mean) / std

        return self.gamma * y + self.beta

In [25]:
batch_size = 3
sentence_length = 5
embedding_dim = 8

input_tensor = torch.randn(sentence_length, batch_size, embedding_dim)
print('input_tensor.shape[-2:] -->', input_tensor.shape[-2:])
ln = LayerNormalization(input_tensor.shape[-2:])
output = ln.forward(input_tensor)
output

input_tensor.shape[-2:] --> torch.Size([3, 8])
Dim:  [-1, -2]
Mean
 (torch.Size([5, 1, 1])): 
tensor([[[ 0.0224]],

        [[ 0.1338]],

        [[-0.2999]],

        [[ 0.5752]],

        [[-0.3409]]])
Standerd Deviation
 (torch.Size([5, 1, 1])): 
tensor([[[0.7733]],

        [[0.9201]],

        [[1.5157]],

        [[1.0057]],

        [[0.8369]]])


tensor([[[ 1.3838, -1.6031, -0.5537,  0.0748, -0.4604, -1.7306,  1.0826,
           0.6129],
         [ 0.4923,  2.0879, -0.3854, -0.5606,  0.4960,  0.0489,  0.5850,
          -1.8244],
         [ 0.0331,  0.3365, -0.0289,  0.8907, -1.0897,  0.3105,  1.1760,
          -1.3743]],

        [[-0.0053,  0.7428,  0.0361, -1.8225, -0.6666,  1.4805,  0.2582,
           0.6601],
         [-1.8962,  1.5772, -0.1669, -0.6593, -0.4069,  0.9902,  0.6890,
           0.9179],
         [ 0.1525, -1.3686,  0.3098, -1.8551, -0.8267,  1.3058,  0.3051,
           0.2490]],

        [[ 0.1462, -0.6165,  2.3118, -0.9694, -1.5421,  1.3778, -0.0656,
          -1.1791],
         [-1.3477,  1.1242, -0.7854,  0.2151,  0.0560,  0.1809, -0.2846,
           0.7513],
         [-0.7121, -0.2379,  0.8543,  0.6686,  0.1919, -1.9130,  1.2085,
           0.5668]],

        [[ 0.3522,  2.5047, -0.0721,  0.7945, -0.4462, -0.6950,  0.3473,
          -0.2804],
         [-1.1647,  1.2981,  1.4910,  1.3706, -1.0406, -0.0195, 