https://www.youtube.com/watch?v=G45TuC6zRf4&list=PLTl9hO2Oobd97qfWC40gOSU8C0iu0m2l4&index=4

<img src="../pictures/layer_norm.png" alt="pic" width="300">


Here we focus on the first "Add & Norm" layer of the Encoder. This layer takes as input the output matrix of the first Multi-head attention block. It also has a residual connection from the input to the Encoder + Positional Embeddings. These residuals connections are added to ensure that is a stronger information signal that flows trhough Deep Networks and to prevent loss of information in backpropagation from vanishing gradients (gradient updates become 0). It ensures stable training and better convergence.

There in Layer Add & Norm we "Add" the output of the Multi-Headed attention block and the input + positional encodings together and pass the result to a Normalization Layer.

During Normalization, the activation values in each neuron adjusted such that their mean is 0 and their standard deviation is 1 (relative to its layer). So for every activation value:

$$a_i =f[W_{i}^{T}, x+b_i]$$

We normalize by:
$$y =\gamma_l[\frac{a_i-\mu_i}{\sigma_i}] + \beta_l$$

where $i$ represents the $i$th activation neuron, and $l$ the $l$th layer

$\gamma$ and $\beta$ are learnable parameters, there is one pair per layer. However, the mean and std are computed for each word


In [1]:
import torch
from torch import nn

In [58]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
# B = batch size, S = number of words in max sequence, E = embedding
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [59]:
inputs

tensor([[[0.2000, 0.1000, 0.3000]],

        [[0.5000, 0.1000, 0.1000]]])

Layer normalization is actually computed across layer and batch. Therefore, we get distinct pairs of $\gamma$ and $\beta$ for each layer-batch combination : 

In [60]:
parameter_shapes = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shapes))
beta = nn.Parameter(torch.zeros(parameter_shapes))

In [61]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [62]:
# compute dimensions for which we want to perform layer normalization (that is the batch + the embedding dimensions)
# hint: its the last two layers
dims = [-(i+1) for i in range(len(parameter_shapes))]

In [63]:
dims

[-1, -2]

In [64]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [65]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [66]:
var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
# ensure std isn't 0 since it is the denominator
epsilon = 1e-5
std = (var + epsilon).sqrt()

In [67]:
y = (inputs - mean) / std

In [68]:
out = gamma *y + beta

In [75]:
y.size()

torch.Size([2, 1, 3])

In [76]:
gamma.size()

torch.Size([1, 3])

In [83]:
class LayerNormalization:
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        # compute dimensions for which we want to perform layer normalization (that is the batch + the embedding dimensions)
        # hint: its the last two layers
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean})")
        var = ((inputs - mean)**2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Std \n ({std.size()}): \n {std})")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}): \n {y})")
        out = self.gamma * y + self.beta
        print(f"Output \n ({out.size()}): \n {out})")
        return out
        
        

In [84]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

In [85]:
nn.Parameter(torch.ones(inputs.size()[-2:])).size()

torch.Size([3, 8])

In [86]:
layer_norm = LayerNormalization(parameters_shape=inputs.size()[-2:])

In [87]:
layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 1, 1])): 
 tensor([[[ 0.1247]],

        [[ 0.1336]],

        [[-0.0469]],

        [[-0.1399]],

        [[ 0.1343]]]))
Std 
 (torch.Size([5, 1, 1])): 
 tensor([[[0.8218]],

        [[1.0480]],

        [[0.9055]],

        [[1.2428]],

        [[1.0522]]]))
y 
 (torch.Size([5, 3, 8])): 
 tensor([[[-0.1071,  1.3647, -0.7985, -1.2234, -1.3445,  0.3726, -0.6062,
          -0.9186],
         [-1.4119,  1.1061,  1.5628, -0.2964,  2.0020,  0.3794, -0.5848,
          -0.5192],
         [ 0.6890, -0.4799, -0.6190,  0.5880,  0.8222,  1.6510, -1.0985,
          -0.5298]],

        [[ 1.0896, -0.8622,  1.3571,  0.5579, -0.4322,  1.1513,  1.4784,
          -0.9235],
         [-1.1008, -0.6334,  0.2749, -0.1975,  0.0471,  0.3669,  0.7266,
           0.4657],
         [-0.8278, -0.1707, -1.5194, -0.2840, -1.8340,  1.5010, -1.5312,
           1.3004]],

        [[ 0.2462, -1.0931, -1.5495, -1.4534,  1.6599,  0.3497,  1.5365,
          -1.3019],
         [-0.1945,  0.5511,  0

tensor([[[-0.1071,  1.3647, -0.7985, -1.2234, -1.3445,  0.3726, -0.6062,
          -0.9186],
         [-1.4119,  1.1061,  1.5628, -0.2964,  2.0020,  0.3794, -0.5848,
          -0.5192],
         [ 0.6890, -0.4799, -0.6190,  0.5880,  0.8222,  1.6510, -1.0985,
          -0.5298]],

        [[ 1.0896, -0.8622,  1.3571,  0.5579, -0.4322,  1.1513,  1.4784,
          -0.9235],
         [-1.1008, -0.6334,  0.2749, -0.1975,  0.0471,  0.3669,  0.7266,
           0.4657],
         [-0.8278, -0.1707, -1.5194, -0.2840, -1.8340,  1.5010, -1.5312,
           1.3004]],

        [[ 0.2462, -1.0931, -1.5495, -1.4534,  1.6599,  0.3497,  1.5365,
          -1.3019],
         [-0.1945,  0.5511,  0.4614, -1.3758, -1.5687,  0.0930,  0.5842,
           0.7743],
         [ 0.6928,  0.4164,  0.4343, -1.5143,  0.2965,  0.8982, -0.1242,
           1.1810]],

        [[ 0.5049,  0.7949, -0.1415,  0.0268, -0.4009,  0.8883, -0.3975,
          -1.7064],
         [ 0.4681,  0.9660, -0.7651, -1.2775, -1.3326,  0.7993, 