<t1>Layer normalization</t1>

<p>For each row, get the differences from the average and divide them with the square of the variance (+ a small value to avoid dividing by 0).</p>
<p>If "elementwise_affine" is true (default) the whole thing is followed by a trainable linear layer.</p>
<p>Purpose: preserve variance, remove bias.<Tp>

In [38]:
import torch
from torch import nn

In [39]:
def manual_layer_norm(x, normalized_shape, eps=1e-5, weight=None, bias=None):
    dims = list(range(-len(normalized_shape), 0))  # e.g. [-1], or [-3,-2,-1]
    print("dims:", dims)
    mu = x.mean(dim=dims, keepdim=True)
    print("mean:", mu)
    var = x.var(dim=dims, unbiased=False, keepdim=True)
    print("var:", var)
    x_norm = (x - mu) / torch.sqrt(var + eps)
    if weight is not None:
        x_norm = x_norm * weight
    if bias is not None:
        x_norm = x_norm + bias
    return x_norm

In [40]:
# var = average difference from the average
print(torch.tensor([1.0, 1.0, 1.0]).var())
print(torch.tensor([1.0, 0.0, -1.0]).var())
print(torch.tensor([1.0, 1.0, 0.0]).var())

tensor(0.)
tensor(1.)
tensor(0.3333)


In [41]:
batch, sentence_length, embedding_dim = 2, 2, 4
torch.random.manual_seed(0)
embedding = torch.randn(batch, sentence_length, embedding_dim)

print(embedding)

tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
         [ 0.8487,  0.6920, -0.3160, -2.1152]],

        [[ 0.3223, -1.2633,  0.3500,  0.3081],
         [ 0.1198,  1.2377,  1.1168, -0.2473]]])


In [42]:
print(manual_layer_norm(embedding, [-1]))

dims: [-1]
mean: tensor([[[-0.7407],
         [-0.2226]],

        [[-0.0707],
         [ 0.5567]]])
var: tensor([[[0.1630],
         [1.3937]],

        [[0.4743],
         [0.4037]]])
tensor([[[-0.9539, -1.0196,  1.2137,  0.7598],
         [ 0.9075,  0.7747, -0.0791, -1.6031]],

        [[ 0.5706, -1.7316,  0.6109,  0.5501],
         [-0.6877,  1.0717,  0.8815, -1.2655]]])


In [43]:
layer_norm = nn.LayerNorm(embedding_dim)
print(layer_norm(embedding))

tensor([[[-0.9539, -1.0196,  1.2137,  0.7598],
         [ 0.9075,  0.7747, -0.0791, -1.6031]],

        [[ 0.5706, -1.7316,  0.6109,  0.5501],
         [-0.6877,  1.0717,  0.8815, -1.2655]]],
       grad_fn=<NativeLayerNormBackward0>)
