-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[FSDP][state_dict] Return tensors instead of FlatParameters to avoid pickling errors #94637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…pickling errors After #88913, user-defined parameter states will be pickled. For a FlatParameter, this means `_local_shard` will also be pickled. Since state_dict and load_state_dict only require the tensor, returning the full FlatParameter does not give us any extra benefit. This PR changes the behavior to simply return a view of the FlatParameter. Differential Revision: [D43205127](https://our.internmc.facebook.com/intern/diff/D43205127/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94637
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 FailuresAs of commit 8b2f0e3: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix!
| if valid_data_size > 0: | ||
| if flat_param._shard_numel_padded > 0: | ||
| flat_param = flat_param.narrow(0, 0, valid_data_size) | ||
| flat_param = flat_param.view(valid_data_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, maybe add a comment saying this will make it return a tensor that can be properly serialized?
…s to avoid pickling errors" After #88913, user-defined parameter states will be pickled. For a FlatParameter, this means `_local_shard` will also be pickled. Since state_dict and load_state_dict only require the tensor, returning the full FlatParameter does not give us any extra benefit. This PR changes the behavior to simply return a view of the FlatParameter. Differential Revision: [D43205127](https://our.internmc.facebook.com/intern/diff/D43205127/) [ghstack-poisoned]
…pickling errors Pull Request resolved: #94637 After #88913, user-defined parameter states will be pickled. For a FlatParameter, this means `_local_shard` will also be pickled. Since state_dict and load_state_dict only require the tensor, returning the full FlatParameter does not give us any extra benefit. This PR changes the behavior to simply return a view of the FlatParameter. ghstack-source-id: 179983735 Differential Revision: [D43205127](https://our.internmc.facebook.com/intern/diff/D43205127/)
|
@pytorchbot merge -f "The failing tests are not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
After #88913, user-defined parameter states will be pickled. For a FlatParameter, this means
_local_shardwill also be pickled. Since state_dict and load_state_dict only require the tensor, returning the full FlatParameter does not give us any extra benefit. This PR changes the behavior to simply return a view of the FlatParameter.Differential Revision: D43205127