-
Couldn't load subscription status.
- Fork 25.7k
[mxfp8 torch._scaled_grouped_mm] fix meta registration for 3d tensor #162765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162765
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1f78541 with merge base 36338fc ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#162765) Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N). See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622). Pull Request resolved: pytorch#162765 Approved by: https://github.com/ngimel
…ytorch#162765) Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N). See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622). Pull Request resolved: pytorch#162765 Approved by: https://github.com/ngimel
…ytorch#162765) Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N). See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622). Pull Request resolved: pytorch#162765 Approved by: https://github.com/ngimel
…ytorch#162765) Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N). See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622). Pull Request resolved: pytorch#162765 Approved by: https://github.com/ngimel
Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N).
See Blas.cpp for correct validation logic here.