-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable resetting of batchnorm running moments and cumulative average #6445
Conversation
torch/nn/modules/batchnorm.py
Outdated
@@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | |||
if self.track_running_stats: | |||
self.register_buffer('running_mean', torch.zeros(num_features)) | |||
self.register_buffer('running_var', torch.ones(num_features)) | |||
self.register_buffer('num_batches_tracked', torch.LongTensor([0])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
cc @soumith |
@onnxbot test this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking merge of this until BC-breakage is resolved
fyi, I'm working on load_state_dict to solve the issue. |
Thanks so much -- would be great to have it in time for the next release :) |
@pytorchbot retest this please |
1 similar comment
@pytorchbot retest this please |
Can someone help me out with this onnx check? |
@ssnl please have a look |
(closed by mistake) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the new buffer should be in onnx export or not. If it should be, then updating the onnx expected output would be an easy fix.
torch/nn/modules/batchnorm.py
Outdated
num_batches_tracked_key = prefix + 'num_batches_tracked' | ||
if num_batches_tracked_key not in state_dict: | ||
# Add the missing num_batches_tracked counter | ||
state_dict[num_batches_tracked_key] = torch.LongTensor([0]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/batchnorm.py
Outdated
@@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, | |||
if self.track_running_stats: | |||
self.register_buffer('running_mean', torch.zeros(num_features)) | |||
self.register_buffer('running_var', torch.ones(num_features)) | |||
self.register_buffer('num_batches_tracked', torch.LongTensor([0])) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/batchnorm.py
Outdated
if self.training and self.track_running_stats: | ||
self.num_batches_tracked += 1 | ||
if self.momentum is None: # use cumulative moving average | ||
exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item()) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
The TestCaffe2Backend.test_resnet fail is at least easy to try locally. Could you see if it works? |
comments addressed. I'm not sure what the TestCaffe2Backend test is, as grepping for that string returns nothing. |
@onnxbot retest this please |
…") moving average
@pytorchbot test this please |
@ezyang merge this one with onnxbot/onnx-fb-universe#1792 together |
Woohoo! Thanks to all! |
…") moving average (pytorch#6445)
…") moving average (pytorch#6445)
Duplicate of #5766