An implementation of https://arxiv.org/pdf/1502.03167

$$\widehat{x}^{(k)} = \frac{x^{(k)} - \text{E}[x^{(k)} ]}{\sqrt{\text{Var}[x^{(k)}]}}$$ , where k is the dimension of a d-dimensional input $$x = (x^{(1)} \text{... } x^{(d)})$$, meaning we normalize for each dimension.

In [3]:
import torch
import torch.nn as nn


BATCH_SIZE = 32
NUM_FEATURES = 128


def normalize(x: torch.Tensor, eps: float = 1e-5):
    # x.shape = batch_size, n_features
    numerator = x - torch.mean(x, dim=0) # dim 0 since we wanna collapse the batch dimension (collapsing all rows into a single row)
    print(f"Numerator tensor shape: {numerator.shape}")
    denominator = torch.sqrt(torch.var(x, dim=0, unbiased=False) + eps)
    normalized_tensor = numerator/denominator
    print(f"Normalized tensor shape: {normalized_tensor.shape}")

    return normalized_tensor


input_tensor = torch.rand(BATCH_SIZE, NUM_FEATURES)
normalized_tensor = normalize(input_tensor)
normalized_tensor.shape

Numerator tensor shape: torch.Size([32, 128])
Normalized tensor shape: torch.Size([32, 128])


torch.Size([32, 128])

We then add two parameters, gamma and beta, which we use to scale and shift the normalized tensor: $$y^{(k)} = \gamma^{(k)}\widehat{x}^{(k)} + \beta^{(k)} $$.

In [None]:
gamma = nn.Parameter(torch.ones(NUM_FEATURES))
beta = nn.Parameter(torch.zeros(NUM_FEATURES))

torch.Size([64])

We calculate the exponential moving average to estimate the mean and variance of the entire training set using the following formula: $$ EMA = (1 - \alpha) \cdot \text{old} + \alpha \cdot \text{new} $$.

In [None]:
class BatchNorm(nn.Module):
    def __init__(self, n_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.n_features = n_features
        self.eps = eps
        self.momentum = momentum

        self.gamma = nn.Parameter(torch.ones(self.n_features))
        self.beta = nn.Parameter(torch.zeros(self.n_features))

        # Used to calculate EMA instead of calculating mean and variance of the entire training set after training. 
        self.register_buffer('running_mean', torch.zeros(n_features))
        self.register_buffer('running_var', torch.ones(n_features))

    def forward(self, x: torch.Tensor):
        if self.training:
            batch_mean = torch.mean(x, dim=0)
            batch_var = torch.var(x, dim=0, unbiased=False)

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

            mean = batch_mean
            var = batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        scaled_and_shifted = self.gamma * x_hat + self.beta

        return scaled_and_shifted