-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Fix embedding backward with scalar index #82809
Conversation
🔗 Helpful links
❌ 8 New Failures, 9 PendingAs of commit 2410db0 (more details on the Dr. CI page): Expand to see more
🕵️ 8 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakagestrunk / win-vs2019-cuda11.6-py3 / test (default, 3, 5, windows.8xlarge.nvidia.gpu) (1/8)Step: "Test" (full log | diagnosis details)
|
@pytorchbot label "ciflow/trunk" "module: mps" |
@kulinseth @albanD can you please take a look? |
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
|
/easycla |
@pytorchbot rebase |
You don't have permissions to rebase this PR, only people with write permissions may rebase PRs. |
2410db0
to
ade9dfc
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/82809
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8da9071: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Cancelled rebase, looks good to me |
ade9dfc
to
8da9071
Compare
@pytorchbot merge - f "MPS tests are green" |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
### Description Previously the embedding backward always expands `-1` dim to indices, resulting in the following error when the indices is a scalar: ``` error: Rank of data array must equal number of outer dimensions in indices array + rank of slice to update, 2 != 1 + 0 -:8:10: note: see current operation: %5 = "mps.scatter_nd"(%0, %arg1, %4) {batch_dims = 0 : ui32, mode = 0 : i32} : (tensor<10x5xf16>, ``` Now makes it conditional. Reproducer: ```python def repro(): w = torch.tensor([[-2.6465, 2.5859, 0.4688, 1.7949, 3.2676], [-3.1641, 8.9375, 5.7578, -2.9453, -6.5469], [ 2.0469, 1.3516, -8.7344, 6.0000, 1.3906], [ 6.5781, 7.8438, 6.9766, 3.2891, -5.1172], [-7.9414, 7.7344, 4.1875, 2.8574, 2.9531], [-0.4844, -5.6328, -6.8359, -4.5156, 3.7891], [ 4.9375, 6.6094, 6.7031, 0.6719, -6.4219], [ 7.0469, 8.2031, 4.4453, 1.7129, -2.4688], [ 1.2207, -3.3750, -2.4531, 7.4062, -6.0469], [-8.9688, 2.2656, 2.4160, -1.0176, 8.4531]], dtype=torch.float32, requires_grad=True) x = torch.tensor(5) out = torch.nn.functional.embedding(x, w) out.sum().backward() w_mps = w.detach().clone().to("mps").requires_grad_() x_mps = x.to("mps") out = torch.nn.functional.embedding(x_mps, w_mps) out.sum().backward() # error ``` ### Issue <!-- Link to Issue ticket or RFP --> ### Testing <!-- How did you test your change? --> Pull Request resolved: pytorch#82809 Approved by: https://github.com/malfet
### Description Previously the embedding backward always expands `-1` dim to indices, resulting in the following error when the indices is a scalar: ``` error: Rank of data array must equal number of outer dimensions in indices array + rank of slice to update, 2 != 1 + 0 -:8:10: note: see current operation: %5 = "mps.scatter_nd"(%0, %arg1, %4) {batch_dims = 0 : ui32, mode = 0 : i32} : (tensor<10x5xf16>, ``` Now makes it conditional. Reproducer: ```python def repro(): w = torch.tensor([[-2.6465, 2.5859, 0.4688, 1.7949, 3.2676], [-3.1641, 8.9375, 5.7578, -2.9453, -6.5469], [ 2.0469, 1.3516, -8.7344, 6.0000, 1.3906], [ 6.5781, 7.8438, 6.9766, 3.2891, -5.1172], [-7.9414, 7.7344, 4.1875, 2.8574, 2.9531], [-0.4844, -5.6328, -6.8359, -4.5156, 3.7891], [ 4.9375, 6.6094, 6.7031, 0.6719, -6.4219], [ 7.0469, 8.2031, 4.4453, 1.7129, -2.4688], [ 1.2207, -3.3750, -2.4531, 7.4062, -6.0469], [-8.9688, 2.2656, 2.4160, -1.0176, 8.4531]], dtype=torch.float32, requires_grad=True) x = torch.tensor(5) out = torch.nn.functional.embedding(x, w) out.sum().backward() w_mps = w.detach().clone().to("mps").requires_grad_() x_mps = x.to("mps") out = torch.nn.functional.embedding(x_mps, w_mps) out.sum().backward() # error ``` ### Issue <!-- Link to Issue ticket or RFP --> ### Testing <!-- How did you test your change? --> Pull Request resolved: pytorch#82809 Approved by: https://github.com/malfet
Description
Previously the embedding backward always expands
-1
dim to indices, resulting in the following error when the indices is a scalar:Now makes it conditional.
Reproducer:
Issue
Testing
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev