In [1]:
import torch
import torch.nn as nn

In [5]:
N, C, H, W = 2, 3, 4, 4
torch.manual_seed(42)
input = torch.randn(N, C, H , W)
input

tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047],
          [-0.7521,  1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688,  0.7624]],

         [[ 1.6423, -0.1596, -0.4974,  0.4396],
          [-0.7581,  1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105,  1.3347],
          [-0.2316,  0.0418, -0.2516,  0.8599]],

         [[-1.3847, -0.8712, -0.2234,  1.7174],
          [ 0.3189, -0.4245,  0.3057, -0.7746],
          [-1.5576,  0.9956, -0.8798, -0.6011],
          [-1.2742,  2.1228, -1.2347, -0.4879]]],


        [[[-0.9138, -0.6581,  0.0780,  0.5258],
          [-0.4880,  1.1914, -0.8140, -0.7360],
          [-1.4032,  0.0360, -0.0635,  0.6756],
          [-0.0978,  1.8446, -1.1845,  1.3835]],

         [[ 1.4451,  0.8564,  2.2181,  0.5232],
          [ 0.3466, -0.1973, -1.0546,  1.2780],
          [-0.1722,  0.5238,  0.0566,  0.4263],
          [ 0.5750, -0.6417, -2.2064, -0.7508]],

         [[ 0.0109, -0.3387,

In [7]:

batch_norm = nn.BatchNorm2d(C, affine = True)
normalized = batch_norm(input)
normalized

tensor([[[[ 1.8870,  1.4753,  0.9259, -1.8897],
          [ 0.7177, -1.0740,  0.0420, -1.4206],
          [-0.6221,  1.6265, -0.2853, -1.2323],
          [-0.5994, -0.4417, -0.6378,  0.7964]],

         [[ 1.3543, -0.5901, -0.9546,  0.0565],
          [-1.2360,  0.7457,  0.4463,  1.3957],
          [ 0.9624,  0.9811,  0.2409,  1.0224],
          [-0.6678, -0.3728, -0.6893,  0.5100]],

         [[-1.2570, -0.7293, -0.0636,  1.9308],
          [ 0.4937, -0.2703,  0.4801, -0.6300],
          [-1.4346,  1.1891, -0.7381, -0.4518],
          [-1.1434,  2.3474, -1.1028, -0.3354]]],


        [[[-0.7736, -0.5341,  0.1554,  0.5748],
          [-0.3748,  1.1981, -0.6801, -0.6070],
          [-1.2320,  0.1160,  0.0228,  0.7151],
          [-0.0093,  1.8099, -1.0271,  1.3781]],

         [[ 1.1416,  0.5063,  1.9757,  0.1467],
          [-0.0438, -0.6308, -1.5559,  0.9612],
          [-0.6037,  0.1473, -0.3568,  0.0421],
          [ 0.2026, -1.1104, -2.7988, -1.2281]],

         [[ 0.1771, -0.1821,

In [9]:
# Print stats for each channel (since BatchNorm normalizes per channel)
print('Mean and std of input image per channel:')
for c in range(C):
    channel_vals = input[:, c, :, :].flatten()  # All batch, H, W for this channel
    print(f"\nChannel {c}:")
    print(f"Mean: {channel_vals.mean():.6f}")
    print(f"Std:  {channel_vals.std():.6f}")

print('\nMean and std after BatchNorm per channel:')
for c in range(C):
    channel_vals = normalized[:, c, :, :].flatten()
    print(f"\nChannel {c}:")
    print(f"Mean: {channel_vals.mean():.6f}")
    print(f"Std:  {channel_vals.std():.6f}")



Mean and std of input image per channel:

Channel 0:
Mean: -0.087866
Std:  1.084790

Channel 1:
Mean: 0.387244
Std:  0.941521

Channel 2:
Mean: -0.161516
Std:  0.988676

Mean and std after BatchNorm per channel:

Channel 0:
Mean: 0.000000
Std:  1.015997

Channel 1:
Mean: 0.000000
Std:  1.015995

Channel 2:
Mean: -0.000000
Std:  1.015996


In [10]:
# For comparison with LayerNorm, let's also look at per-image stats
print('\nFor comparison - stats per image:')
for i in range(N):
    img_stats = normalized[i].flatten()
    print(f"\nImage {i}:")
    print(f"Mean: {img_stats.mean():.6f}")
    print(f"Std:  {img_stats.std():.6f}")


For comparison - stats per image:

Image 0:
Mean: 0.015779
Std:  1.040310

Image 1:
Mean: -0.015779
Std:  0.979683
