New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dtensor] fix dtensor _to_copy op for mix precision #116426
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116426
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8579064 with merge base ca4df16 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, might want to add a unit test?
|
||
register_op_strategy( | ||
aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) | ||
)(default_strategy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind if I confirm my understanding of the fix (just out of interest :) )
It looks like static_kwargkey
ends up getting used in the cache-lookup for DTensor's sharding prop: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/op_schema.py#L299
So was the problem that we had a graph with _to_copy()
showing up twice, with different dtypes, and we ended up using the cached sharding strategy, when we should have recomputed it for a different dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! this is correct, we should recomputed the dtype in this case but it was reusing the cached type
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
Context: Existing FSDPExtension have some bug in the case when the unflatten tensor involves some compute/communications in cuda stream, the current logic of FSDPExtension unflatten tensor happens in the unshard stream, which makes runtime lost sync with the compute stream, and if there're some dependencies between the compute stream and the unflatten tensor logic, currently it would lose sync point, which could possibly lead to NaN. This PR make the FSDPExtension to record the compute stream and let DTensorExtension to directly use the compute stream for unflatten_tensor. In long term we might want to directly make the FSDP runtime logic to only make the unshard happen in unshard stream, and use unshard views to happen in the compute stream. We currently fix this in the Extension directly as this is the simplest thing to do without affecting FSDP runtime logic Pull Request resolved: #116559 Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/yifuwang ghstack dependencies: #116426
Disable some runtime assertion first as it does not work with torch.compile properly, I'll have a follow up fix in dynamo and reenable this check again Pull Request resolved: #116573 Approved by: https://github.com/awgu, https://github.com/XilunWu ghstack dependencies: #116426, #116559
This PR adds devices to register_backend of multithraeded pg, to avoid seeing tons of warnings. Pull Request resolved: #116678 Approved by: https://github.com/awgu, https://github.com/XilunWu ghstack dependencies: #116426, #116559, #116573
Pull Request resolved: pytorch#116426 Approved by: https://github.com/fduwjj
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225