Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BatchNorm] Unexpected behaviour with track_running_stats #37823

Closed
frgfm opened this issue May 5, 2020 · 7 comments
Closed

[BatchNorm] Unexpected behaviour with track_running_stats #37823

frgfm opened this issue May 5, 2020 · 7 comments
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@frgfm
Copy link
Contributor

frgfm commented May 5, 2020

🐛 Strange behaviour when changing track_running_stats after instantiation

When the track_running_stats is set to False after instantiation, the number of batches tracked is indeed not updated, but the running mean and var are.

I understand that this attribute is not meant to be changed after instantiation. But here I was freezing BN stats in layers that had all their frozen parameters temporarily for a training. As you go through it, model.eval() and model.train() are called often, which means that freezing the BN stats of frozen layers has to be done every epoch.

To Reproduce

import torch
import torch.nn as nn

bn = nn.BatchNorm1d(4)
bn.track_running_stats = False

# Store initial values
num_batches = bn.num_batches_tracked.clone() 
running_mean = bn.running_mean.clone() 
running_var = bn.running_var.clone()

# Forward random tensor
_ = bn(torch.rand(32, 4))

# Check which stats were updated
print(torch.equal(num_batches, bn.num_batches_tracked))
print(torch.equal(running_mean, bn.running_mean))
print(torch.equal(running_var, bn.running_var))

yields :

True
False
False

Expected behavior

Any forward through a BN layer would not have any influence on the BN stats if track_running_stats is set to ̀False` (even after instantiation).

Environment

  • PyTorch Version: 1.5.0
  • OS: Ubuntu 18.04 LTS
  • How you installed PyTorch: conda
  • Python version: 3.7
  • CUDA/cuDNN version: 10.1.243 (cuDNN 7.6.5)
  • GPU models and configuration: GeForce GTX 1080 Ti (driver 430.50)
@mruberry mruberry added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 5, 2020
@mruberry
Copy link
Collaborator

mruberry commented May 5, 2020

Thanks for reporting this issue, @frgfm! And sorry to hear it's causing you trouble. I'm also surprised at this behavior -- I wonder if it was intended?

@ssnl
Copy link
Collaborator

ssnl commented May 6, 2020

Hi, author of track_running_stats here.

@mruberry @frgfm The root cause of this is that self.running_* buffers are created or set to None at ctor depending on the track_running_stats. BatchNorm*D passes the attributes to F.batch_norm, which does the nullity check to decide whether they should be updated. So effectively, setting that attribute on BatchNorm*D after creation doesn't do anything (except for updating num_batch_tracked). I'd be fine with accepting a patch that fixes BatchNorm*d::forward to check for self.num_batch_tracked and decide whether to pass None or self.running_* to F.batch_norm.

@frgfm
Copy link
Contributor Author

frgfm commented May 6, 2020

@ssnl I figure it would be a nullity check under the hood, but thanks for the specifics.
I'll work on a PR then, and keep you both posted!

@frgfm
Copy link
Contributor Author

frgfm commented May 7, 2020

@ssnl I realize I might have misunderstood your suggestion. To clarify, what did you mean by

decide whether to pass None or self.running_* to F.batch_norm.
?

As I understand the values passed to torch.batch_norm [here], if I pass None instead of self.running_*, they will indeed not be updated, but they also won't be used in the output computation. Here my issue was that I do have existing self.running_* buffers that are different from their value at initialization and track_runnings_stats=True, I cannot prevent all the buffers from updating whatever the future value of self.training.

My intuition was to pass the self.running_* attributes anyway so that the computation takes them into account but make a dynamic check of self.track_running_stats to decide whether these attributes have to be updated.

Let me know your thoughts on this, I may have completely missed your point earlier!

@ssnl
Copy link
Collaborator

ssnl commented May 7, 2020

@frgfm Oh good points. I must admit that I didn't think this through when I wrote the original reply.

Here is the current behavior:

train eval supported use case
track_running_stats=True && buffer is not None normalize & update buffer with batch stats normalize with buffer stats Y
track_running_stats=True && buffer is None normalize with batch stats ERROR N
track_running_stats=False && buffer is not None normalize & update buffer with buffer stats normalize & update buffer with buffer stats N
track_running_stats=False && buffer is None normalize with batch stats normalize with batch stats Y

The problem with dynamic check of track_running_stats to decide whether buffers should be updated is that the computation is handled by F.batch_norm, which should be agnostic of track_running_stats because its behavior is completely controlled by the training flag and the nullity of the running buffers. Adding another flag will be unnecessary as it doesn't introduce new behavior.

Really we want to change the second to last row in the above table to always normalize (but not update buffer) with buffer stats. In F.batch_norm language, we want running_* is not None and training is False.

Note that the last row corresponds to the behavior with running_* is None and training is True.

So at the end of the day, we probably just want to change this line

self.training or not self.track_running_stats,

to

(self.running_mean is None and self.running_var is None) if self.track_running_stats else self.training

(probably worth expanding into an if-block for readability, with comments.)

@frgfm
Copy link
Contributor Author

frgfm commented May 7, 2020

@ssnl yes, I agree with the change in behaviour.
Correct me if I'm wrong, but shouldn't it be :

(self.running_mean is None and self.running_var is None) if not self.track_running_stats else self.training

which is equivalent to

self.training if self.track_running_stats else (self.running_mean is None and self.running_var is None)

?

I opened a PR with my suggestion above #38084, let me know what you think!

facebook-github-bot pushed a commit that referenced this issue Jun 22, 2020
…alse (#38084)

Summary:
This PR aims at tackling #37823 by:
- ensuring that buffers will be used for normalization computation but won't be updated, when buffers are not None, and `track_running_stats=False`
- adding a corresponding unittest to ensure expected behaviour

Any feedback is welcome!

_Note: we might want to update the docstrings of  `BatchNorm*d`, feel free to share any suggestion!_
Pull Request resolved: #38084

Differential Revision: D22047871

Pulled By: ezyang

fbshipit-source-id: 5acbcad9773e7901f26d625db71d43d7dc236d3e
@frgfm
Copy link
Contributor Author

frgfm commented Jun 23, 2020

Closed by #38084

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants