An implementation of https://arxiv.org/pdf/1502.03167

$$\widehat{x}^{(k)} = \frac{x^{(k)} - \text{E}[x^{(k)} ]}{\sqrt{\text{Var}[x^{(k)}]}}$$ , where k is the dimension of a d-dimensional input $$x = (x^{(1)} \text{... } x^{(d)})$$, meaning we normalize for each dimension.

In [5]:
import torch
import torch.nn as nn


BATCH_SIZE = 32
NUM_FEATURES = 128


def normalize(x: torch.Tensor, eps: float = 1e-5):
    # x.shape = batch_size, n_features
    numerator = x - torch.mean(x, dim=0) # dim 0 since we wanna collapse the batch dimension (collapsing all rows into a single row)
    print(f"Numerator tensor shape: {numerator.shape}")
    denominator = torch.sqrt(torch.var(x, dim=0, unbiased=False) + eps)
    normalized_tensor = numerator/denominator
    print(f"Normalized tensor shape: {normalized_tensor.shape}")

    return normalized_tensor


input_tensor = torch.rand(BATCH_SIZE, NUM_FEATURES)
normalized_tensor = normalize(input_tensor)
normalized_tensor.shape

Numerator tensor shape: torch.Size([32, 128])
Normalized tensor shape: torch.Size([32, 128])


torch.Size([32, 128])

We then add two parameters, gamma and beta, which we use to scale and shift the normalized tensor: $$y^{(k)} = \gamma^{(k)}\widehat{x}^{(k)} + \beta^{(k)} $$.

In [None]:
gamma = nn.Parameter(torch.ones(NUM_FEATURES))
beta = nn.Parameter(torch.zeros(NUM_FEATURES))

torch.Size([64])

We calculate the exponential moving average to estimate the mean and variance of the entire training set using the following formula: $$ EMA = (1 - \alpha) \cdot \text{old} + \alpha \cdot \text{new} $$.

In [6]:
class BatchNorm(nn.Module):
    def __init__(self, n_features: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.n_features = n_features
        self.eps = eps
        self.momentum = momentum

        self.gamma = nn.Parameter(torch.ones(self.n_features))
        self.beta = nn.Parameter(torch.zeros(self.n_features))

        # Used to calculate EMA instead of calculating mean and variance of the entire training set after training. 
        self.register_buffer('running_mean', torch.zeros(n_features))
        self.register_buffer('running_var', torch.ones(n_features))

    def forward(self, x: torch.Tensor):
        # x.shape: [B, N_FEATURES]
        if self.training:
            batch_mean = torch.mean(x, dim=0)
            batch_var = torch.var(x, dim=0, unbiased=False)

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

            mean = batch_mean
            var = batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        scaled_and_shifted = self.gamma * x_hat + self.beta

        return scaled_and_shifted

In [8]:
batchnorm = BatchNorm(NUM_FEATURES)

normalized_tensor2 = batchnorm(input_tensor)
normalized_tensor2

tensor([[ 0.6791,  0.9368, -0.8219,  ...,  1.1410,  0.3369, -0.0038],
        [-1.4530, -1.5025, -0.4494,  ...,  0.0060, -1.4156,  0.1819],
        [ 0.6202, -0.5847, -0.9721,  ...,  0.0415,  0.5601,  1.5636],
        ...,
        [ 1.3314, -0.3402, -1.1348,  ..., -0.2988,  1.1874, -1.0064],
        [-1.1594,  1.2680, -0.9108,  ...,  0.2650, -1.0203, -1.4529],
        [ 0.7662, -1.6140,  1.8213,  ...,  0.2648, -1.1049,  0.3148]],
       grad_fn=<AddBackward0>)

batchnorm = BatchNorm(NUM_FEATURES)

In [11]:
import torch
import torch.nn as nn


BATCH_SIZE = 32
NUM_CHANNELS = 3
HEIGHT = 32
WIDTH = 32


def normalize2D(x: torch.Tensor, eps: float = 1e-5):
    # x.shape = [B, C, H, W]
    numerator = x - torch.mean(x, dim=(0, 2, 3), keepdim=True) # dim (0, 2, 3) since we wanna collapse the batch dimension (collapsing all rows into a single row), height and width to calculate the mean for all activations in the channels dim.
    print(f"Numerator tensor shape: {numerator.shape}")
    denominator = torch.sqrt(torch.var(x, dim=(0, 2, 3), unbiased=False, keepdim=True) + eps)
    normalized_tensor = numerator/denominator
    print(f"Normalized tensor shape: {normalized_tensor.shape}")

    return normalized_tensor


image_input_tensor = torch.rand(BATCH_SIZE, NUM_CHANNELS, HEIGHT, WIDTH)
normalized_image_tensor = normalize2D(image_input_tensor)
normalized_image_tensor.shape

Numerator tensor shape: torch.Size([32, 3, 32, 32])
Normalized tensor shape: torch.Size([32, 3, 32, 32])


torch.Size([32, 3, 32, 32])

In [None]:
class BatchNorm2D(nn.Module):
    def __init__(self, n_channels: int, eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.n_channels = n_channels
        self.eps = eps
        self.momentum = momentum

        self.gamma = nn.Parameter(torch.ones(self.n_channels))
        self.beta = nn.Parameter(torch.zeros(self.n_channels))

        # Used to calculate EMA instead of calculating mean and variance of the entire training set after training. 
        self.register_buffer('running_mean', torch.zeros(n_channels))
        self.register_buffer('running_var', torch.ones(n_channels))

    def forward(self, x: torch.Tensor):
        # x.shape: [B, C, H, W]
        if self.training:
            batch_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
            batch_var = torch.var(x, dim=(0, 2, 3), unbiased=False, keepdim=True)

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean.squeeze()
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var.squeeze()

            mean = batch_mean
            var = batch_var
        else:
            mean = self.running_mean.view(1, self.n_features, 1, 1)
            var = self.running_var.view(1, self.n_features, 1, 1)

        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        scaled_and_shifted = self.gamma.view(1, self.n_channels, 1, 1) * x_hat + self.beta.view(1, self.n_channels, 1, 1)

        return scaled_and_shifted

In [15]:
batchnorm_2d = BatchNorm2D(n_channels=3)
normalized_image_tensor2 = batchnorm_2d(image_input_tensor)
normalized_image_tensor2

tensor([[[[-0.5462,  1.6242, -1.3328,  ...,  1.3732,  1.5921,  0.8680],
          [-0.1094,  0.3836,  0.6358,  ...,  0.1983, -0.6079, -0.7143],
          [ 1.4027,  0.7692, -0.7442,  ...,  0.7702,  1.5141, -0.6277],
          ...,
          [ 0.6075,  1.6465, -0.5005,  ...,  0.7709, -1.3559,  1.0874],
          [ 0.9382, -1.3500,  1.3994,  ..., -1.5106,  1.5770, -0.2618],
          [-1.4785,  1.0848,  0.1401,  ...,  0.4633, -0.3789, -0.5244]],

         [[ 1.5688, -1.2322, -0.6788,  ..., -1.0756,  0.8480, -0.9172],
          [ 0.0134, -0.6316,  1.1542,  ..., -0.6255, -0.1618,  0.0278],
          [ 1.1640,  0.2854, -1.5592,  ..., -0.0286, -0.5903,  0.2796],
          ...,
          [ 1.4502,  1.0023,  1.3632,  ...,  1.7106,  0.2137,  1.4205],
          [ 1.4260,  0.6429,  0.5273,  ..., -0.9019, -1.0301,  1.5476],
          [-0.4451,  0.6664, -1.1424,  ..., -0.8534, -1.5626,  1.2858]],

         [[ 0.0094, -0.3890, -1.0076,  ...,  1.2930, -0.3669, -0.2070],
          [-0.6518,  0.5907,  