New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable BFloat support for gemms on arch other than ampere #50442
Conversation
💊 CI failures summary and remediationsAs of commit c33a608 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_windows_vs2019_py36_cuda11.1_test2 (1/1)Step: "Test" (full log | diagnosis details | 🔁 rerun)
|
…torch into ci-all/matmul-bf16-non-ampere
…l-bf16-non-ampere
This should be ready. Test failures are unrelated. |
} else { | ||
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); | ||
} | ||
TORCH_CUDABLAS_CHECK(cublasGemmEx( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting and resetting cublas MathMode is not required if you specify CUBLAS_GEMM_DFALT_TENSOR_OP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://docs.nvidia.com/cuda/cublas/index.html#cublasmath_t
CUBLAS_DEFAULT_MATH This is the default and highest-performance mode that uses compute and intermediate storage precisions with at least the same number of mantissa and exponent bits as requested. Tensor Cores will be used whenever possible.
CUBLAS_TENSOR_OP_MATH This mode is deprecated and will be removed in a future release. Allows the library to use Tensor Core operations whenever possible. For single precision GEMM routines cuBLAS will use the CUBLAS_COMPUTE_32F_FAST_16F compute type.
test/test_linalg.py
Outdated
b1 = torch.randn(num_batches, M, N, device=device).to(dtype) | ||
b2 = torch.randn(num_batches, N, O, device=device).to(dtype) | ||
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) | ||
if not is_cuda_bfloat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_supported=False
, is_cuda_bfloat=False
is an impossible situation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some ops are supported on SM52, and some are not. I don't think it worth the maintenance effort to write a clear list on which is supported on which SM. So what I implemented here is:
SM >= 53 ---> supported
SM < 53 ---> undefined behavior
…l-bf16-non-ampere
Codecov Report
@@ Coverage Diff @@
## master #50442 +/- ##
==========================================
- Coverage 81.00% 81.00% -0.01%
==========================================
Files 1916 1916
Lines 209481 209484 +3
==========================================
+ Hits 169690 169692 +2
- Misses 39791 39792 +1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Thanks @zasdfgbnm!
Would you just rebase this? Sorry PyTorch is especially popular these days.
@mruberry rebased |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Internal builds are failing with:
We typically use LooseVersion for version comparisons. See
for an example. |
…l-bf16-non-ampere
@mruberry fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #{issue number}