Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-persistent Module buffers #18056

Closed
f0k opened this issue Mar 15, 2019 · 8 comments
Closed

Non-persistent Module buffers #18056

f0k opened this issue Mar 15, 2019 · 8 comments
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@f0k
Copy link
Contributor

f0k commented Mar 15, 2019

馃殌 Feature

nn.Module.register_buffer() should get a keyword argument persistent=True. If set to False, the buffer will not be included in the output of state_dict(), and not loaded in _load_state_dict().

Motivation

I repeatedly come across cases where I want to precompute a tensor at construction time that is then used in every forward() call. Let's take a Discrete Cosine Transform module as an example. It can be computed by generating a DCT matrix and applying a matrix product. Assuming the input size is fixed at construction time, it would be wasteful to recompute the matrix in every forward call.
I currently have three options: Making it an nn.Parameter, registering it as a buffer, or directly storing it as an attribute.
The first two cause it to be included in state_dict. It would be wasteful to store the matrix in every model, and it would lock me in to that implementation -- if I decide to implement the DCT differently, I will have to implement a state dict hook that discards the matrix when loading older models.
The third one does not include it in state_dict, but also does not convert it when calling .cuda(), double() etc.. This can be fixed by overriding _apply(), but I don't want to do this in every Module and it would cause the model to not work with data_parallel() (which explicitly copies only the parameters and buffers).

Pitch

With register_buffer(name, tensor, persistent=False), I would want a buffer to be registered that is not stored and restored, but otherwise treated as any other buffer. The docstring of register_buffer already speaks of "persistent buffers", so it seems sensible to also allow "non-persistent buffers". From what I understand, this would only require changes in state_dict() and _load_state_dict(), as well as a way to track which buffers are non-persistent.

Alternatives

My use case is about what would be a constant in graph-based frameworks: Something that can be computed once and reused, something that is independent of the input to forward(). It would be possible to have a nn.Constant wrapper class similar to nn.Parameter, that is registered when assigning to a Module attribute, and included whenever a model is moved or replicated. But my impression is that there are several places in Pytorch that assume a model only has parameters and buffers, and would need updating to know about constants. Furthermore, it's not important that the value is constant, so that's a too narrow concept.

There once was a proposal for "calculated parameters" (#7313 (comment)) that would also fit my use case, but would cause more overhead for me, the humble developer -- I would need to implement a Module that computes the DCT matrix. All I want is a way to store Tensors as Module attributes that are not included in the state dict, but moved across devices just like parameters or buffers.

@vishwakftw vishwakftw added feature A request for a proper, new feature. module: nn Related to torch.nn labels Mar 15, 2019
@liangbright
Copy link

another problem:
nn.Module.register_buffer('some_tensor', None)
some_tensor will not show up as a key in state_dict(), which causes lots of trouble.

@soumith
Copy link
Member

soumith commented Mar 27, 2019

this seems pretty reasonable to do, we'd accept a PR into core (or we'll get around to it when we can)

@sharvil
Copy link
Contributor

sharvil commented Apr 23, 2020

I can take this on.

@zhangguanheng66 zhangguanheng66 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 23, 2020
@sharvil
Copy link
Contributor

sharvil commented Apr 23, 2020

The patch is ready for review. PTAL.

@sharvil
Copy link
Contributor

sharvil commented Apr 28, 2020

@zhangguanheng66 can you please help me find the right reviewers for this PR? Thanks.

@sharvil
Copy link
Contributor

sharvil commented May 4, 2020

Ping. Can we move forward with or reject the PR?

facebook-github-bot pushed a commit that referenced this issue May 7, 2020
Summary:
Issue: #18056
Pull Request resolved: #37191

Differential Revision: D21428373

Pulled By: albanD

fbshipit-source-id: a7d367bafb95137e1bc380178b82b08eff5d5a5a
@sharvil
Copy link
Contributor

sharvil commented May 18, 2020

I think we can close this issue now. The PR has been merged.

@f0k
Copy link
Contributor Author

f0k commented May 25, 2020

Yes, cool! Thanks for implementing this @sharvil!

@f0k f0k closed this as completed May 25, 2020
JJGO added a commit to JJGO/voxelmorph that referenced this issue Sep 17, 2021
PyTorch added the functionality to have non-persistent buffers that
are not included in the state_dict [1]. This is done by specifying
`persistent=False` when registering it.

This is just an improvement, that removes the associated workaround

[1] pytorch/pytorch#18056
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants