In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [78]:
class BatchNorm1d(nn.Module):
    def __init__(self, num_features: int, momentum: float = 0.1, eps: float = 1e-5):
        super(BatchNorm1d, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros((num_features, 1)), persistent=True)
        self.register_buffer('running_var', torch.ones((num_features, 1)), persistent=True)
        self.gamma = nn.Parameter(torch.ones((num_features, 1)), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros((num_features, 1)), requires_grad=True)
        
    def forward(self, inputs):
        batch_size, c, timesteps = inputs.shape
        assert self.num_features == c, f"expected tensor's channels is {slef.num_features}, buf found {c}"
        mean = torch.mean(inputs, dim=[0, -1], keepdims=True)
        var = torch.var(inputs, dim=[0, -1], unbiased=False, keepdims=True)
        if self.training is True:
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = torch.autograd.Variable(self.running_mean)
            var = torch.autograd.Variable(self.running_var)
        out = self.gamma * (inputs - mean) / torch.sqrt(var + self.eps) + self.beta
        return out

In [84]:
torch.manual_seed(2)
batch_size, c, timesteps = 2, 2, 4
bn1 = nn.BatchNorm1d(c)
bn2 = BatchNorm1d(c)
print('bn1 state dict', bn1.state_dict())
print('bn2 state dict', bn2.state_dict())
print('='*40)
inputs = torch.randn((batch_size, c, timesteps))
out1 = bn1(inputs)
out2 = bn2(inputs)
print('bn1 output', out1)
print('bn2 output', out2)

bn1 state dict OrderedDict([('weight', tensor([1., 1.])), ('bias', tensor([0., 0.])), ('running_mean', tensor([0., 0.])), ('running_var', tensor([1., 1.])), ('num_batches_tracked', tensor(0))])
bn2 state dict OrderedDict([('gamma', tensor([[1.],
        [1.]])), ('beta', tensor([[0.],
        [0.]])), ('running_mean', tensor([[0.],
        [0.]])), ('running_var', tensor([[1.],
        [1.]]))])
bn1 output tensor([[[-0.7559,  1.9225, -1.1163, -0.8501],
         [-1.2276,  1.6587, -1.0324,  0.1028]],

        [[-0.5740,  1.1407, -0.1135,  0.3466],
         [ 1.4652, -0.6641, -0.0256, -0.2768]]],
       grad_fn=<NativeBatchNormBackward0>)
bn2 output tensor([[[-0.7559,  1.9225, -1.1163, -0.8501],
         [-1.2276,  1.6587, -1.0324,  0.1028]],

        [[-0.5740,  1.1407, -0.1135,  0.3466],
         [ 1.4652, -0.6641, -0.0256, -0.2768]]], grad_fn=<AddBackward0>)


## 经过train之后的mean与官方相等，但是std与官方存在差异

In [85]:
bn1.eval()
bn2.eval()
out1 = bn1(inputs)
out2 = bn2(inputs)
print('bn1 output:' ,out1)
print('bn2 output:', out2)

bn1 output: tensor([[[-1.0119,  0.9848, -1.2805, -1.0821],
         [-1.2117,  1.2012, -1.0485, -0.0996]],

        [[-0.8762,  0.4020, -0.5330, -0.1900],
         [ 1.0394, -0.7406, -0.2069, -0.4169]]],
       grad_fn=<NativeBatchNormBackward0>)
bn2 output: tensor([[[-1.0159,  0.9887, -1.2857, -1.0864],
         [-1.2178,  1.2072, -1.0538, -0.1001]],

        [[-0.8797,  0.4036, -0.5351, -0.1908],
         [ 1.0446, -0.7444, -0.2079, -0.4190]]], grad_fn=<AddBackward0>)


## 没有经过train（）之后，进行eval（）后结果与官方相同；
## 经过train（）之后，进行eval（）后的结果与官方存在差异

In [86]:
print('Torch API:', bn1.running_mean, bn1.running_var)
print('My Test:', bn2.running_mean, bn2.running_var)

Torch API: tensor([-0.0488, -0.0204]) tensor([0.9610, 0.9781])
My Test: tensor([[[-0.0488],
         [-0.0204]]]) tensor([[[0.9534],
         [0.9684]]])
