You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am playing with PyTorch batchnorm2d and your implementation. I tried to use your implementation in mobilenetv3 and the performance seems similar. However, I found the gradient values are not the same, but I am not sure why. Below is my test code:
"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d
@author: ptrblck
"""
import torch
import torch.nn as nn
def compare_bn(bn1, bn2):
err = False
if not torch.allclose(bn1.running_mean, bn2.running_mean):
print('Diff in running_mean: {} vs {}'.format(
bn1.running_mean, bn2.running_mean))
err = True
if not torch.allclose(bn1.running_var, bn2.running_var):
print('Diff in running_var: {} vs {}'.format(
bn1.running_var, bn2.running_var))
err = True
if bn1.affine and bn2.affine:
if not torch.allclose(bn1.weight, bn2.weight):
print('Diff in weight: {} vs {}'.format(
bn1.weight, bn2.weight))
err = True
# compare weight gradient here
if not torch.allclose(bn1.weight.grad, bn2.weight.grad):
print('Diff in weight gradient: {} vs {}'.format(
bn1.weight.grad, bn2.weight.grad))
err = True
if not torch.allclose(bn1.bias, bn2.bias):
print('Diff in bias: {} vs {}'.format(
bn1.bias, bn2.bias))
err = True
# compare bias gradient here
if not torch.allclose(bn1.bias.grad, bn2.bias.grad):
print('Diff in bias gradient: {} vs {}'.format(
bn1.bias.grad, bn2.bias.grad))
err = True
if not err:
print('All parameters are equal!')
class MyBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MyBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
# Run train
for _ in range(10):
scale = torch.randint(1, 10, (1,)).float()
bias = torch.randint(-10, 10, (1,)).float()
x = torch.randn(10, 3, 100, 100) * scale + bias
out1 = my_bn(x)
out2 = bn(x)
# calculate gradient for leaf
out1.sum().backward()
out2.sum().backward()
compare_bn(my_bn, bn)
torch.allclose(out1, out2)
print('Max diff: ', (out1 - out2).abs().max())
# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):
scale = torch.randint(1, 10, (1,)).float()
bias = torch.randint(-10, 10, (1,)).float()
x = torch.randn(10, 3, 100, 100) * scale + bias
out1 = my_bn(x)
out2 = bn(x)
# calculate gradient for leaf
out1.sum().backward()
out2.sum().backward()
compare_bn(my_bn, bn)
torch.allclose(out1, out2)
print('Max diff: ', (out1 - out2).abs().max())
Thanks in advance.
The text was updated successfully, but these errors were encountered:
I guess you are running into the expected numerical precision errors due to the usage of sum as I've described here: pytorch/pytorch#82493 (comment)
TL;DR: small expected numerical mismatches of 1e-7 or 1e-6 will be accumulated num_elements times, if you are summing the output. My post in the upstream issue describes it in more details.
If I'm using out.mean().backward() I'm seeing the expected small differences as ~1e-8; closing.
Hi Ptrblk,
I am playing with PyTorch batchnorm2d and your implementation. I tried to use your implementation in mobilenetv3 and the performance seems similar. However, I found the gradient values are not the same, but I am not sure why. Below is my test code:
Thanks in advance.
The text was updated successfully, but these errors were encountered: