Skip to content

Conversation

mayank31398
Copy link
Contributor

@mayank31398 mayank31398 commented Jul 11, 2024

Fixes #130549

This PR uses the specific dtype for the grad_input buffer and fixes the error

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Copy link

pytorch-bot bot commented Jul 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130550

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f3eb4ef with merge base 72d9135 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 11, 2024
# only update grad_input to -1 if not masked
assert partial_placement.mask_buffer.data is not None
grad_update = partial_placement.mask_buffer.data.float() - 1.0
grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l I think we should avoid using .data here since its intent can be ambiguous. If you want it to not be part of autograd graph, then you should use .detach().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree I think detach is a safer option.
but I'll defer to @tianyu-l

Copy link
Contributor

@tianyu-l tianyu-l Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l I think we should avoid using .data here since its intent can be ambiguous. If you want it to not be part of autograd graph, then you should use .detach().

Hmm the .data is just a class variable (of type torch.Tensor) of the class MaskBuffer. So it shouldn't be a concern?

https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/embedding_ops.py#L28

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the MuskBuffer is not a torch.Tensor and data is just coincidentally a field name of that custom object, we can probably rename that field later to avoid confusing with tensor.data

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad!

@awgu awgu requested a review from tianyu-l July 11, 2024 17:10
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Thanks a lot for fixing the bug!

@mayank31398
Copy link
Contributor Author

all tests passed
lets merge this @awgu @tianyu-l

@awgu
Copy link
Collaborator

awgu commented Jul 12, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@awgu awgu added the release notes: distributed (dtensor) release notes category label Jul 12, 2024
@awgu
Copy link
Collaborator

awgu commented Jul 12, 2024

I did not see a TP specific label, so I used the DTensor one.

@awgu
Copy link
Collaborator

awgu commented Jul 12, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mayank31398 mayank31398 deleted the fix-tp-loss branch July 12, 2024 16:51
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
Fixes pytorch#130549

This PR uses the specific dtype for the `grad_input` buffer and fixes the error

Pull Request resolved: pytorch#130550
Approved by: https://github.com/tianyu-l
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug with loss_parallel when BF16 logits are passed

6 participants