## Batch Norm
statistic term

NLP: [N, L, C] -> [C]

CV: [N, C, H, W] -> [C]

In [1]:
import torch
import torch.nn as nn

In [2]:
batch_size, time_steps, embedding_dim = 2, 3, 4
eps = 1e-5
num_groups = 2

inputx = torch.randn(batch_size, time_steps, embedding_dim)
inputx  # N * L * C

tensor([[[-1.0911, -0.4127,  0.8924,  0.0615],
         [-0.5815, -1.5187, -0.3445,  0.1576],
         [-1.5346,  1.6328,  0.6028, -0.6954]],

        [[ 0.2712,  0.1874,  0.7269,  0.0604],
         [ 0.5891,  0.3546,  0.3073,  0.1127],
         [ 1.2444,  0.3722, -2.3264,  0.9483]]])

In [3]:
batch_norm_op = nn.BatchNorm1d(embedding_dim, affine=False)
batch_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[-0.9350, -0.5442,  0.8298, -0.0968],
         [-0.4098, -1.7123, -0.2907,  0.1052],
         [-1.3921,  1.6161,  0.5675, -1.6874]],

        [[ 0.4688,  0.0896,  0.6799, -0.0991],
         [ 0.7965,  0.2661,  0.2998,  0.0110],
         [ 1.4716,  0.2848, -2.0863,  1.7670]]])

In [4]:
bn_mean = inputx.mean(dim=(0, 1), keepdim=True)
bn_std = inputx.std(dim=(0, 1), keepdim=True, unbiased=False)

(inputx - bn_mean) / (bn_std + eps)

tensor([[[-0.9350, -0.5442,  0.8298, -0.0968],
         [-0.4098, -1.7123, -0.2907,  0.1052],
         [-1.3921,  1.6161,  0.5675, -1.6874]],

        [[ 0.4688,  0.0896,  0.6799, -0.0991],
         [ 0.7965,  0.2661,  0.2998,  0.0110],
         [ 1.4716,  0.2848, -2.0863,  1.7670]]])

## Layer Norm
常用于NLP

statistic term

NLP: [N, L, C] -> [N, L]

CV: [N, C, H, W] -> [N, H, W]

In [5]:
layer_norm_op = nn.LayerNorm(embedding_dim, elementwise_affine=False)
layer_norm_op(inputx)

tensor([[[-1.3208, -0.3811,  1.4263,  0.2756],
         [-0.0160, -1.5565,  0.3736,  1.1989],
         [-1.2682,  1.3469,  0.4966, -0.5753]],

        [[-0.1603, -0.4936,  1.6529, -0.9990],
         [ 1.4636,  0.0804, -0.1983, -1.3457],
         [ 0.8386,  0.2213, -1.6888,  0.6290]]])

In [6]:
ln_mean = inputx.mean(dim=-1, keepdim=True)
ln_std = inputx.std(dim=-1, keepdim=True, unbiased=False)
(inputx - ln_mean) / (ln_std + eps)

tensor([[[-1.3208, -0.3811,  1.4263,  0.2756],
         [-0.0160, -1.5565,  0.3736,  1.1989],
         [-1.2682,  1.3469,  0.4966, -0.5753]],

        [[-0.1603, -0.4937,  1.6529, -0.9990],
         [ 1.4637,  0.0804, -0.1983, -1.3458],
         [ 0.8386,  0.2213, -1.6888,  0.6290]]])

## Instance Norm
常用于风格迁移上

statistic term

NLP: [N, L, C] -> [N, C]

CV: [N, C, H, W] -> [N, C]

In [7]:
ins_norm_op = nn.InstanceNorm1d(embedding_dim)
ins_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[-0.0566, -0.2399,  0.9633,  0.5774],
         [ 1.2520, -1.0871, -1.3783,  0.8293],
         [-1.1954,  1.3269,  0.4151, -1.4067]],

        [[-1.0622, -1.4079,  0.8566, -0.7705],
         [-0.2775,  0.5981,  0.5462, -0.6418],
         [ 1.3396,  0.8098, -1.4028,  1.4122]]])

In [8]:
in_mean = inputx.mean(dim=1, keepdim=True)
in_std = inputx.std(dim=1, keepdim=True, unbiased=False)
(inputx - in_mean) / (in_std + eps)

tensor([[[-0.0566, -0.2399,  0.9633,  0.5774],
         [ 1.2520, -1.0871, -1.3783,  0.8293],
         [-1.1954,  1.3269,  0.4151, -1.4067]],

        [[-1.0622, -1.4087,  0.8566, -0.7705],
         [-0.2775,  0.5985,  0.5462, -0.6418],
         [ 1.3396,  0.8103, -1.4028,  1.4122]]])

## Group Norm

statistic term

NLP: [N, G, L, C // G] -> [N, G]

CV: [N, G, C // G, H, W] -> [N, G]

In [9]:
group_norm_op = nn.GroupNorm(num_groups, embedding_dim, affine=False)
group_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[-0.4700,  0.1592,  1.4590, -0.0952],
         [ 0.0026, -0.8664, -0.8546,  0.0845],
         [-0.8812,  2.0558,  0.9173, -1.5110]],

        [[-0.6563, -0.8933,  0.7020,  0.0826],
         [ 0.2433, -0.4203,  0.3121,  0.1312],
         [ 2.0971, -0.3704, -2.1358,  0.9078]]])

In [10]:
group_inputxs = torch.split(inputx, split_size_or_sections= embedding_dim // num_groups, dim=-1)
results = []
for g_inputx in group_inputxs:
    gn_mean = g_inputx.mean(dim=(1, 2), keepdim=True)
    gn_std = g_inputx.std(dim=(1, 2), keepdim=True, unbiased=False)
    gn_result = (g_inputx - gn_mean) / (gn_std + eps)
    results.append(gn_result)
torch.cat(results, dim=-1)

tensor([[[-0.4700,  0.1592,  1.4590, -0.0952],
         [ 0.0026, -0.8664, -0.8546,  0.0845],
         [-0.8812,  2.0558,  0.9173, -1.5110]],

        [[-0.6563, -0.8933,  0.7020,  0.0826],
         [ 0.2433, -0.4203,  0.3121,  0.1312],
         [ 2.0971, -0.3704, -2.1358,  0.9078]]])

## Weight Norm

In [11]:
linear = nn.Linear(embedding_dim, 3, bias=False)
wn_linear = nn.utils.weight_norm(linear)
wn_linear(inputx)

tensor([[[-0.1737,  0.1331,  0.2459],
         [ 0.5450,  0.6517, -0.0767],
         [-1.0854, -0.2878,  0.9288]],

        [[-0.0707, -0.3230, -0.1982],
         [-0.0475, -0.3142, -0.2021],
         [ 0.1083,  0.4293,  0.2600]]], grad_fn=<UnsafeViewBackward0>)

In [12]:
weight_direction = linear.weight / linear.weight.norm(dim=1, keepdim=True)
weight_magnitude = wn_linear.weight_g
inputx @ (weight_direction.transpose(-1, -2)) * (weight_magnitude.transpose(-1, -2))

tensor([[[-0.1737,  0.1331,  0.2459],
         [ 0.5450,  0.6517, -0.0767],
         [-1.0854, -0.2878,  0.9288]],

        [[-0.0707, -0.3230, -0.1982],
         [-0.0475, -0.3142, -0.2021],
         [ 0.1083,  0.4293,  0.2600]]], grad_fn=<MulBackward0>)