-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dtensor][4/N] have row-wise sharding always use LocalShardsWrapper #122843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122843
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cc5525a with merge base e70bf23 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address the previous PR comment as it might change how this PR is implemented, and re-request review once done :)
…dsWrapper" **Summary** Always wrap local tensor into a `LocalShardsWrapper`. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops over `LocalShardsWrapper`, users need to extend its `__torch_dispatch__`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even` **Result** ``` Row-wise even sharding example in DTensor Col 0-15 ------- ---------- Row 0-1 cuda:0 Row 2-3 cuda:1 Row 4-5 cuda:2 Row 6-7 cuda:3 ``` cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…ticipating ranks only (#122853) **Summary** We wrap DTensor's local tensor in `LocalShardsWrapper` for torchrec's table-wise sharding. The exception is on non-participating ranks: for non-participating ranks, the local tensor is an empty torch.Tensor object. The reason of this design is to avoid complexity on supporting empty tensor case on `LocalShardsWrapper`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise` Pull Request resolved: #122853 Approved by: https://github.com/wz337 ghstack dependencies: #120265, #121392, #122843
…ytorch#122843) **Summary** Always wrap local tensor into a `LocalShardsWrapper`. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops over `LocalShardsWrapper`, users need to extend its `__torch_dispatch__`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even` **Result** ``` Row-wise even sharding example in DTensor Col 0-15 ------- ---------- Row 0-1 cuda:0 Row 2-3 cuda:1 Row 4-5 cuda:2 Row 6-7 cuda:3 ``` Pull Request resolved: pytorch#122843 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#120265, pytorch#121392
…ticipating ranks only (pytorch#122853) **Summary** We wrap DTensor's local tensor in `LocalShardsWrapper` for torchrec's table-wise sharding. The exception is on non-participating ranks: for non-participating ranks, the local tensor is an empty torch.Tensor object. The reason of this design is to avoid complexity on supporting empty tensor case on `LocalShardsWrapper`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise` Pull Request resolved: pytorch#122853 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#120265, pytorch#121392, pytorch#122843
…ytorch#122843) **Summary** Always wrap local tensor into a `LocalShardsWrapper`. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops over `LocalShardsWrapper`, users need to extend its `__torch_dispatch__`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even` **Result** ``` Row-wise even sharding example in DTensor Col 0-15 ------- ---------- Row 0-1 cuda:0 Row 2-3 cuda:1 Row 4-5 cuda:2 Row 6-7 cuda:3 ``` Pull Request resolved: pytorch#122843 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#120265, pytorch#121392
…ticipating ranks only (pytorch#122853) **Summary** We wrap DTensor's local tensor in `LocalShardsWrapper` for torchrec's table-wise sharding. The exception is on non-participating ranks: for non-participating ranks, the local tensor is an empty torch.Tensor object. The reason of this design is to avoid complexity on supporting empty tensor case on `LocalShardsWrapper`. **Test** `torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise` Pull Request resolved: pytorch#122853 Approved by: https://github.com/wz337 ghstack dependencies: pytorch#120265, pytorch#121392, pytorch#122843
Stack from ghstack (oldest at bottom):
Summary
Always wrap local tensor into a
LocalShardsWrapper
. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops overLocalShardsWrapper
, users need to extend its__torch_dispatch__
.Test
torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even
Result
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang