## Batch Norm
statistic term

NLP: [N, L, C] -> [C]

CV: [N, C, H, W] -> [C]

In [1]:
import torch
import torch.nn as nn

In [2]:
batch_size, time_steps, embedding_dim = 2, 3, 4
eps = 1e-5
num_groups = 2

inputx = torch.randn(batch_size, time_steps, embedding_dim)
inputx  # N * L * C

tensor([[[ 1.8588,  0.8554,  0.5678,  0.6646],
         [ 0.3913, -0.5461,  0.7717,  2.9826],
         [-0.1380,  0.3573,  0.3497,  0.8792]],

        [[ 0.8653,  0.4413, -1.6207,  0.2205],
         [-0.8848,  0.6860, -0.6809,  0.9544],
         [-0.0402, -0.3383,  1.3893,  0.1175]]])

In [3]:
batch_norm_op = nn.BatchNorm1d(embedding_dim, affine=False)
batch_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[ 1.7595,  1.1925,  0.4398, -0.3205],
         [ 0.0571, -1.5348,  0.6444,  2.1137],
         [-0.5569,  0.2232,  0.2209, -0.0952]],

        [[ 0.6069,  0.3867, -1.7560, -0.7868],
         [-1.4232,  0.8629, -0.8131, -0.0162],
         [-0.4434, -1.1305,  1.2640, -0.8950]]])

In [4]:
bn_mean = inputx.mean(dim=(0, 1), keepdim=True)
bn_std = inputx.std(dim=(0, 1), keepdim=True, unbiased=False)

(inputx - bn_mean) / (bn_std + eps)

tensor([[[ 1.7595,  1.1925,  0.4398, -0.3205],
         [ 0.0571, -1.5348,  0.6444,  2.1137],
         [-0.5569,  0.2232,  0.2209, -0.0952]],

        [[ 0.6069,  0.3867, -1.7560, -0.7868],
         [-1.4232,  0.8629, -0.8131, -0.0162],
         [-0.4434, -1.1305,  1.2640, -0.8950]]])

## Layer Norm
常用于NLP

statistic term

NLP: [N, L, C] -> [N, L]

CV: [N, C, H, W] -> [N, H, W]

In [5]:
layer_norm_op = nn.LayerNorm(embedding_dim, elementwise_affine=False)
layer_norm_op(inputx)

tensor([[[ 1.6966, -0.2553, -0.8148, -0.6265],
         [-0.3929, -1.1169, -0.0990,  1.6088],
         [-1.3900, -0.0132, -0.0343,  1.4375]],

        [[ 0.9346,  0.4887, -1.6798,  0.2565],
         [-1.1149,  0.8235, -0.8633,  1.1547],
         [-0.4884, -0.9401,  1.6779, -0.2494]]])

In [6]:
ln_mean = inputx.mean(dim=-1, keepdim=True)
ln_std = inputx.std(dim=-1, keepdim=True, unbiased=False)
(inputx - ln_mean) / (ln_std + eps)

tensor([[[ 1.6966, -0.2553, -0.8148, -0.6265],
         [-0.3929, -1.1169, -0.0990,  1.6088],
         [-1.3900, -0.0132, -0.0343,  1.4375]],

        [[ 0.9346,  0.4887, -1.6798,  0.2565],
         [-1.1149,  0.8235, -0.8633,  1.1547],
         [-0.4884, -0.9401,  1.6779, -0.2494]]])

## Instance Norm
常用于风格迁移上

statistic term

NLP: [N, L, C] -> [N, C]

CV: [N, C, H, W] -> [N, C]

In [7]:
ins_norm_op = nn.InstanceNorm1d(embedding_dim)
ins_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[ 1.3671,  1.0916,  0.0274, -0.8072],
         [-0.3702, -1.3244,  1.2106,  1.4092],
         [-0.9969,  0.2329, -1.2380, -0.6020]],

        [[ 1.2387,  0.4082, -1.0471, -0.5643],
         [-1.2103,  0.9685, -0.2997,  1.4051],
         [-0.0284, -1.3767,  1.3468, -0.8408]]])

In [8]:
in_mean = inputx.mean(dim=1, keepdim=True)
in_std = inputx.std(dim=1, keepdim=True, unbiased=False)
(inputx - in_mean) / (in_std + eps)

tensor([[[ 1.3671,  1.0916,  0.0274, -0.8072],
         [-0.3702, -1.3244,  1.2107,  1.4092],
         [-0.9969,  0.2329, -1.2382, -0.6020]],

        [[ 1.2387,  0.4082, -1.0471, -0.5643],
         [-1.2103,  0.9685, -0.2997,  1.4051],
         [-0.0284, -1.3767,  1.3468, -0.8408]]])

## Group Norm

statistic term

NLP: [N, G, L, C // G] -> [N, G]

CV: [N, G, C // G, H, W] -> [N, G]

In [9]:
group_norm_op = nn.GroupNorm(num_groups, embedding_dim, affine=False)
group_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[ 1.8279,  0.5137, -0.5283, -0.4190],
         [-0.0940, -1.3217, -0.2982,  2.1967],
         [-0.7873, -0.1386, -0.7744, -0.1769]],

        [[ 1.2214,  0.5252, -1.6883,  0.1576],
         [-1.6528,  0.9271, -0.7461,  0.8933],
         [-0.2656, -0.7552,  1.3293,  0.0543]]])

In [10]:
group_inputxs = torch.split(inputx, split_size_or_sections= embedding_dim // num_groups, dim=-1)
results = []
for g_inputx in group_inputxs:
    gn_mean = g_inputx.mean(dim=(1, 2), keepdim=True)
    gn_std = g_inputx.std(dim=(1, 2), keepdim=True, unbiased=False)
    gn_result = (g_inputx - gn_mean) / (gn_std + eps)
    results.append(gn_result)
torch.cat(results, dim=-1)

tensor([[[ 1.8279,  0.5137, -0.5282, -0.4190],
         [-0.0940, -1.3217, -0.2981,  2.1967],
         [-0.7873, -0.1386, -0.7744, -0.1769]],

        [[ 1.2214,  0.5252, -1.6883,  0.1576],
         [-1.6528,  0.9271, -0.7461,  0.8933],
         [-0.2656, -0.7552,  1.3293,  0.0543]]])

## Weight Norm

In [11]:
linear = nn.Linear(embedding_dim, 3, bias=False)
wn_linear = nn.utils.weight_norm(linear)
wn_linear(inputx)

tensor([[[ 0.3235,  0.5420, -0.6963],
         [-1.1378, -0.4317, -0.9247],
         [-0.6015, -0.1038, -0.3725]],

        [[ 0.6753,  1.1038, -0.9312],
         [-0.8022,  0.2282, -0.7469],
         [-0.3648, -0.7307,  0.5381]]], grad_fn=<UnsafeViewBackward0>)

In [12]:
weight_direction = linear.weight / linear.weight.norm(dim=1, keepdim=True)
weight_magnitude = wn_linear.weight_g
inputx @ (weight_direction.transpose(-1, -2)) * (weight_magnitude.transpose(-1, -2))

tensor([[[ 0.3235,  0.5420, -0.6963],
         [-1.1378, -0.4317, -0.9247],
         [-0.6015, -0.1038, -0.3725]],

        [[ 0.6753,  1.1038, -0.9312],
         [-0.8022,  0.2282, -0.7469],
         [-0.3648, -0.7307,  0.5381]]], grad_fn=<MulBackward0>)

关于权重归一化的再次说明

In [13]:
batch_size, feat_dim, hid_dim = 2, 3, 4
inputx = torch.randn(batch_size, feat_dim)
linear = nn.Linear(feat_dim, hid_dim, bias=False)
wn_linear = nn.utils.weight_norm(linear)

In [14]:
weight_magnitude = torch.tensor([linear.weight[i, :].norm() for i in torch.arange(linear.weight.shape[0])], dtype=torch.float32).unsqueeze(-1)

weight_direction = linear.weight / weight_magnitude

print('linear.weight:')
print(linear.weight)

print('weight_magnitude:')
print(weight_magnitude)

print('weight_direction:')
print(weight_direction)

print('magnitude of weight_direction:')
print((weight_direction ** 2).sum(dim=-1))

linear.weight:
tensor([[-0.2551, -0.4248, -0.5152],
        [ 0.2330,  0.1396,  0.2011],
        [-0.2154,  0.2437,  0.5582],
        [-0.1916,  0.1428,  0.0081]], grad_fn=<WeightNormInterfaceBackward0>)
weight_magnitude:
tensor([[0.7148],
        [0.3380],
        [0.6460],
        [0.2391]])
weight_direction:
tensor([[-0.3569, -0.5943, -0.7208],
        [ 0.6895,  0.4129,  0.5951],
        [-0.3335,  0.3772,  0.8640],
        [-0.8014,  0.5972,  0.0338]], grad_fn=<DivBackward0>)
magnitude of weight_direction:
tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [15]:
print('inputx @ (weight_direction * weight_magnitude).T:')
print(inputx @ (weight_direction * weight_magnitude).T)

print('linear(inputx):')
print(linear(inputx))

print('wn_linear(inputx):')
print(wn_linear(inputx))

print('parameters in wn_linear:')
for n, p in wn_linear.named_parameters():
    print(n, p)

print('construct weight of linear:')
print(wn_linear.weight_g * (wn_linear.weight_v /
                            torch.tensor([wn_linear.weight_v[i, :].norm() for i in torch.arange(wn_linear.weight.shape[0])],
                                         dtype=torch.float32).unsqueeze(-1)))

inputx @ (weight_direction * weight_magnitude).T:
tensor([[ 1.0653, -0.4257, -0.5384, -0.1830],
        [-0.7710,  0.1401,  1.1635,  0.3304]], grad_fn=<MmBackward0>)
linear(inputx):
tensor([[ 1.0653, -0.4257, -0.5384, -0.1830],
        [-0.7710,  0.1401,  1.1635,  0.3304]], grad_fn=<MmBackward0>)
wn_linear(inputx):
tensor([[ 1.0653, -0.4257, -0.5384, -0.1830],
        [-0.7710,  0.1401,  1.1635,  0.3304]], grad_fn=<MmBackward0>)
parameters in wn_linear:
weight_g Parameter containing:
tensor([[0.7148],
        [0.3380],
        [0.6460],
        [0.2391]], requires_grad=True)
weight_v Parameter containing:
tensor([[-0.2551, -0.4248, -0.5152],
        [ 0.2330,  0.1396,  0.2011],
        [-0.2154,  0.2437,  0.5582],
        [-0.1916,  0.1428,  0.0081]], requires_grad=True)
construct weight of linear:
tensor([[-0.2551, -0.4248, -0.5152],
        [ 0.2330,  0.1396,  0.2011],
        [-0.2154,  0.2437,  0.5582],
        [-0.1916,  0.1428,  0.0081]], grad_fn=<MulBackward0>)
