In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [16]:
X=torch.tensor([[[[1.,2.],[3.,4.]],[[5.,6.],[7.,8.]]]])
X=X.mean(dim=(0, 3), keepdim=True)
X, X.shape

(tensor([[[[1.5000],
           [3.5000]],
 
          [[5.5000],
           [7.5000]]]]),
 torch.Size([1, 2, 2, 1]))

In [13]:
X=torch.tensor([[1., 2.], [3., 4.]])
X.mean(dim=0, keepdim=True)

tensor([[2., 3.]])

In [23]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not torch.is_grad_enabled():
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4) # 2 is fully connect layer, 4 is conv layer
        if len(X.shape) == 2: # fully connect layer
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            mean = X.mean(dim=(0,2,3), keepdim=True) # calculate the mean of each channel
            var = ((X - mean) ** 2).mean(dim=(0,2,3), keepdim=True) # X - mean is propogate the mean to each element in the channel.
        X_hat = (X - mean) / torch.sqrt(var + eps)
        moving_mean = momentum * moving_mean + (1 - momentum) * mean
        moving_var = momentum * moving_var + (1 - momentum) * var
    Y_hat = gamma * X_hat + beta
    return Y_hat, moving_mean, moving_var

In [25]:
X=torch.tensor([[1., 1.], [0., 0.]])
batch_norm(X, 1, 0, 1, 0, 1, 1)

(tensor([[ 0.4472,  0.4472],
         [-0.4472, -0.4472]]),
 tensor([1., 1.]),
 tensor([0., 0.]))

In [26]:
class myBatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            assert num_dims == 4
            shape = (1, num_features, 1, 1)
        # shape should do element-wise multiply for batch norm.
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        Y, moving_mean, moving_var=batch_norm(
            X, self.gamma, self.beta, self.moving_mean, self.moving_var,
            eps=1e-5, momentum=0.9)
        self.moving_mean=moving_mean
        self.moving_var=moving_var
        return Y
        

In [28]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), myBatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), myBatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), myBatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), myBatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

### Train in GPU later