## Batch Norm
statistic term

NLP: [N, L, C] -> [C]

CV: [N, C, H, W] -> [C]

In [1]:
import torch
import torch.nn as nn

In [2]:
batch_size, time_steps, embedding_dim = 2, 3, 4
eps = 1e-5
num_groups = 2

inputx = torch.randn(batch_size, time_steps, embedding_dim)
inputx  # N * L * C

tensor([[[ 0.1046, -1.4737,  0.4315, -0.9118],
         [ 1.2951, -1.8426,  0.0324,  0.3267],
         [ 0.6766,  1.6677, -0.0437,  1.6402]],

        [[ 1.2373,  0.7291, -1.7653, -0.2277],
         [ 0.5288, -0.5725,  0.8231, -1.0805],
         [ 1.1818,  0.2117,  1.8918, -0.1694]]])

In [3]:
batch_norm_op = nn.BatchNorm1d(embedding_dim, affine=False)
batch_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[-1.6766, -1.0302,  0.1849, -0.9362],
         [ 1.0473, -1.3318, -0.1782,  0.4419],
         [-0.3679,  1.5377, -0.2475,  1.9033]],

        [[ 0.9150,  0.7704, -1.8137, -0.1750],
         [-0.7059, -0.2936,  0.5411, -1.1238],
         [ 0.7880,  0.3475,  1.5134, -0.1101]]])

In [4]:
bn_mean = inputx.mean(dim=(0, 1), keepdim=True)
bn_std = inputx.std(dim=(0, 1), keepdim=True, unbiased=False)

(inputx - bn_mean) / (bn_std + eps)

tensor([[[-1.6766, -1.0302,  0.1849, -0.9362],
         [ 1.0473, -1.3318, -0.1782,  0.4419],
         [-0.3679,  1.5377, -0.2475,  1.9033]],

        [[ 0.9150,  0.7704, -1.8137, -0.1750],
         [-0.7059, -0.2936,  0.5411, -1.1238],
         [ 0.7880,  0.3475,  1.5134, -0.1101]]])

## Layer Norm
常用于NLP

statistic term

NLP: [N, L, C] -> [N, L]

CV: [N, C, H, W] -> [N, H, W]

In [5]:
layer_norm_op = nn.LayerNorm(embedding_dim, elementwise_affine=False)
layer_norm_op(inputx)

tensor([[[ 0.7404, -1.3208,  1.1674, -0.5870],
         [ 1.1804, -1.5791,  0.0699,  0.3288],
         [-0.4312,  0.9537, -1.4377,  0.9152]],

        [[ 1.0878,  0.6434, -1.5379, -0.1933],
         [ 0.7751, -0.6380,  1.1527, -1.2898],
         [ 0.4975, -0.7007,  1.3745, -1.1714]]])

In [6]:
ln_mean = inputx.mean(dim=-1, keepdim=True)
ln_std = inputx.std(dim=-1, keepdim=True, unbiased=False)
(inputx - ln_mean) / (ln_std + eps)

tensor([[[ 0.7404, -1.3208,  1.1674, -0.5870],
         [ 1.1804, -1.5791,  0.0699,  0.3288],
         [-0.4312,  0.9536, -1.4377,  0.9152]],

        [[ 1.0878,  0.6434, -1.5379, -0.1933],
         [ 0.7751, -0.6380,  1.1527, -1.2898],
         [ 0.4975, -0.7007,  1.3745, -1.1714]]])

## Instance Norm
常用于风格迁移上

statistic term

NLP: [N, L, C] -> [N, C]

CV: [N, C, H, W] -> [N, C]

In [7]:
ins_norm_op = nn.InstanceNorm1d(embedding_dim)
ins_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[-1.2085, -0.5868,  1.3983, -1.2126],
         [ 1.2404, -0.8210, -0.5166, -0.0240],
         [-0.0319,  1.4077, -0.8817,  1.2365]],

        [[ 0.7916,  1.1331, -1.3559,  0.6359],
         [-1.4106, -1.2994,  0.3299, -1.4119],
         [ 0.6191,  0.1663,  1.0260,  0.7760]]])

In [8]:
in_mean = inputx.mean(dim=1, keepdim=True)
in_std = inputx.std(dim=1, keepdim=True, unbiased=False)
(inputx - in_mean) / (in_std + eps)

tensor([[[-1.2085, -0.5868,  1.3983, -1.2126],
         [ 1.2404, -0.8210, -0.5166, -0.0240],
         [-0.0319,  1.4077, -0.8817,  1.2365]],

        [[ 0.7916,  1.1331, -1.3559,  0.6359],
         [-1.4107, -1.2994,  0.3299, -1.4119],
         [ 0.6191,  0.1663,  1.0260,  0.7760]]])

