Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-persistent buffers. #37191

Closed
wants to merge 3 commits into from
Closed

Add support for non-persistent buffers. #37191

wants to merge 3 commits into from

Conversation

sharvil
Copy link
Contributor

@sharvil sharvil commented Apr 23, 2020

Issue: #18056

@dr-ci
Copy link

dr-ci bot commented Apr 23, 2020

💊 Build failures summary and remediations

As of commit d68705b (more details on the Dr. CI page):


  • 1/2 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)
  • 1/2 broken upstream at merge base 122d821 since May 06

🚧 1 ongoing upstream failure:

These were probably caused by upstream breakages that are not fixed yet:


Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 14 times.

@zhangguanheng66
Copy link
Contributor

@ssnl Do you have time to review this PR? Thanks

@mruberry mruberry added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 5, 2020
@mruberry mruberry requested review from albanD and removed request for apaszke May 5, 2020 22:47
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Sorry for the late review.
Could you make a small update in the doc to make it crystal clear what this can be used for?
The code is good, only minor comments.


This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this clearer that the only difference here is that non persistent buffers won't show up in the state dict load/save while the persistent ones will.
There is no other difference right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Yes, there's no other difference.

@@ -85,6 +85,7 @@ def __init__(self):
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._persistent_buffers = set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to make this clearer that this is not a dict like the other things, you could name this self._persistent_buffers_name_set ? Or is this getting too long?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I'm not a huge fan of appending the type name to the identifier, but don't feel strongly about it.

@@ -618,6 +627,7 @@ def remove_from(*dicts):
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self._persistent_buffers.discard(name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think you can add the discard call inside the remove_from function here to make this slightly simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@sharvil sharvil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the code review, @albanD!

@@ -85,6 +85,7 @@ def __init__(self):
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._persistent_buffers = set()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I'm not a huge fan of appending the type name to the identifier, but don't feel strongly about it.


This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Yes, there's no other difference.

@@ -618,6 +627,7 @@ def remove_from(*dicts):
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self._persistent_buffers.discard(name)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect thanks!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@sharvil
Copy link
Contributor Author

sharvil commented May 6, 2020

(last push was just a rebase to pytorch/master)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@albanD
Copy link
Collaborator

albanD commented May 6, 2020

Ho nice, I was going to do it manually but better that way. Thanks!

@facebook-github-bot
Copy link
Contributor

@albanD merged this pull request in 594b33e.

ppwwyyxx added a commit to ppwwyyxx/detectron2 that referenced this pull request Feb 22, 2021
Summary:
supported since pytorch 1.6: pytorch/pytorch#37191

This will not save pixel_mean/std in checkpoints. This is consistent with our current model zoo format.

Differential Revision: D26576833

fbshipit-source-id: 65472efcfd71f9383bb47a192e095315037390b5
facebook-github-bot pushed a commit to facebookresearch/detectron2 that referenced this pull request Feb 25, 2021
Summary:
supported since pytorch 1.6: pytorch/pytorch#37191

This will not save pixel_mean/std in checkpoints. This is consistent with our current model zoo format.

Reviewed By: theschnitz

Differential Revision: D26576833

fbshipit-source-id: b7143d6ef8b106e873958394f80aba75fc11d2cf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants