Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable resetting of batchnorm running moments and cumulative average #6445

Merged
merged 1 commit into from
Apr 27, 2018
Merged

Enable resetting of batchnorm running moments and cumulative average #6445

merged 1 commit into from
Apr 27, 2018

Conversation

jma127
Copy link
Contributor

@jma127 jma127 commented Apr 9, 2018

Duplicate of #5766

@@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.LongTensor([0]))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@jma127
Copy link
Contributor Author

jma127 commented Apr 9, 2018

cc @soumith

@bddppq
Copy link
Contributor

bddppq commented Apr 10, 2018

@onnxbot test this please

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocking merge of this until BC-breakage is resolved

@ssnl
Copy link
Collaborator

ssnl commented Apr 16, 2018

fyi, I'm working on load_state_dict to solve the issue.

@jma127
Copy link
Contributor Author

jma127 commented Apr 16, 2018

Thanks so much -- would be great to have it in time for the next release :)

@jma127
Copy link
Contributor Author

jma127 commented Apr 26, 2018

@pytorchbot retest this please

1 similar comment
@jma127
Copy link
Contributor Author

jma127 commented Apr 26, 2018

@pytorchbot retest this please

@jma127
Copy link
Contributor Author

jma127 commented Apr 26, 2018

Can someone help me out with this onnx check?

@soumith
Copy link
Member

soumith commented Apr 26, 2018

@ssnl please have a look

@soumith soumith closed this Apr 26, 2018
@soumith soumith reopened this Apr 26, 2018
@soumith
Copy link
Member

soumith commented Apr 26, 2018

(closed by mistake)

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the new buffer should be in onnx export or not. If it should be, then updating the onnx expected output would be an easy fix.

num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key not in state_dict:
# Add the missing num_batches_tracked counter
state_dict[num_batches_tracked_key] = torch.LongTensor([0])

This comment was marked as off-topic.

@@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.LongTensor([0]))

This comment was marked as off-topic.

if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Apr 26, 2018

The TestCaffe2Backend.test_resnet fail is at least easy to try locally. Could you see if it works?

@jma127
Copy link
Contributor Author

jma127 commented Apr 26, 2018

comments addressed. I'm not sure what the TestCaffe2Backend test is, as grepping for that string returns nothing.

@dzhulgakov
Copy link
Collaborator

@onnxbot retest this please

@houseroad
Copy link
Member

@pytorchbot test this please

@houseroad
Copy link
Member

@ezyang merge this one with onnxbot/onnx-fb-universe#1792 together

@houseroad houseroad merged commit 76d3c30 into pytorch:master Apr 27, 2018
@jma127
Copy link
Contributor Author

jma127 commented Apr 27, 2018

Woohoo! Thanks to all!

Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants