Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Module: use swap_tensors for Tensor subclasses (#122755) #123106

Merged
merged 1 commit into from Apr 2, 2024

Conversation

wanchaol
Copy link
Contributor

@wanchaol wanchaol commented Apr 1, 2024

This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the swap_tensors method to swap all of the tensors not just the .data field.

Test plan:

pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True

Pull Request resolved: #122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki

(cherry picked from commit e6ee832)

Fixes #ISSUE_NUMBER

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

This fixes a bug when casting a module that has DTensor parameters. The old behavior will swap the .data field of the Tensor subclass which is incorrect behavior when dealing with tensor subclasses that may have multiple child tensors.

This uses the `swap_tensors` method to swap all of the tensors not just the .data field.

Test plan:

```
pytest test/distributed/_tensor/test_api.py -k 'test_distribute_module_casting'
python test/distributed/fsdp/test_wrap.py -k test_auto_wrap_smoke_test_cuda_init_mode1_cpu_offload0_use_device_id_True
```

Pull Request resolved: #122755
Approved by: https://github.com/wanchaol, https://github.com/mikaylagawarecki

(cherry picked from commit e6ee832)
Copy link

pytorch-bot bot commented Apr 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123106

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit a49712f with merge base 86a2d67 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 1, 2024
@Skylion007 Skylion007 requested a review from malfet April 1, 2024 18:02
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I had initially wanted to gate this behavior behind the flag for 2.3, but if you view this as a critical fix for DTensor/traceable wrapper subclasses this sounds good to me!

Just want to note again the constraints for this path mentioned here that might not be present for the regular compute_should_use_set_data path though.

Granted that _apply was broken for wrapper subclasses so this is still an improvement to the state of the world nevertheless.

@wanchaol
Copy link
Contributor Author

wanchaol commented Apr 1, 2024

Hey, I had initially wanted to gate this behavior behind the flag for 2.3, but if you view this as a critical fix for DTensor/traceable wrapper subclasses this sounds good to me!

Just want to note again the constraints for this path mentioned here that might not be present for the regular compute_should_use_set_data path though.

Granted that _apply was broken for wrapper subclasses so this is still an improvement to the state of the world nevertheless.

@mikaylagawarecki Thanks for the context! Yeah I would indeed say this is a critical fix for subclasses given that the Module._apply was broken for subclasses like DTensor/Float8Tensor without the swap_tensors feature, so I think it is at least a improvements for the subclasses. Getting this fix to a stable release quicker would be nice since it shows working for the tracable subclasses :) Let me know if you have some concerns though!

For the constraints mentioned, I think these looks fine to me! wondering if you think we should submit a cherry-pick PR for #122800?

@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Apr 1, 2024

Sounds good! Yea we could cherry pick #122800 indeed so that this feature works with all nn.Modules, I will do this

@albanD
Copy link
Collaborator

albanD commented Apr 2, 2024

I'm not sure about this. This is not a bugfix in the sense that this never worked or had any chance to work. Also as Mikayla mentioned, there is some risk associated with this being enabled where the original plan was to keep it behind a flag for 2.3 to test it out without too much risk.
I would personally prefer to keep this safe approach and not enable this by default.

@wanchaol
Copy link
Contributor Author

wanchaol commented Apr 2, 2024

I'm not sure about this. This is not a bugfix in the sense that this never worked or had any chance to work. Also as Mikayla mentioned, there is some risk associated with this being enabled where the original plan was to keep it behind a flag for 2.3 to test it out without too much risk. I would personally prefer to keep this safe approach and not enable this by default.

I think technically this is a bug for wrapper tensor subclasses. Before this enablement, if user called module.to to move the wrapper subclass parameters to either a different dtype/device, the behavior was always silently wrong. So the swap_tensors fixes definitely fixes the bug for users who use wrapper subclasses (i.e. DTensor, Float8Tensor). Given that the PR only enables wrapper tensor subclasses to use swap_tensors but not any other paths, I think the risk is relatively low and it fixes the issue with wrapper tensor subclasses, the feature can still be behind the flag for all other paths for 2.3.

Wondering if you think this make sense or not :)

@huydhn huydhn merged commit ef38d05 into release/2.3 Apr 2, 2024
97 of 98 checks passed
@github-actions github-actions bot deleted the subclass_swap_tensor_2.3 branch May 3, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants