In [1]:


import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:

parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [4]:

gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:


dims = [-(i + 1) for i in range(len(parameter_shape))]

In [6]:
dims

[-1, -2]

In [7]:

mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [8]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [9]:

var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta


In [12]:
out


tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

class

In [13]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [14]:

batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.3394,  1.1978,  1.0886,  0.3532, -0.0821,  1.7808, -0.4583,
           0.2895],
         [ 1.4722, -0.3206,  0.5372, -0.2004,  1.1477,  0.4039,  0.1370,
          -0.3159],
         [ 1.0962, -1.2243,  0.9908,  0.3668,  1.1975, -1.0236, -0.8583,
           0.3021]],

        [[-0.7205, -0.8839, -0.9601,  1.4237,  0.3245, -0.3246, -0.3709,
           2.1229],
         [-0.4210,  0.5898,  1.5925, -0.5100,  0.2002,  1.3171, -0.2650,
           0.0031],
         [ 0.0451,  0.1525, -1.3250, -0.0041,  0.9804,  0.1324, -2.0061,
           0.2426]],

        [[-0.4651,  0.8401, -1.2045,  0.5012, -0.6813, -2.1767,  0.2083,
           1.7364],
         [ 0.1049,  0.3159, -1.8046, -0.9896,  0.2334,  1.4384,  0.3297,
           0.8160],
         [ 0.4176,  0.3167,  0.3075,  2.6232, -0.5630,  0.6807,  0.5105,
           0.8698]],

        [[ 0.7936,  0.1628, -1.4500, -1.2782, -0.6968, -0.1072, -0.2945,
           0.1914],
         [-1.6593, -0.2724, 

In [15]:
layer_norm = LayerNormalization(inputs.size()[-1:])


In [16]:
out = layer_norm.forward(inputs)


Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.6886],
         [ 0.3576],
         [ 0.1059]],

        [[ 0.0764],
         [ 0.3133],
         [-0.2228]],

        [[-0.1552],
         [ 0.0555],
         [ 0.6454]],

        [[-0.3349],
         [-0.0989],
         [-0.0293]],

        [[-0.1648],
         [ 0.7624],
         [ 0.7480]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.7260],
         [0.6299],
         [0.9379]],

        [[1.0640],
         [0.7400],
         [0.8976]],

        [[1.1561],
         [0.9500],
         [0.8453]],

        [[0.7173],
         [1.1534],
         [0.8994]],

        [[1.0003],
         [0.6409],
         [0.9961]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.8964,  0.7013,  0.5509, -0.4619, -1.0616,  1.5043, -1.5797,
          -0.5497],
         [ 1.7693, -1.0766,  0.2850, -0.8858,  1.2542,  0.0735, -0.3503,
          -1.0692],
         [ 1.0559, -1.4183,  0.9435,  0.2782,  1.1639, -1.2043, -1.0281,
           0.2092]],



In [17]:
out[0].mean(), out[0].std()


(tensor(-9.9341e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))