In [1]:
import torch.nn as nn
import torch

In [2]:
bn = nn.BatchNorm2d(3)

N, C, H, W = 2, 3, 4, 5
mean, std = 10, 4
x = std * torch.randn(N, C, H, W) + mean

In [3]:
# Should have correct `mean` and `std`
print('Before BN:')
print('  Shape: ', x.shape)
print('  Means: ', x.mean(dim=(0, 2, 3)))
print('  Stds: ', x.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

Before BN:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 9.2909, 10.8612,  9.9053])
  Stds:  tensor([3.7070, 4.4806, 3.7729])
  Running Mean:  tensor([0., 0., 0.])
  Running Var:  tensor([1., 1., 1.])


In [4]:
# Should have mean = 0, and std = 1,
# and running stats are updated
bn.train()
out = bn(x)
print('After BN Train:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Train:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 1.4901e-08,  4.9174e-08, -5.0664e-08], grad_fn=<MeanBackward1>)
  Stds:  tensor([1.0127, 1.0127, 1.0127], grad_fn=<StdBackward>)
  Running Mean:  tensor([0.9291, 1.0861, 0.9905])
  Running Var:  tensor([2.2742, 2.9076, 2.3235])


In [5]:
# mean != 0 and std != 1,
# and running stats are not updated
bn.eval()
out = bn(x)
print('After BN Eval:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Eval:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([5.5448, 5.7326, 5.8484], grad_fn=<MeanBackward1>)
  Stds:  tensor([2.4581, 2.6277, 2.4752], grad_fn=<StdBackward>)
  Running Mean:  tensor([0.9291, 1.0861, 0.9905])
  Running Var:  tensor([2.2742, 2.9076, 2.3235])


In [6]:
# mean = 0 and std = 1,
# but running stats are not updated
bn.train()
bn.track_running_stats = False
out = bn(x)
print('After BN Train + Disable Running Stats:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Train + Disable Running Stats:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 1.4901e-08,  4.9174e-08, -5.0664e-08], grad_fn=<MeanBackward1>)
  Stds:  tensor([1.0127, 1.0127, 1.0127], grad_fn=<StdBackward>)
  Running Mean:  tensor([0.9291, 1.0861, 0.9905])
  Running Var:  tensor([2.2742, 2.9076, 2.3235])


In [7]:
# mean = 0 and std = 1,
# and running stats are updated
bn.train()
bn.track_running_stats = True
out = bn(x)
print('After BN Train + Enable Running Stats:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Train + Disable Running Stats:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 1.4901e-08,  4.9174e-08, -5.0664e-08], grad_fn=<MeanBackward1>)
  Stds:  tensor([1.0127, 1.0127, 1.0127], grad_fn=<StdBackward>)
  Running Mean:  tensor([1.7653, 2.0636, 1.8820])
  Running Var:  tensor([3.4209, 4.6244, 3.5146])


In [8]:
# mean = 0 and std = 1,
# but running stats are not updated
bn.train()
bn.track_running_stats = False
out = bn(x)
print('After BN Train + Disable Running Stats:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Train + Disable Running Stats:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 1.4901e-08,  4.9174e-08, -5.0664e-08], grad_fn=<MeanBackward1>)
  Stds:  tensor([1.0127, 1.0127, 1.0127], grad_fn=<StdBackward>)
  Running Mean:  tensor([1.7653, 2.0636, 1.8820])
  Running Var:  tensor([3.4209, 4.6244, 3.5146])


In [9]:
# mean = 0 and std = 1,
# and running stats are updated
bn.train()
bn.track_running_stats = True
out = bn(x)
print('After BN Train + Enable Running Stats:')
print('  Shape: ', out.shape)
print('  Means: ', out.mean(dim=(0, 2, 3)))
print('  Stds: ', out.std(dim=(0, 2, 3)))
print('  Running Mean: ', bn.running_mean)
print('  Running Var: ', bn.running_var)

After BN Train + Enable Running Stats:
  Shape:  torch.Size([2, 3, 4, 5])
  Means:  tensor([ 1.4901e-08,  4.9174e-08, -5.0664e-08], grad_fn=<MeanBackward1>)
  Stds:  tensor([1.0127, 1.0127, 1.0127], grad_fn=<StdBackward>)
  Running Mean:  tensor([2.5178, 2.9434, 2.6843])
  Running Var:  tensor([4.4530, 6.1695, 4.5866])
