In [33]:
import torch
import torch.nn as nn


In [34]:
# Image example
N, C, H, W = 2, 3, 3, 3
torch.manual_seed(42)
input = torch.randn(N, C, H, W)
input

tensor([[[[ 1.9269,  1.4873,  0.9007],
          [-2.1055,  0.6784, -1.2345],
          [-0.0431, -1.6047, -0.7521]],

         [[ 1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688],
          [ 0.7624,  1.6423, -0.1596]],

         [[-0.4974,  0.4396, -0.7581],
          [ 1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105]]],


        [[[ 1.3347, -0.2316,  0.0418],
          [-0.2516,  0.8599, -1.3847],
          [-0.8712, -0.2234,  1.7174]],

         [[ 0.3189, -0.4245, -0.8140],
          [-0.7360, -0.8371, -0.9224],
          [ 1.8113,  0.1606,  0.3672]],

         [[ 0.1754, -1.1845,  1.3835],
          [-1.2024,  0.7078, -1.0759],
          [ 0.5357,  1.1754,  0.5612]]]])

In [35]:
# Normalize over the last three dimensions
# i.e. the channel and spatial dimensions
layer_norm = nn.LayerNorm([C,H,W], elementwise_affine=True)

# (Batch_size, channel, height, width) --> (batch_size, channel, heigh, width)
normalized = layer_norm(input)
normalized

tensor([[[[ 1.5402,  1.1496,  0.6284],
          [-2.0428,  0.4309, -1.2689],
          [-0.2102, -1.5977, -0.8402]],

         [[ 1.2930, -0.5207, -1.4191],
          [-0.8187, -0.6690, -0.8551],
          [ 0.5055,  1.2873, -0.3138]],

         [[-0.6139,  0.2186, -0.8456],
          [ 0.7862,  0.5396,  1.3213],
          [ 0.9646,  0.9800,  0.3705]]],


        [[[ 1.4005, -0.2895,  0.0054],
          [-0.3110,  0.8881, -1.5336],
          [-0.9796, -0.2806,  1.8133]],

         [[ 0.3044, -0.4976, -0.9179],
          [-0.8337, -0.9428, -1.0348],
          [ 1.9147,  0.1336,  0.3566]],

         [[ 0.1496, -1.3176,  1.4531],
          [-1.3370,  0.7241, -1.2004],
          [ 0.5383,  1.2285,  0.5659]]]], grad_fn=<NativeLayerNormBackward0>)

In [37]:
print('Mean and std of input image:')
for i in range(N):
    img_s = input[i].flatten()
    print(f"Image {i}:")
    print(f"Mean: {img_s.mean():.6f}")
    print(f"Std:  {img_s.std():.6f}")

print('Mean and std of input image after layernorm:')
for i in range(N):
    img_stats = normalized[i].flatten()
    print(f"Image {i}:")
    print(f"Mean: {img_stats.mean():.6f}")
    print(f"Std:  {img_stats.std():.6f}")

Mean and std of input image:
Image 0:
Mean: 0.193514
Std:  1.146888
Image 1:
Mean: 0.036716
Std:  0.944496
Mean and std of input image after layernorm:
Image 0:
Mean: -0.000000
Std:  1.019045
Image 1:
Mean: 0.000000
Std:  1.019043
