New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Module: use swap_tensors for Tensor subclasses (#122755) #123106
Conversation
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors. This uses the `swap_tensors` method to swap all of the tensors not just the .data field. Test plan: ``` pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting' python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True ``` Pull Request resolved: #122755 Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki (cherry picked from commit e6ee832)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123106
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit a49712f with merge base 86a2d67 (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I had initially wanted to gate this behavior behind the flag for 2.3, but if you view this as a critical fix for DTensor/traceable wrapper subclasses this sounds good to me!
Just want to note again the constraints for this path mentioned here that might not be present for the regular compute_should_use_set_data
path though.
Granted that _apply
was broken for wrapper subclasses so this is still an improvement to the state of the world nevertheless.
@mikaylagawarecki Thanks for the context! Yeah I would indeed say this is a critical fix for subclasses given that the For the constraints mentioned, I think these looks fine to me! wondering if you think we should submit a cherry-pick PR for #122800? |
Sounds good! Yea we could cherry pick #122800 indeed so that this feature works with all nn.Modules, I will do this |
I'm not sure about this. This is not a bugfix in the sense that this never worked or had any chance to work. Also as Mikayla mentioned, there is some risk associated with this being enabled where the original plan was to keep it behind a flag for 2.3 to test it out without too much risk. |
I think technically this is a bug for wrapper tensor subclasses. Before this enablement, if user called Wondering if you think this make sense or not :) |
This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.
This uses the
swap_tensors
method to swap all of the tensors not just the .data field.Test plan:
Pull Request resolved: #122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki
(cherry picked from commit e6ee832)
Fixes #ISSUE_NUMBER
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang