Skip to content

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jan 29, 2024

Copy link

pytorch-bot bot commented Jan 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118513

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 80826b4 with merge base 0f7e636 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Jan 29, 2024
@tianyu-l tianyu-l requested a review from wanchaol January 29, 2024 08:27
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jan 29, 2024
ghstack-source-id: 83c9536
Pull Request resolved: #118513
@tianyu-l tianyu-l added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (dtensor) release notes category labels Jan 29, 2024
@XilunWu
Copy link
Contributor

XilunWu commented Jan 29, 2024

The CI report says "test_dtensor_op_db_take_along_dim_cpu_float32 in test_dtensor_ops.py has unexpected success". This means your change has made it working correctly. You can remove "xfail(take_along_dim)" from the file which will mark this test as a passing test instead of an expected failing test.

Another thing you can do in future is to locally run pytest test/distributed/_tensor/test_dtensor_ops.py to see if the test result has changed with your PR.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left 2 suggestions. I'm not very clear about the main body of gather_strategy and the use of _MaskPartial though.

input_dt = distribute_tensor(global_input, device_mesh, [Replicate()])
index_dt = distribute_tensor(global_index, device_mesh, [Shard(gather_dim)])
global_output = torch.gather(global_input, gather_dim, global_index)
comm_mode = CommDebugMode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need to instantiate another comm_mode but re-use the one instantiated above. We can instantiate one comm_mode at the beginning of the test and reuse it everywhere in the test,

Comment on lines +343 to +344
input_shape = input_strategy.output_shape
index_shape = index_strategy.output_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to check if the input_shape and output_shape is eligible to perform torch.gather:
https://pytorch.org/docs/stable/generated/torch.gather.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't check in dtensor sharding prop, these errors would pop up from local tensor ops, which seems still is the expected behavior?

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, some minor suggestions inlined.

# tensor dim can be equal or larger than the mask dim, respectively.
if tensor.ndim == self.mask_buffer.data.ndim:
tensor[self.mask_buffer.data] = 0.0
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main reason here is that the embedding output would produce an additional dimension compare to the input, hence the output masking logic become different? maybe have more clarification in the comment

tensor[self.mask_buffer.data, :] = 0.0
# NOTE: Depending on the use case (gather op or embedding op),
# tensor dim can be equal or larger than the mask dim, respectively.
if tensor.ndim == self.mask_buffer.data.ndim:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we are reusing the logic, let's put this logic as a common method in MaskBuffer.

@tianyu-l
Copy link
Contributor Author

The CI report says "test_dtensor_op_db_take_along_dim_cpu_float32 in test_dtensor_ops.py has unexpected success". This means your change has made it working correctly. You can remove "xfail(take_along_dim)" from the file which will mark this test as a passing test instead of an expected failing test.

Another thing you can do in future is to locally run pytest test/distributed/_tensor/test_dtensor_ops.py to see if the test result has changed with your PR.

This PR doesn't touch take_along_dim. Not sure how it could make the test pass... Should I still remove it?

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Feb 1, 2024
ghstack-source-id: f9fc955
Pull Request resolved: #118513
@tianyu-l tianyu-l requested a review from XilunWu February 1, 2024 02:17
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225

[ghstack-poisoned]
@XilunWu
Copy link
Contributor

XilunWu commented Feb 1, 2024

The CI report says "test_dtensor_op_db_take_along_dim_cpu_float32 in test_dtensor_ops.py has unexpected success". This means your change has made it working correctly. You can remove "xfail(take_along_dim)" from the file which will mark this test as a passing test instead of an expected failing test.
Another thing you can do in future is to locally run pytest test/distributed/_tensor/test_dtensor_ops.py to see if the test result has changed with your PR.

This PR doesn't touch take_along_dim. Not sure how it could make the test pass... Should I still remove it?

@tianyu-l check

@out_wrapper()
def take_along_dim(
a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
) -> torch.Tensor:
torch._check(
a.ndim == indices.ndim,
lambda: (
"torch.take_along_dim(): input and indices should have the same "
f"number of dimensions, but got {a.ndim} dimensions for input, and "
f"{indices.ndim} dimensions for indices"
),
)
torch._check(
utils.is_integer_dtype(indices.dtype),
lambda: (
"torch.take_along_dim(): dtype of indices should be int but got "
f"{indices.dtype} instead"
),
)
if dim is None:
return torch.gather(a.view(-1), 0, indices.view(-1))
else:
self_sizes = list(a.shape)
self_sizes[dim] = indices.size(dim)
broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
indices_broadcast = broadcast_to(indices, broadcast_shape)
indices_sizes = list(indices.shape)
indices_sizes[dim] = a.size(dim)
broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
self_broadcast = broadcast_to(a, broadcast_shape)
return torch.gather(self_broadcast, dim, indices_broadcast)

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thx for the work!

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 wconstab yf225

[ghstack-poisoned]
@tianyu-l
Copy link
Contributor Author

tianyu-l commented Feb 2, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/tianyu-l/3/head branch February 5, 2024 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants