In [None]:
import torch
import torch.nn as nn

class batchNorm(nn.Module):
    def __init__(self, num_features, eps =1e-5, momentum=0.1):
        super().__init__()
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.eps =eps
        self.momentum =momentum
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)

    def forward(self, x):
        if self.training:
            mean = torch.mean(x, dim=0)
            var = torch.var(x, dim =0)

            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
  
        else:
            mean = self.running_mean
            var =self.running_var
        
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma*x_hat + self.beta

In [None]:
class layerNormCNN(nn.Model):
    def __init__(self, num_channels, eps =1e-5):
        self.num_channels = num_channels
        self.eps = eps
        self.beta = nn.Parameter(torch.zeros(num_channels))
        self.gamma = nn.Parameter(torch.ones(num_channels))

    def forward(self, x):
        # (C, H, W)
        mean = x.mean(dim = (1,2,3), bias =False)
        var= x.var(dim =(1,2,3), bias = False)

        x_hat = (x - mean)/torch.sqrt(var + self.eps)
        return self.gamma* x_hat + self.beta

In [None]:
class layerNormTransform(nn.Model):
    def __init__(self, normalized_shape, eps =1e-5):
        self.eps = eps
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
        self.gamma = nn.Parameter(torch.ones(normalized_shape))

    def forward(self, x):
        # (B, seq_len, embed)
        mean = x.mean(dim =-1, bias =False)
        var= x.var(dim =-1, bias = False)

        x_hat = (x - mean)/torch.sqrt(var + self.eps)
        return self.gamma* x_hat + self.beta

In [None]:
class InstanceNorm(nn.Module):
    def __init__(self, num_channels, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        # x: (N, C, H, W)
        mean = x.mean(dim=(2, 3), keepdim=True)
        var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma.view(1, -1, 1, 1) * x_hat + self.beta.view(1, -1, 1, 1)


In [None]:
class GroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super().__init__()
        assert num_channels % num_groups == 0
        self.num_groups = num_groups
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x):
        N, C, H, W = x.shape
        G = self.num_groups

        x = x.view(N, G, C // G, H, W)
        mean = x.mean(dim=(2, 3, 4), keepdim=True)
        var = x.var(dim=(2, 3, 4), keepdim=True, unbiased=False)

        x = (x - mean) / torch.sqrt(var + self.eps)
        x = x.view(N, C, H, W)

        return self.gamma.view(1, -1, 1, 1) * x + self.beta.view(1, -1, 1, 1)
