An implementation of https://arxiv.org/pdf/1502.03167

$$\widehat{x}^{(k)} = \frac{x^{(k)} - \text{E}[x^{(k)} ]}{\sqrt{\text{Var}[x^{(k)}]}}$$ , where k is the dimension of a d-dimensional input $$x = (x^{(1)} \text{... } x^{(d)})$$, meaning we normalize for each dimension.

In [None]:
import torch
import torch.nn as nn


BATCH_SIZE = 32
NUM_FEATURES = 128


def normalize(x: torch.Tensor, eps: float = 1e-5):
    # x.shape = batch_size, n_features
    numerator = x - torch.mean(x, dim=0) # dim 0 since we wanna collapse the batch dimension (collapsing all rows into a single row)
    print(f"Denominator tensor shape: {denominator.shape}")
    denominator = torch.sqrt(torch.var(x, dim=0, unbiased=False) + eps)
    normalized_tensor = numerator/denominator
    print(f"Normalized tensor shape: {normalized_tensor.shape}")

    return normalized_tensor


input_tensor = torch.rand(BATCH_SIZE, NUM_FEATURES)
normalized_tensor = normalize(input_tensor)
normalized_tensor.shape

Denominator tensor shape: torch.Size([32, 128])
Normalized tensor shape: torch.Size([32, 128])


torch.Size([32, 128])