-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[functorch][nn] Refactor NN stateless APIs by swapping module tensors #92536
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92536
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 56ce796: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
22272d2
to
12ae267
Compare
8909f5e
to
9ac946e
Compare
Hey @XuehaiPan, thanks for the contribution! This is certainly something we'd consider merging. Do you mind answering a few questions to help our understanding on the review side?
Apologies in advance if it takes us a while to review this; we're a bit busy on our side, but this is something we will work to get reviewed! |
9ac946e
to
a785070
Compare
I have described the implementation details of I can summarize them as:
Both of the existing two approaches above limit the usage and flexibility of user-defined forward pass. Users can only access the parameters/buffers by In this PR, the new approach is more like the It is similar to the following for easy understanding (only similar in concept but not the same): def functional_call(module, parameters_and_buffers, args, kwargs):
orig_state_dict = module.state_dict()
try:
module.load_state_dict(parameters_and_buffers)
return module(*args, **kwargs)
finally:
module.load_state_dict(orig_state_dict)
All user-given tensors are registered as "true" parameters and buffers. Users can access them by
No. It is feeless from the user side. Beyond bug fixes, it is backward-compatible.
I think it will have a similar performance to the current In my opinion, the performance would be: call-per-second (the higher the better): Because |
a785070
to
8442194
Compare
Just a quick fyi @XuehaiPan -- I'll be out until 3/7 so probably won't respond very much. You can cc Chillee for general functorch matters and he'll route them to the right places |
Shall we add some unittest for the new |
87fd85f
to
d12c000
Compare
863fd55
to
70c1291
Compare
Added. |
This reverts commit 79449e6.
f3d3a5a
to
56ce796
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful :)
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Follows #92536 (comment). There have been 10 months since `torch.nn.utils._stateless` was marked as deprecated. This PR also changes `tie_weights` in `_reparametrize_module` to kw-only argument. Since it is private API and only imported by `torch.nn.utils._stateless` (removed). Pull Request resolved: #94498 Approved by: https://github.com/jbschlosser
Fixes [Bug][functorch]
self.parameters()
is not available during forward pass with eitherfunctorch.make_functional
nortorch.func.functional_call
#92295Resolves Make stateless.functional_call support weight tying #86708
Resolves Add strict mode to functional_call #92153
Closes [WIP] Add
strict
flag tofunctional_call
#92401Closes [functorch] Cache intermediate submodules for submodule lookup in stateless functional calls #92218
Requires Fix and update type hints for
make_functional.py
#91579Refactor NN stateless APIs by swapping module tensors.
cc @albanD @zou3519