In [1]:
import torch
import torch.nn as nn

In [12]:
batch_size = 20
num_channels = 10
height = 5
width = 6
a = torch.rand(batch_size, num_channels, height, width)

In [11]:
a.var(dim=0, unbiased=False, keepdim=True).shape

tensor([[0.0884, 0.1019, 0.0885, 0.0842, 0.0686, 0.0827, 0.1006, 0.1028, 0.0998,
         0.0741]])

In [None]:
class BatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1) -> None:
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # Initialize parameters
        self.gamma = nn.parameter(torch.ones(1, num_features))
        self.beta = nn.Parameter(torch.zeros(1, num_features))
        self.register_buffer('running_mean', torch.zeros(1, num_features))
        self.register_buffer('running_var', torch.ones(1, num_features))
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0, keepdim=True)
            var = x.var(dim=0, unbiased=False, keepdim=True)
            self.running_mean = (1-self.momentum)*self.running_mean+self.momentum*mean
            self.running_var = (1-self.momentum)*self.running_var + self.momentum*var
        else:
            mean = self.running_mean
            var = self.running_var
        
        x_normalized = (x-mean)/torch.sqrt(var+self.eps)

        out = self.gamma*x_normalized + self.beta
        return out

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5) -> None:
        super().__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps

        # initialize parameters
        self.gamma = nn.parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)

        x_normalized = (x - mean)/torch.sqrt(var+self.eps)
        out = self.gamma*x_normalized+self.beta
        return out

In [None]:
class InstanceNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5) -> None:
        super().__init__()

        self.num_features = num_features
        self.eps = eps

        # Initialize parameters
        self.gamma = nn.Parameter(torch.ones(1,num_features,1,1))
        self.beta = nn.parameter(torch.zeros(1, num_features, 1,1))

    def forward(self, x):
        mean = x.mean(dim=(2,3), keepdim=True)
        var = x.var(dim=(2,3), unbiased=False, keepdim=True)

        x_normalized = (x - mean)/torch.sqrt(var+self.eps)
        out = self.gamma*x_normalized+self.beta
        return out

In [None]:
class GroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5) -> None:
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps

        # Initialize parameters
        self.gamma = nn.Parameter(torch.ones(1, num_channels, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        num_channels_per_group = num_channels // self.num_groups
        x = x.view(batch_size, self.num_groups, num_channels_per_group, height, width)

        # compute mean and variance per group
        mean = x.mean(dim=(2,3,4), keepdim=True)
        var = x.var(dim=(2,3,4), unbiased=False, keepdim=True)

        # normalize x within each group
        x_normalized = (x-mean)/ torch.sqrt(var+self.eps)

        # reshape back to the original shape
        x_normalized = x_normalized.view(batch_size, num_channels, height, width)

        # scale and shift using learned parameters
        out = x_normalized * self.gamma + self.beta
        return out
