-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[dtensor] add op support for aten.gather.default #118513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118513
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 80826b4 with merge base 0f7e636 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 [ghstack-poisoned]
The CI report says "test_dtensor_op_db_take_along_dim_cpu_float32 in Another thing you can do in future is to locally run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left 2 suggestions. I'm not very clear about the main body of gather_strategy
and the use of _MaskPartial
though.
input_dt = distribute_tensor(global_input, device_mesh, [Replicate()]) | ||
index_dt = distribute_tensor(global_index, device_mesh, [Shard(gather_dim)]) | ||
global_output = torch.gather(global_input, gather_dim, global_index) | ||
comm_mode = CommDebugMode() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we need to instantiate another comm_mode
but re-use the one instantiated above. We can instantiate one comm_mode
at the beginning of the test and reuse it everywhere in the test,
input_shape = input_strategy.output_shape | ||
index_shape = index_strategy.output_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to check if the input_shape
and output_shape
is eligible to perform torch.gather
:
https://pytorch.org/docs/stable/generated/torch.gather.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't check in dtensor sharding prop, these errors would pop up from local tensor ops, which seems still is the expected behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, some minor suggestions inlined.
# tensor dim can be equal or larger than the mask dim, respectively. | ||
if tensor.ndim == self.mask_buffer.data.ndim: | ||
tensor[self.mask_buffer.data] = 0.0 | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main reason here is that the embedding output would produce an additional dimension compare to the input, hence the output masking logic become different? maybe have more clarification in the comment
tensor[self.mask_buffer.data, :] = 0.0 | ||
# NOTE: Depending on the use case (gather op or embedding op), | ||
# tensor dim can be equal or larger than the mask dim, respectively. | ||
if tensor.ndim == self.mask_buffer.data.ndim: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given that we are reusing the logic, let's put this logic as a common method in MaskBuffer.
This PR doesn't touch |
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 [ghstack-poisoned]
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 [ghstack-poisoned]
@tianyu-l check pytorch/torch/_refs/__init__.py Lines 4470 to 4504 in fb8ffba
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thx for the work!
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225 [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #118513 Approved by: https://github.com/wanchaol, https://github.com/XilunWu
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @wconstab @yf225