In [52]:
import torch.nn as nn
import torch

# Input: batch size=2, dimension=4
BN = nn.BatchNorm2d(4, momentum=0.001)
LN = nn.LayerNorm([4, 2, 2])
IN = nn.InstanceNorm2d(4, momentum=0.001, track_running_stats=True)
GN = nn.GroupNorm(2, 4)

In [53]:
with torch.no_grad():
    bns = []
    lns = []
    ins = []
    gns = []
    # Input: batch size=2, dimension=4
    for _ in range(1000):
        X = torch.randn(2, 4, 2, 2) * torch.Tensor([1, 2, 3, 4]).view([1, 4, 1, 1])
        X = X + torch.Tensor([[[[0.25]]], [[[0.75]]]])
        bns.append(BN(X))
        lns.append(LN(X))
        ins.append(IN(X))
        gns.append(GN(X))

In [54]:
# Running Stats: (1, 4)
# Mean: (0.25 + 0.75) / 2, std should = [1, 2, 3, 4]
print("BN:", BN.running_mean, BN.running_var**0.5)
def result(x):
    x = torch.concat([b.view([-1, *b.shape]) for b in x])
    return torch.mean(x, 0)[:, :, 0, 0], torch.std(x, 0)[:, :, 0, 0]

print("BN result:\n", *result(bns), sep='\n')

BN: tensor([0.3238, 0.2936, 0.3012, 0.2675]) tensor([1.0145, 1.6870, 2.4865, 3.2944])
BN result:

tensor([[-0.2215, -0.1286, -0.1157, -0.1002],
        [ 0.2815,  0.1597,  0.1325,  0.1138]])
tensor([[0.9609, 0.9979, 1.0320, 0.9907],
        [0.9550, 0.9651, 0.9872, 0.9775]])


In [55]:
# Running Stats: (2, 4)
# Normalize for each sample -> all output should be (0, 1)
print("IN:", IN.running_mean, IN.running_var**0.5)
print("IN result:\n", *result(ins), sep='\n')

IN: tensor([0.3238, 0.2936, 0.3012, 0.2675]) tensor([0.9908, 1.6743, 2.4750, 3.2985])
IN result:

tensor([[ 0.0358,  0.0184, -0.0253, -0.0375],
        [ 0.0144,  0.0089,  0.0657,  0.0599]])
tensor([[0.9902, 0.9971, 1.0313, 0.9848],
        [0.9996, 0.9877, 0.9990, 0.9908]])


In [56]:
# Running Stats: (2, 1)
# Output Mean: Should all be 0. (Sample 1=0.25-0.25, Sample 2=0.75-0.75)
# Output Std: each sample has same variance
print("LN result:\n", *result(lns), sep='\n')

LN result:

tensor([[ 0.0317,  0.0131, -0.0292, -0.0660],
        [ 0.0254,  0.0145,  0.0668,  0.0551]])
tensor([[0.4551, 0.8051, 1.1432, 1.3767],
        [0.4830, 0.7494, 1.1000, 1.3757]])


In [57]:
# Group = 1 -> Layer Norm
# Group = 4 -> Instance Norm
# Group = 2 -> Split 2 blocks and do Layer Norm
print("GN result:\n", *result(gns), sep='\n')

GN result:

tensor([[ 0.0317,  0.0131, -0.0292, -0.0660],
        [ 0.0254,  0.0145,  0.0668,  0.0551]])
tensor([[0.4551, 0.8051, 1.1432, 1.3767],
        [0.4830, 0.7494, 1.1000, 1.3757]])
