New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port bmm and baddbmm from TH to ATen #42553
Conversation
[ghstack-poisoned]
ghstack-source-id: b737e9836ecfa12fcf8e6fd7d29421ca16eb524b Pull Request resolved: #42553
💊 CI failures summary and remediationsAs of commit 909b366 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 162 times. |
[ghstack-poisoned]
ghstack-source-id: 3eba02aebf71537dd1af8841ec2f4cb2f94bf279 Pull Request resolved: #42553
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 047dffe5ba8d7ab12d820bced1072315ed8f3b04 Pull Request resolved: #42553
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not finished yet. Will post more comment later.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
ghstack-source-id: f45d4e1d89b776d52fa788d29e2f24a05d143ba4 Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
ghstack-source-id: 61c04d480b16d92ad66fe7246773b48f5f5302fc Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
ghstack-source-id: b43d308cf89ef6b5f42c6adc98118f9facabeb83 Pull Request resolved: #42553
Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices). |
@anjali411 Could you please rebase? Looks like there are lots of flaky tests. |
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
ghstack-source-id: 2a2540ae6445b3f513e83a5a355d6099532621de Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
ghstack-source-id: 3469afd70f10593355aa5e7f85fb983c47ea5714 Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
test/test_torch.py
Outdated
@@ -17938,24 +17938,16 @@ def test_strided_mm_bmm(self, device, dtype): | |||
|
|||
@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1") | |||
@onlyOnCPUAndCUDA | |||
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) | |||
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @ngimel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498
So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16
and torch.bfloat16
in a follow-up PR and leave this one as is.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
ghstack-source-id: 0bbc3aae1e6103e5ffcd3516ef3c72a5f35d006a Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
ghstack-source-id: d7ff7cdb6cf57f4f199d98b600e71d41999a17d6 Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
ghstack-source-id: c52815b490a9d924af2ec40660c8f20ff1e188c8 Pull Request resolved: #42553
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for working on this!
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/46/base #42553 +/- ##
=====================================================
Coverage 81.22% 81.22%
=====================================================
Files 1837 1837
Lines 198087 198087
=====================================================
+ Hits 160893 160897 +4
+ Misses 37194 37190 -4 |
@anjali411 merged this pull request in e1ee3bf. |
Summary: Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests. Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`. Pull Request resolved: #47910 Reviewed By: zhangguanheng66 Differential Revision: D25052130 Pulled By: mruberry fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
@@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { | |||
|
|||
/* LEVEL 3 BLAS FUNCTIONS */ | |||
|
|||
#ifndef __HIP_PLATFORM_HCC__ | |||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this macro CUDA_VERSION >= 11200
intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No harm done, workaround is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xwang233 no my bad! we should fix that to avoid confusion in future
Stack from ghstack:
Ports
torch.bmm
andtorch.baddbmm
from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.Closes #24539
Differential Revision: D24893511