-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[DTensor] Add a private util for sharding tensor #142288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142288
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fb3b2a1 with merge base 61dc5e9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamped to unblock.
torch/distributed/tensor/_api.py
Outdated
Locally shards a full tensor based on indicated sharding arrangement, and | ||
returns a DTensor containing the local shard. | ||
.. warning:: This is a private API purposed to skip the communication |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we say This is a private API that is subject to change. It is ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added now.
for cur_shape, cur_offset in zip(shape, offset) | ||
] | ||
local_tensor = full_tensor[slices] | ||
return DTensor.from_local(local_tensor, device_mesh, placements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to pass the shape
and stride
for uneven tensor for from_local
. Otherwise, the shape and stride would be inferred from the local tensor as it is uniformly distributed. See example: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_state_dict_utils.py#L566-L572
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads-up. I don't have the shape or stride info, and this function assumes things are the same across ranks, so the infer would be fine.
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard. warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o tianyu-l XilunWu [ghstack-poisoned]
@pytorchbot merge -f "CI was green; minor edits to comments" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard. warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`. Pull Request resolved: pytorch#142288 Approved by: https://github.com/wz337
Stack from ghstack (oldest at bottom):
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard.
warning: This is a private API purposed to skip the communication otherwise required by
distribute_tensor
. It is only applicable to a case where all ranks have the samefull_tensor
.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu