In [1]:
import torch
from torch import nn

In [5]:
inputs = torch.tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()

In [6]:
B, S, E    #Batch, Words, Embeddings

(1, 2, 3)

In [7]:
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [12]:
parameters_shape = inputs.size()[-2:]
parameters_shape

torch.Size([1, 3])

In [14]:
gamma = nn.Parameter(torch.ones(parameters_shape))
beta = nn.Parameter(torch.zeros(parameters_shape))

In [16]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [17]:
dims = [-(i + 1) for i in range(len(parameters_shape))]

In [18]:
dims

[-1, -2]

In [21]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [20]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [22]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std =  (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [23]:
y = (inputs - mean) / std

In [24]:
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [25]:
out = gamma * y + beta

In [26]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [27]:
class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out