-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[fsdp][torch.compile] FSDP changes #115497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115497
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 2c65d1a with merge base 87ea6fb ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
@@ -441,7 +441,9 @@ def _pre_forward_unshard( | |||
if not handle._prefetched: | |||
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) | |||
handle._needs_pre_forward_unshard = False | |||
state._device_handle.current_stream().wait_stream(state._unshard_stream) | |||
# Don't wait during trace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is the trace valid if it contains stream information but not stream synchronization?
I guess we don't care for now since it's ignored by the backend but most likely will need fixing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is sound for now - but the longer term version of this will need to teach inductor about .wait_stream and plumb the whole thing through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @voznesenskym @albanD
Yes, this is sound for now - but the longer term version of this will need to teach inductor about .wait_stream and plumb the whole thing through.
May i know why here bypassing the stream methods? Does it mean for the longer term plan, the stream and other runtime primitives, like event, will be compiled through the torch.compile(dynamo+AOTAutograd+inductor)
?
May i know if there is a plan for supporting stream/event in whole torch.compile
stack?
Thank you. :)
and hasattr(handle, "_ran_pre_backward_hook") | ||
and handle._ran_pre_backward_hook | ||
): | ||
log.warning("%s %s", id(state), "Not Running pre backward! Already Ran!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this really be a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug is better
@@ -327,3 +333,7 @@ def _replace_by_prefix( | |||
new_key = new_prefix + key[len(old_prefix) :] | |||
state_dict[new_key] = state_dict[key] | |||
del state_dict[key] | |||
|
|||
|
|||
def _data_ptr_allocated(tensor: torch.Tensor) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this missing an allow_in_graph? Or is it not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need it for now - It will in the future, but this makes it easier for me to track where I will eventually need to add it.
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#115497 Approved by: https://github.com/albanD
flat_param_part_view is unused in pytorch repo: https://fburl.com/ssaomd7x it became unused since refactoring in pytorch#115497 before that, the original code is below. Since flat_param is 1D, we do not need .view for reshaping ``` self.flat_param.data = padded_unsharded_flat_param[ : unsharded_size.numel() ].view( unsharded_size ) ``` unit test: pytest test/distributed/fsdp/test_fsdp_core.py
flat_param_part_view is unused in pytorch repo: https://fburl.com/ssaomd7x it became unused since refactoring in #115497 before that, the original code is below. Since flat_param is 1D, we do not need .view for reshaping ``` self.flat_param.data = padded_unsharded_flat_param[ : unsharded_size.numel() ].view( unsharded_size ) ``` unit test: pytest test/distributed/fsdp/test_fsdp_core.py Pull Request resolved: #117082 Approved by: https://github.com/awgu, https://github.com/wconstab, https://github.com/Skylion007
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225