Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance drop due to batch norm params recalculation #8

Closed
ghost opened this issue Jan 29, 2019 · 1 comment
Closed

performance drop due to batch norm params recalculation #8

ghost opened this issue Jan 29, 2019 · 1 comment

Comments

@ghost
Copy link

ghost commented Jan 29, 2019

Thanks for the great work!
I have some cases when the performance drops when using swa compared to a single model.
In this case from 0.67 loss to 0.72 loss of the exact SWA copy.

In order to debug the problem I run SWA for only one epoch and compared the model vs the SWA copy.
All the parameter are the same except the batch norms running_mean and running var. and it seems that the deeper you go in the network the bigger the divergence is:

Do you have any tips on how to recalculate the batch_norm params more accurately? or should i just run the training set to the swa version multiple times for them to converge to the original model params?

This the code I use to compare m the model state dict and swa the SWA copy state dict

for key in m.keys():
...     print(key,(swa[key]-m[key]).sum())
...
in_c.0.weight tensor(0., device='cuda:0')
in_c.1.weight tensor(0., device='cuda:0')
in_c.1.bias tensor(0., device='cuda:0')
in_c.1.running_mean tensor(-0.0847, device='cuda:0')
in_c.1.running_var tensor(4.5671, device='cuda:0')
in_c.1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block1.conv1.weight tensor(0., device='cuda:0')
stage1.block1.bn1.weight tensor(0., device='cuda:0')
stage1.block1.bn1.bias tensor(0., device='cuda:0')
stage1.block1.bn1.running_mean tensor(-0.0932, device='cuda:0')
stage1.block1.bn1.running_var tensor(-1.3953, device='cuda:0')
stage1.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block1.conv2.weight tensor(0., device='cuda:0')
stage1.block1.bn2.weight tensor(0., device='cuda:0')
stage1.block1.bn2.bias tensor(0., device='cuda:0')
stage1.block1.bn2.running_mean tensor(0.0153, device='cuda:0')
stage1.block1.bn2.running_var tensor(0.4095, device='cuda:0')
stage1.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block2.conv1.weight tensor(0., device='cuda:0')
stage1.block2.bn1.weight tensor(0., device='cuda:0')
stage1.block2.bn1.bias tensor(0., device='cuda:0')
stage1.block2.bn1.running_mean tensor(-0.1347, device='cuda:0')
stage1.block2.bn1.running_var tensor(0.1461, device='cuda:0')
stage1.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block2.conv2.weight tensor(0., device='cuda:0')
stage1.block2.bn2.weight tensor(0., device='cuda:0')
stage1.block2.bn2.bias tensor(0., device='cuda:0')
stage1.block2.bn2.running_mean tensor(0.0590, device='cuda:0')
stage1.block2.bn2.running_var tensor(-0.0815, device='cuda:0')
stage1.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block3.conv1.weight tensor(0., device='cuda:0')
stage1.block3.bn1.weight tensor(0., device='cuda:0')
stage1.block3.bn1.bias tensor(0., device='cuda:0')
stage1.block3.bn1.running_mean tensor(-0.3279, device='cuda:0')
stage1.block3.bn1.running_var tensor(1.1645, device='cuda:0')
stage1.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block3.conv2.weight tensor(0., device='cuda:0')
stage1.block3.bn2.weight tensor(0., device='cuda:0')
stage1.block3.bn2.bias tensor(0., device='cuda:0')
stage1.block3.bn2.running_mean tensor(0.0153, device='cuda:0')
stage1.block3.bn2.running_var tensor(-0.0028, device='cuda:0')
stage1.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage1.block4.conv1.weight tensor(0., device='cuda:0')
stage1.block4.bn1.weight tensor(0., device='cuda:0')
stage1.block4.bn1.bias tensor(0., device='cuda:0')
stage1.block4.bn1.running_mean tensor(-0.0740, device='cuda:0')
stage1.block4.bn1.running_var tensor(1.9955, device='cuda:0')
stage1.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage1.block4.conv2.weight tensor(0., device='cuda:0')
stage1.block4.bn2.weight tensor(0., device='cuda:0')
stage1.block4.bn2.bias tensor(0., device='cuda:0')
stage1.block4.bn2.running_mean tensor(-0.0084, device='cuda:0')
stage1.block4.bn2.running_var tensor(-0.1089, device='cuda:0')
stage1.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.conv1.weight tensor(0., device='cuda:0')
stage2.block1.bn1.weight tensor(0., device='cuda:0')
stage2.block1.bn1.bias tensor(0., device='cuda:0')
stage2.block1.bn1.running_mean tensor(-0.5207, device='cuda:0')
stage2.block1.bn1.running_var tensor(13.1180, device='cuda:0')
stage2.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.conv2.weight tensor(0., device='cuda:0')
stage2.block1.bn2.weight tensor(0., device='cuda:0')
stage2.block1.bn2.bias tensor(0., device='cuda:0')
stage2.block1.bn2.running_mean tensor(0.0416, device='cuda:0')
stage2.block1.bn2.running_var tensor(0.0256, device='cuda:0')
stage2.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block1.shortcut.conv.weight tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.weight tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.bias tensor(0., device='cuda:0')
stage2.block1.shortcut.bn.running_mean tensor(-0.0403, device='cuda:0')
stage2.block1.shortcut.bn.running_var tensor(9.1353, device='cuda:0')
stage2.block1.shortcut.bn.num_batches_tracked tensor(0, device='cuda:0')
stage2.block2.conv1.weight tensor(0., device='cuda:0')
stage2.block2.bn1.weight tensor(0., device='cuda:0')
stage2.block2.bn1.bias tensor(0., device='cuda:0')
stage2.block2.bn1.running_mean tensor(0.5277, device='cuda:0')
stage2.block2.bn1.running_var tensor(-3.7371, device='cuda:0')
stage2.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block2.conv2.weight tensor(0., device='cuda:0')
stage2.block2.bn2.weight tensor(0., device='cuda:0')
stage2.block2.bn2.bias tensor(0., device='cuda:0')
stage2.block2.bn2.running_mean tensor(-0.0346, device='cuda:0')
stage2.block2.bn2.running_var tensor(-0.5359, device='cuda:0')
stage2.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block3.conv1.weight tensor(0., device='cuda:0')
stage2.block3.bn1.weight tensor(0., device='cuda:0')
stage2.block3.bn1.bias tensor(0., device='cuda:0')
stage2.block3.bn1.running_mean tensor(-0.1441, device='cuda:0')
stage2.block3.bn1.running_var tensor(-9.7162, device='cuda:0')
stage2.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block3.conv2.weight tensor(0., device='cuda:0')
stage2.block3.bn2.weight tensor(0., device='cuda:0')
stage2.block3.bn2.bias tensor(0., device='cuda:0')
stage2.block3.bn2.running_mean tensor(0.0451, device='cuda:0')
stage2.block3.bn2.running_var tensor(-0.1082, device='cuda:0')
stage2.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage2.block4.conv1.weight tensor(0., device='cuda:0')
stage2.block4.bn1.weight tensor(0., device='cuda:0')
stage2.block4.bn1.bias tensor(0., device='cuda:0')
stage2.block4.bn1.running_mean tensor(0.5103, device='cuda:0')
stage2.block4.bn1.running_var tensor(-31.8605, device='cuda:0')
stage2.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage2.block4.conv2.weight tensor(0., device='cuda:0')
stage2.block4.bn2.weight tensor(0., device='cuda:0')
stage2.block4.bn2.bias tensor(0., device='cuda:0')
stage2.block4.bn2.running_mean tensor(-0.0271, device='cuda:0')
stage2.block4.bn2.running_var tensor(-0.5281, device='cuda:0')
stage2.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block1.conv1.weight tensor(0., device='cuda:0')
stage3.block1.bn1.weight tensor(0., device='cuda:0')
stage3.block1.bn1.bias tensor(0., device='cuda:0')
stage3.block1.bn1.running_mean tensor(2.6584, device='cuda:0')
stage3.block1.bn1.running_var tensor(-154.8269, device='cuda:0')
stage3.block1.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block1.conv2.weight tensor(0., device='cuda:0')
stage3.block1.bn2.weight tensor(0., device='cuda:0')
stage3.block1.bn2.bias tensor(0., device='cuda:0')
stage3.block1.bn2.running_mean tensor(-0.0399, device='cuda:0')
stage3.block1.bn2.running_var tensor(-2.9489, device='cuda:0')
stage3.block1.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block2.conv1.weight tensor(0., device='cuda:0')
stage3.block2.bn1.weight tensor(0., device='cuda:0')
stage3.block2.bn1.bias tensor(0., device='cuda:0')
stage3.block2.bn1.running_mean tensor(-0.0263, device='cuda:0')
stage3.block2.bn1.running_var tensor(-6.7252, device='cuda:0')
stage3.block2.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block2.conv2.weight tensor(0., device='cuda:0')
stage3.block2.bn2.weight tensor(0., device='cuda:0')
stage3.block2.bn2.bias tensor(0., device='cuda:0')
stage3.block2.bn2.running_mean tensor(0.2284, device='cuda:0')
stage3.block2.bn2.running_var tensor(0.6274, device='cuda:0')
stage3.block2.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block3.conv1.weight tensor(0., device='cuda:0')
stage3.block3.bn1.weight tensor(0., device='cuda:0')
stage3.block3.bn1.bias tensor(0., device='cuda:0')
stage3.block3.bn1.running_mean tensor(-0.1151, device='cuda:0')
stage3.block3.bn1.running_var tensor(15.2176, device='cuda:0')
stage3.block3.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block3.conv2.weight tensor(0., device='cuda:0')
stage3.block3.bn2.weight tensor(0., device='cuda:0')
stage3.block3.bn2.bias tensor(0., device='cuda:0')
stage3.block3.bn2.running_mean tensor(0.3864, device='cuda:0')
stage3.block3.bn2.running_var tensor(-1.1801, device='cuda:0')
stage3.block3.bn2.num_batches_tracked tensor(0, device='cuda:0')
stage3.block4.conv1.weight tensor(0., device='cuda:0')
stage3.block4.bn1.weight tensor(0., device='cuda:0')
stage3.block4.bn1.bias tensor(0., device='cuda:0')
stage3.block4.bn1.running_mean tensor(-0.2459, device='cuda:0')
stage3.block4.bn1.running_var tensor(29.2794, device='cuda:0')
stage3.block4.bn1.num_batches_tracked tensor(0, device='cuda:0')
stage3.block4.conv2.weight tensor(0., device='cuda:0')
stage3.block4.bn2.weight tensor(0., device='cuda:0')
stage3.block4.bn2.bias tensor(0., device='cuda:0')
stage3.block4.bn2.running_mean tensor(-0.0546, device='cuda:0')
stage3.block4.bn2.running_var tensor(9.7019, device='cuda:0')
stage3.block4.bn2.num_batches_tracked tensor(0, device='cuda:0')
feed_forward.0.weight tensor(0., device='cuda:0')
feed_forward.1.weight tensor(0., device='cuda:0')
feed_forward.1.bias tensor(0., device='cuda:0')
feed_forward.1.running_mean tensor(0.2525, device='cuda:0')
feed_forward.1.running_var tensor(-2.0196, device='cuda:0')
feed_forward.1.num_batches_tracked tensor(0, device='cuda:0')
@izmailovpavel
Copy link
Collaborator

Hi, sorry for a late response. If you are still interested, you should be using the utils.bn_update function in order to update batch norm statistics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant