In [1]:
import torch
from torch import nn

$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,
    where $ \hat{x}$ is the estimated statistic and $x_t$ is the
    new observed value.

In [2]:
nn.BatchNorm1d?

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mBatchNorm1d[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_features[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meps[0m[0;34m=[0m[0;36m1e-05[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmomentum[0m[0;34m=[0m[0;36m0.1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maffine[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtrack_running_stats[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .

.. math::

    y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

The mean and standard-deviation are calculated per-dimension 

In [3]:
class BatchNorm1d(nn.Module):
    def __init__(self,
    num_features,
    eps=1e-05,
    momentum=0.1,
    affine=True,
#     track_running_stats=True,
    ):
        super().__init__()
        self.eps = eps
        self.affine = affine
        self.running_var = torch.ones(num_features) #internal mean and bias estimates for use in inference
        self.running_mean = torch.zeros(num_features)
        self.mom = momentum
        if affine:
            self.gamma = nn.Parameter(torch.ones(num_features))
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        if self.training: #Normalize by batch mean/var, and update mean/var estimates
            mean = torch.mean(x, 0, keepdim = True)
            x = x - mean
            var = torch.var(x, 0, keepdim = True, unbiased = False)
            x = x/torch.sqrt(var + self.eps)
            self.update_estimates(mean, var)
        else: #Normalize by mean/var estimates from training
            x = x - self.running_mean
            x = x/torch.sqrt(self.running_var + self.eps)
        if self.affine: #Apply learnable scale/offset
            l = len(x.shape) #number of dimensions
            trail_shape = (l - 2) * (1,) #ones of trailing dimensions after features
            gamma = self.gamma.view((1,-1) + trail_shape) #view so that we multiply/add over features
            beta = self.beta.view((1,-1) + trail_shape)
            x = gamma*x + beta
        return x

    def update_estimates(self, mean, var):
        self.running_mean = (1-self.mom)*self.running_mean + self.mom * mean.flatten()
        self.running_var = (1-self.mom)*self.running_var + self.mom * var.flatten()

## Testing

In [16]:
bn = BatchNorm1d(3, affine=True)
bnt = nn.BatchNorm1d(3, affine=True)

In [17]:
data = torch.rand(4,3)

In [20]:
for f in [bnt, bn]:
    o = f(data)
    print(o)

tensor([[ 0.8313,  0.7915,  1.1609],
        [-1.6295, -0.5567, -0.6334],
        [-0.0099, -1.3494,  0.7636],
        [ 0.8081,  1.1146, -1.2912]], grad_fn=<NativeBatchNormBackward>)
tensor([[ 0.8313,  0.7915,  1.1609],
        [-1.6295, -0.5567, -0.6334],
        [-0.0099, -1.3494,  0.7636],
        [ 0.8081,  1.1146, -1.2912]], grad_fn=<AddBackward0>)


In [21]:
for f in [bnt, bn]:
    print(f.running_mean)
    print(f.running_var)

tensor([0.1123, 0.1677, 0.1336])
tensor([0.7472, 0.7675, 0.7440])
tensor([0.1123, 0.1677, 0.1336])
tensor([0.7427, 0.7578, 0.7403])


Slight difference in running_var. Maybe pytorch implementation updates std instead?

In [22]:
for f in [bnt, bn]:
    f.eval()
    o = f(data)
    print(o)

tensor([[ 0.5656,  0.8098,  0.6908],
        [-0.0738,  0.3077,  0.2669],
        [ 0.3470,  0.0125,  0.5969],
        [ 0.5596,  0.9301,  0.1114]], grad_fn=<NativeBatchNormBackward>)
tensor([[ 0.5673,  0.8149,  0.6925],
        [-0.0740,  0.3096,  0.2675],
        [ 0.3481,  0.0126,  0.5984],
        [ 0.5613,  0.9360,  0.1117]], grad_fn=<AddBackward0>)
