Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[functorch][nn] Refactor NN stateless APIs by swapping module tensors #92536

Closed
wants to merge 16 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 18, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92536

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 56ce796:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@XuehaiPan XuehaiPan changed the title Refactor NN stateless APIs by swapping module tensors [functorch][nn] Refactor NN stateless APIs by swapping module tensors Jan 18, 2023
@XuehaiPan XuehaiPan force-pushed the swap-tensors branch 2 times, most recently from 8909f5e to 9ac946e Compare January 18, 2023 16:41
@jbschlosser
Copy link
Contributor

Hey @XuehaiPan, thanks for the contribution! This is certainly something we'd consider merging. Do you mind answering a few questions to help our understanding on the review side?

  • Can you describe at a somewhat detailed level how this approach works and differs from what is currently done in master?
  • Are there benefits to this approach beyond addressing the issues linked in the PR description (which is of course very valuable)?
  • What are the observable differences from the user side (beyond bug fixes)? Do you have a sense for where BC-breaking aspects will arise?
  • Do you expect any performance improvements or regressions?

Apologies in advance if it takes us a while to review this; we're a bit busy on our side, but this is something we will work to get reviewed!

@XuehaiPan
Copy link
Collaborator Author

XuehaiPan commented Jan 18, 2023

Can you describe at a somewhat detailed level how this approach works and differs from what is currently done in master?

I have described the implementation details of functorch.make_functional and (current approach) torch.func.functional_call in #92295 (comment).

I can summarize them as:

  • functorch.make_functional: swap module tensors by delattr first when setattr. That will make all tensors store in module._buffers and leave an empty module._parameters.

    While doing a functional call, for example, given name='weight' and tensor=Tensor(..., requires_grad=True), the delattr function will unregister the stateless module parameter, then setattr will register the user-given tensor as a buffer, not a parameter. Because nn.Module.__setattr__ only set nn.Parameter type as a parameter.

  • torch.func.functional_call: change the module class by overriding the __getattribute__ and __setattr__ methods. That does not change contents in module._parameters and module._buffers.

    The current torch.func.functional_call implementation is very tricky. That allows self.weight access to the user-given tensor with a customized __getattribute__ method. However, it does not change the contents in module._parameters and module._buffers. So users can access the parameters/buffers by self.<attr-name> but all other methods are not available. Such as, self.get_parameter(...), self.parameters(), self.buffer(), self.state_dict(), etc. These methods will return the tensors in the original module rather than the user-given ones.

Both of the existing two approaches above limit the usage and flexibility of user-defined forward pass. Users can only access the parameters/buffers by self.<attr-name>. If users use something else but available in NN APIs, users will get errors or unexpected values.

In this PR, the new approach is more like the functorch.make_functional approach (actually the _swap_state function), but not using delattr + setattr. Instead, it swaps the tensors in the module by accessing the module._parameters and module._buffers dicts directly. All user-given tensors are registered as "true" parameters and buffers. Users can access them by self.<attr-name>, self.get_parameter(...), self.parameters(), self.state_dict(), and also self.to(...).

It is similar to the following for easy understanding (only similar in concept but not the same):

def functional_call(module, parameters_and_buffers, args, kwargs):
    orig_state_dict = module.state_dict()
    try:
        module.load_state_dict(parameters_and_buffers)
        return module(*args, **kwargs)
    finally:
        module.load_state_dict(orig_state_dict)

Are there benefits to this approach beyond addressing the issues linked in the PR description (which is of course very valuable)?

All user-given tensors are registered as "true" parameters and buffers. Users can access them by self.<attr-name>, self.parameters(), self.state_dict(), and also self.to(...). This will give the user a good experience transitioning from the object-originated pattern (OOP) to the stateless functional pattern. All methods, that the NN users are familiar with, are still available during the functional call.

What are the observable differences from the user side (beyond bug fixes)? Do you have a sense for where BC-breaking aspects will arise?

No. It is feeless from the user side. Beyond bug fixes, it is backward-compatible.

Do you expect any performance improvements or regressions?

I think it will have a similar performance to the current torch.func.functional_call implementation.

In my opinion, the performance would be:

call-per-second (the higher the better): functorch.make_functional > torch.func.functional_call (this PR) ≈ torch.func.functional_call (in master)

Because functorch.make_functional traverses the module and extract the tensor names in the first place. While torch.func.functional_call will traverse and do many checks in each function call. I think the former one would be slightly faster. But the true bottleneck here is the forward pass rather than the call preparation overhead though.

@zou3519
Copy link
Contributor

zou3519 commented Jan 18, 2023

Just a quick fyi @XuehaiPan -- I'll be out until 3/7 so probably won't respond very much. You can cc Chillee for general functorch matters and he'll route them to the right places

@ain-soph
Copy link
Contributor

Shall we add some unittest for the new strict flag?

@XuehaiPan XuehaiPan force-pushed the swap-tensors branch 2 times, most recently from 87fd85f to d12c000 Compare January 19, 2023 09:14
@XuehaiPan XuehaiPan marked this pull request as draft January 19, 2023 10:45
@XuehaiPan XuehaiPan force-pushed the swap-tensors branch 11 times, most recently from 863fd55 to 70c1291 Compare January 19, 2023 14:09
@XuehaiPan
Copy link
Collaborator Author

Shall we add some unittest for the new strict flag?

Added.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful :)

@jbschlosser
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@XuehaiPan XuehaiPan deleted the swap-tensors branch February 9, 2023 08:37
pytorchmergebot pushed a commit that referenced this pull request Feb 9, 2023
Follows #92536 (comment). There have been 10 months since `torch.nn.utils._stateless` was marked as deprecated.

This PR also changes `tie_weights` in `_reparametrize_module` to kw-only argument. Since it is private API and only imported by `torch.nn.utils._stateless` (removed).
Pull Request resolved: #94498
Approved by: https://github.com/jbschlosser
@jbschlosser jbschlosser removed their assignment Mar 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
8 participants