An implementation of https://arxiv.org/pdf/1607.06450

$$ \mu^{l} = \frac{1}{H} \sum_{i=1}^{H} a^{l}_{i} \qquad \sigma^l=\sqrt{\frac{1}{H} \sum_{i=1}^{H}(a_{i}^{l} - \mu^l)Â²}$$
, where $H$ denotes the number of hidden units in a layer, and $a_{i}^{l}$ is the actication of the $i$-th unit in layer $l$:

In [None]:
import torch
import torch.nn as nn


BATCH_SIZE = 32
N_FEATURES = 128


def layernorm(x: torch.Tensor, eps: float = 1e-5):
    activation_mean = torch.mean(x, dim=-1, keepdim=True)
    print(f"Activation mean shape: {activation_mean.shape}")
    activation_variance = torch.var(x, dim=-1, unbiased=False, keepdim=True)
    print(f"Activation variance shape: {activation_variance.shape}")
    x_hat = (x - activation_mean) / torch.sqrt(activation_variance + eps)
    print(f"x_hat shape: {x_hat.shape}")

    return x_hat


x = torch.rand([BATCH_SIZE, N_FEATURES])
normalized_x = layernorm(x)

Activation mean shape: torch.Size([32, 1])
Activation variance shape: torch.Size([32, 1])
x_hat shape: torch.Size([32, 128])


Similar to batch normalization we add two learnable parameters, $\gamma$ and $\beta$ which we use to shift and scale the normalized tensor: $y^{(k)} = \gamma^{(k)}\widehat{x}^{(k)} + \beta^{(k)} $. Since layernorm is stateless we don't need to calculate the running mean and variance as we do for batchnorm, and we don't require separate modes for training and inference.

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, n_features: int, eps: float = 1e-5):
        super().__init__()
        self.n_features = n_features
        self.eps = eps

        self.gamma = nn.Parameter(torch.ones(self.n_features))
        self.beta = nn.Parameter(torch.zeros(self.n_features))

    def forward(self, x: torch.Tensor):
        activation_mean = torch.mean(x, dim=-1, keepdim=True)
        activation_var = torch.var(x, dim=-1, unbiased=False, keepdim=True)

        x_hat = (x - activation_mean) / torch.sqrt(activation_var + self.eps)
        scaled_and_shifted = self.gamma * x_hat + self.beta

        return scaled_and_shifted