-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix nn.Parameter
usage for 2D and use_orig_params=True
#89782
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89782
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 93cccd0: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5830e8be3148aa7fc3514136ae02f81b7417d50e Pull Request resolved: #89782
@@ -1307,7 +1317,10 @@ def _use_unsharded_views(self, as_params: bool) -> None: | |||
assert tensor is not None # mypy | |||
param_var = tensor | |||
setattr(module, param_name, param_var) | |||
if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: | |||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change and below is just ufmt
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
@pytorchbot rebase -s |
@pytorchbot successfully started a rebase job. Check the current status here |
…s=True`" This ensures that all elements of `FlatParameter._params` and `FlatParameter._shared_params` are `nn.Parameter`s (as expected). This was violated by the local tensor of a `DTensor` when using 2D parallelism. To fix the breakage, we simply wrap with `nn.Parameter` if needed. [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: 6282e9a1d4fa734eb46ab8a4f9ed6166208a335f Pull Request resolved: #89782
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…torch#89782) This ensures that all elements of `FlatParameter._params` and `FlatParameter._shared_params` are `nn.Parameter`s (as expected). This was violated by the local tensor of a `DTensor` when using 2D parallelism. To fix the breakage, we simply wrap with `nn.Parameter` if needed. Pull Request resolved: pytorch#89782 Approved by: https://github.com/fduwjj
Stack from ghstack (oldest at bottom):
nn.Parameter
usage for 2D anduse_orig_params=True
#89782This ensures that all elements of
FlatParameter._params
andFlatParameter._shared_params
arenn.Parameter
s (as expected). This was violated by the local tensor of aDTensor
when using 2D parallelism. To fix the breakage, we simply wrap withnn.Parameter
if needed.