## Group Norm

statistic term

NLP: [N, G, L, C // G] -> [N, G]

CV: [N, G, C // G, H, W] -> [N, G]

In [9]:
group_norm_op = nn.GroupNorm(num_groups, embedding_dim, affine=False)
group_norm_op(inputx.transpose(-1, -2)).transpose(-1, -2)

tensor([[[ 0.0252, -1.1699,  0.2446, -1.5256],
         [ 0.9267, -1.4493, -0.2814,  0.1066],
         [ 0.4583,  1.2089, -0.3817,  1.8375]],

        [[ 1.1109,  0.2862, -1.4031, -0.1169],
         [-0.0388, -1.8259,  0.7621, -0.8302],
         [ 1.0209, -0.5533,  1.6561, -0.0681]]])

In [10]:
group_inputxs = torch.split(inputx, split_size_or_sections= embedding_dim // num_groups, dim=-1)
results = []
for g_inputx in group_inputxs:
    gn_mean = g_inputx.mean(dim=(1, 2), keepdim=True)
    gn_std = g_inputx.std(dim=(1, 2), keepdim=True, unbiased=False)
    gn_result = (g_inputx - gn_mean) / (gn_std + eps)
    results.append(gn_result)
torch.cat(results, dim=-1)

tensor([[[ 0.0252, -1.1699,  0.2446, -1.5256],
         [ 0.9267, -1.4492, -0.2814,  0.1066],
         [ 0.4583,  1.2089, -0.3817,  1.8375]],

        [[ 1.1109,  0.2862, -1.4031, -0.1169],
         [-0.0388, -1.8259,  0.7621, -0.8302],
         [ 1.0209, -0.5533,  1.6561, -0.0681]]])

## Weight Norm

In [11]:
linear = nn.Linear(embedding_dim, 3, bias=False)
wn_linear = nn.utils.weight_norm(linear)
wn_linear(inputx)

tensor([[[ 0.6324, -0.6381, -0.1893],
         [ 1.1542, -1.0547,  0.0614],
         [-0.5847,  0.3673,  0.4489]],

        [[ 0.5107,  0.0457,  0.1138],
         [ 0.3587, -0.6211, -0.0026],
         [-0.0274, -0.7776,  0.3750]]], grad_fn=<UnsafeViewBackward0>)

In [12]:
weight_direction = linear.weight / linear.weight.norm(dim=1, keepdim=True)
weight_magnitude = wn_linear.weight_g
inputx @ (weight_direction.transpose(-1, -2)) * (weight_magnitude.transpose(-1, -2))

tensor([[[ 0.6324, -0.6381, -0.1893],
         [ 1.1542, -1.0547,  0.0614],
         [-0.5847,  0.3673,  0.4489]],

        [[ 0.5107,  0.0457,  0.1138],
         [ 0.3587, -0.6211, -0.0026],
         [-0.0274, -0.7776,  0.3750]]], grad_fn=<MulBackward0>)

关于权重归一化的再次说明

In [13]:
batch_size, feat_dim, hid_dim = 2, 3, 4
inputx = torch.randn(batch_size, feat_dim)
linear = nn.Linear(feat_dim, hid_dim, bias=False)
wn_linear = nn.utils.weight_norm(linear)

In [14]:
weight_magnitude = torch.tensor([linear.weight[i, :].norm() for i in torch.arange(linear.weight.shape[0])], dtype=torch.float32).unsqueeze(-1)

weight_direction = linear.weight / weight_magnitude

print('linear.weight:')
print(linear.weight)

print('weight_magnitude:')
print(weight_magnitude)

print('weight_direction:')
print(weight_direction)

print('magnitude of weight_direction:')
print((weight_direction ** 2).sum(dim=-1))

linear.weight:
tensor([[ 0.4706,  0.3531, -0.3592],
        [ 0.5306, -0.2951, -0.1585],
        [-0.5772, -0.1889,  0.0781],
        [ 0.3203, -0.5723,  0.1163]], grad_fn=<WeightNormInterfaceBackward0>)
weight_magnitude:
tensor([[0.6893],
        [0.6276],
        [0.6124],
        [0.6661]])
weight_direction:
tensor([[ 0.6827,  0.5123, -0.5211],
        [ 0.8456, -0.4703, -0.2526],
        [-0.9426, -0.3084,  0.1276],
        [ 0.4809, -0.8592,  0.1746]], grad_fn=<DivBackward0>)
magnitude of weight_direction:
tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [15]:
print('inputx @ (weight_direction * weight_magnitude).T:')
print(inputx @ (weight_direction * weight_magnitude).T)

print('linear(inputx):')
print(linear(inputx))

print('wn_linear(inputx):')
print(wn_linear(inputx))

print('parameters in wn_linear:')
for n, p in wn_linear.named_parameters():
    print(n, p)

print('construct weight of linear:')
print(wn_linear.weight_g * (wn_linear.weight_v /
                            torch.tensor([wn_linear.weight_v[i, :].norm() for i in torch.arange(wn_linear.weight.shape[0])],
                                         dtype=torch.float32).unsqueeze(-1)))

inputx @ (weight_direction * weight_magnitude).T:
tensor([[-0.0750,  0.0486, -0.0290,  0.1244],
        [-0.5817, -0.3314,  0.5547, -0.0039]], grad_fn=<MmBackward0>)
linear(inputx):
tensor([[-0.0750,  0.0486, -0.0290,  0.1244],
        [-0.5817, -0.3314,  0.5547, -0.0039]], grad_fn=<MmBackward0>)
wn_linear(inputx):
tensor([[-0.0750,  0.0486, -0.0290,  0.1244],
        [-0.5817, -0.3314,  0.5547, -0.0039]], grad_fn=<MmBackward0>)
parameters in wn_linear:
weight_g Parameter containing:
tensor([[0.6893],
        [0.6276],
        [0.6124],
        [0.6661]], requires_grad=True)
weight_v Parameter containing:
tensor([[ 0.4706,  0.3531, -0.3592],
        [ 0.5306, -0.2951, -0.1585],
        [-0.5772, -0.1889,  0.0781],
        [ 0.3203, -0.5723,  0.1163]], requires_grad=True)
construct weight of linear:
tensor([[ 0.4706,  0.3531, -0.3592],
        [ 0.5306, -0.2951, -0.1585],
        [-0.5772, -0.1889,  0.0781],
        [ 0.3203, -0.5723,  0.1163]], grad_fn=<MulBackward0>)


In [16]:
conv1d = nn.Conv1d(feat_dim, hid_dim, kernel_size=1, bias=False)
wn_conv1d = nn.utils.weight_norm(conv1d)

In [17]:
conv1d_weight_magnitude = torch.tensor([conv1d.weight[i, :, :].norm()
                                        for i in torch.arange(conv1d.weight.shape[0])], dtype=torch.float32)
conv1d_weight_direction = conv1d.weight / conv1d_weight_magnitude

print('parameters of wn_conv1d:')
for n, p in wn_conv1d.named_parameters():
    print(n, p, p.shape)

print('construct weight of conv1d:')
print(wn_conv1d.weight_g * (wn_conv1d.weight_v / torch.tensor([wn_conv1d.weight_v[i, :, :].norm()
                                                               for i in torch.arange(conv1d.weight.shape[0])],
                                                               dtype=torch.float32)))

print('conv1d.weight:')
print(conv1d.weight)

print('conv1d_weight_magnitude:')
print(conv1d_weight_magnitude)

print('conv1d_weight_direction:')
print(conv1d_weight_direction)

parameters of wn_conv1d:
weight_g Parameter containing:
tensor([[[0.7590]],

        [[0.6893]],

        [[0.6723]],

        [[0.7030]]], requires_grad=True) torch.Size([4, 1, 1])
weight_v Parameter containing:
tensor([[[ 0.3721],
         [-0.4501],
         [ 0.4847]],

        [[-0.4077],
         [ 0.4126],
         [-0.3724]],

        [[-0.4312],
         [ 0.5121],
         [ 0.0611]],

        [[-0.3843],
         [-0.4427],
         [-0.3879]]], requires_grad=True) torch.Size([4, 3, 1])
construct weight of conv1d:
tensor([[[ 0.3721,  0.4097,  0.4201,  0.4017],
         [-0.4501, -0.4956, -0.5082, -0.4860],
         [ 0.4847,  0.5337,  0.5473,  0.5233]],

        [[-0.3703, -0.4077, -0.4181, -0.3998],
         [ 0.3747,  0.4126,  0.4230,  0.4045],
         [-0.3383, -0.3724, -0.3819, -0.3652]],

        [[-0.3820, -0.4206, -0.4312, -0.4124],
         [ 0.4536,  0.4994,  0.5121,  0.4897],
         [ 0.0542,  0.0596,  0.0611,  0.0585]],

        [[-0.3560, -0.3920, -0.4019, -0.