Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port bmm and baddbmm from TH to ATen #42553

Closed
wants to merge 30 commits into from

Conversation

anjali411
Copy link
Contributor

@anjali411 anjali411 commented Aug 4, 2020

Stack from ghstack:

Ports torch.bmm and torch.baddbmm from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.

Closes #24539

Differential Revision: D24893511

anjali411 added a commit that referenced this pull request Aug 4, 2020
ghstack-source-id: b737e9836ecfa12fcf8e6fd7d29421ca16eb524b
Pull Request resolved: #42553
@dr-ci
Copy link

dr-ci bot commented Aug 4, 2020

💊 CI failures summary and remediations

As of commit 909b366 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 162 times.

@anjali411 anjali411 changed the title Port bmm and baddbmm from TH to ATen [WIP] Port bmm and baddbmm from TH to ATen Aug 4, 2020
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 3eba02aebf71537dd1af8841ec2f4cb2f94bf279
Pull Request resolved: #42553
aten/src/THC/THCBlas.cu Outdated Show resolved Hide resolved
aten/src/THC/THCBlas.cu Outdated Show resolved Hide resolved
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 047dffe5ba8d7ab12d820bced1072315ed8f3b04
Pull Request resolved: #42553
@anjali411 anjali411 changed the title [WIP] Port bmm and baddbmm from TH to ATen Port bmm and baddbmm from TH to ATen Aug 6, 2020
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not finished yet. Will post more comment later.

aten/src/ATen/cuda/CUDABlas.h Outdated Show resolved Hide resolved
aten/src/ATen/cuda/CUDABlas.cpp Show resolved Hide resolved
aten/src/ATen/cuda/CUDABlas.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/cuda/CUDABlas.cpp Outdated Show resolved Hide resolved
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 7, 2020
ghstack-source-id: f45d4e1d89b776d52fa788d29e2f24a05d143ba4
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 12, 2020
ghstack-source-id: 61c04d480b16d92ad66fe7246773b48f5f5302fc
Pull Request resolved: #42553
aten/src/ATen/cuda/CUDABlas.h Outdated Show resolved Hide resolved
aten/src/ATen/cuda/CUDABlas.h Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated Show resolved Hide resolved
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 13, 2020
ghstack-source-id: b43d308cf89ef6b5f42c6adc98118f9facabeb83
Pull Request resolved: #42553
@ngimel
Copy link
Collaborator

ngimel commented Aug 13, 2020

Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices).

@zasdfgbnm
Copy link
Collaborator

@anjali411 Could you please rebase? Looks like there are lots of flaky tests.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 2a2540ae6445b3f513e83a5a355d6099532621de
Pull Request resolved: #42553
test/test_torch.py Outdated Show resolved Hide resolved
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 3469afd70f10593355aa5e7f85fb983c47ea5714
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
test/test_torch.py Outdated Show resolved Hide resolved
@@ -17938,24 +17938,16 @@ def test_strided_mm_bmm(self, device, dtype):

@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @ngimel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498

So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16 and torch.bfloat16 in a follow-up PR and leave this one as is.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 0bbc3aae1e6103e5ffcd3516ef3c72a5f35d006a
Pull Request resolved: #42553
test/test_torch.py Outdated Show resolved Hide resolved
test/test_torch.py Outdated Show resolved Hide resolved
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: d7ff7cdb6cf57f4f199d98b600e71d41999a17d6
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: c52815b490a9d924af2ec40660c8f20ff1e188c8
Pull Request resolved: #42553
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this!

@codecov
Copy link

codecov bot commented Nov 12, 2020

Codecov Report

Merging #42553 (909b366) into gh/anjali411/46/base (4738672) will increase coverage by 0.00%.
The diff coverage is 0.00%.

@@                  Coverage Diff                  @@
##           gh/anjali411/46/base   #42553   +/-   ##
=====================================================
  Coverage                 81.22%   81.22%           
=====================================================
  Files                      1837     1837           
  Lines                    198087   198087           
=====================================================
+ Hits                     160893   160897    +4     
+ Misses                    37194    37190    -4     

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in e1ee3bf.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/46/head branch November 16, 2020 15:17
facebook-github-bot pushed a commit that referenced this pull request Nov 18, 2020
Summary:
Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests.

Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`.

Pull Request resolved: #47910

Reviewed By: zhangguanheng66

Differential Revision: D25052130

Pulled By: mruberry

fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
@@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {

/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No harm done, workaround is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 no my bad! we should fix that to avoid confusion in future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants