Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable BFloat support for gemms on arch other than ampere #50442

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 21 additions & 33 deletions aten/src/ATen/cuda/CUDABlas.cpp
Expand Up @@ -327,7 +327,6 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
Expand All @@ -343,7 +342,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
#else
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__
Expand Down Expand Up @@ -550,37 +549,26 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::BFloat16);
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 8) {
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16BF,
lda,
b,
CUDA_R_16BF,
ldb,
&fbeta,
c,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
TORCH_CUDABLAS_CHECK(cublasGemmEx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting and resetting cublas MathMode is not required if you specify CUBLAS_GEMM_DFALT_TENSOR_OP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.nvidia.com/cuda/cublas/index.html#cublasmath_t

CUBLAS_DEFAULT_MATH    This is the default and highest-performance mode that uses compute and intermediate storage precisions with at least the same number of mantissa and exponent bits as requested. Tensor Cores will be used whenever possible.
CUBLAS_TENSOR_OP_MATH    This mode is deprecated and will be removed in a future release. Allows the library to use Tensor Core operations whenever possible. For single precision GEMM routines cuBLAS will use the CUBLAS_COMPUTE_32F_FAST_16F compute type.

handle,
opa,
opb,
m,
n,
k,
&falpha,
a,
CUDA_R_16BF,
lda,
b,
CUDA_R_16BF,
ldb,
&fbeta,
c,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
}
#endif

Expand Down
74 changes: 45 additions & 29 deletions test/test_linalg.py
Expand Up @@ -24,7 +24,7 @@
skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA,
onlyCUDA)
from torch.testing import floating_and_complex_types, floating_types, all_types
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32
from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
from torch.autograd import gradcheck, gradgradcheck

Expand All @@ -36,9 +36,6 @@
if TEST_SCIPY:
import scipy

# TODO: make this common and import it
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()

class TestLinalg(TestCase):
exact_dtype = True

Expand Down Expand Up @@ -3959,12 +3956,12 @@ def call_torch_fn(*args, **kwargs):
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])

@skipCUDAIfRocm
@dtypesIfCUDA(*(torch.float, torch.double, torch.cfloat, torch.cdouble) +
# This test is disabled on CUDA 9, due to:
# See: https://github.com/pytorch/pytorch/issues/31006
((torch.half,) if torch.version.cuda and not torch.version.cuda.startswith('9.') else ()))
@dtypesIfCUDA(torch.cfloat, torch.cdouble,
*torch.testing.get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater)))
@dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool}))
def test_blas_alpha_beta_empty(self, device, dtype):
# This test is disabled on CUDA 9 due to:
# See: https://github.com/pytorch/pytorch/issues/31006
if dtype is torch.bfloat16 and self.device_type == 'xla':
# TODO (@zasdfgbnm): this causes the following error on test
# TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16:
Expand Down Expand Up @@ -4500,8 +4497,8 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8,
torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(),
*([torch.float32, torch.float64, torch.bfloat16]
if TEST_WITH_ROCM else torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)))
*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)),
include_half=(not TEST_WITH_ROCM)))
@dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_addmv(self, device, dtype):
# have to use torch.randn(...).to(bfloat16) instead of
Expand Down Expand Up @@ -4535,8 +4532,7 @@ def test_addmv(self, device, dtype):
for m, v in itertools.product(ms, vs):
self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)

@dtypesIfCUDA(*([torch.half, torch.float, torch.double]
+ ([torch.bfloat16] if TEST_WITH_ROCM else [])))
@dtypesIfCUDA(*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
@dtypes(torch.float, torch.double)
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
# tests (o, s)*(s). o is output size, s is summed size.
Expand Down Expand Up @@ -4567,7 +4563,8 @@ def _test(row_major, incx, incy, lda_tail):

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(),
*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
@tf32_on_and_off(0.05)
def test_addmm(self, device, dtype):
Expand Down Expand Up @@ -4742,19 +4739,25 @@ def test_strided_mm_bmm(self, device, dtype):
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return

num_batches = 10
M, N, O = 23, 8, 12
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32

if self.device_type == 'cpu':
is_supported = True
elif self.device_type == 'cuda':
is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM
is_supported = True
if dtype == torch.bfloat16 and self.device_type == 'cuda':
is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)

if not is_supported:
b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2))
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.bmm(b1, b2))
return

def invert_perm(p):
Expand Down Expand Up @@ -4917,21 +4920,28 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05)
def test_addbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return

num_batches = 2
M, N, O = 2, 3, 4

if self.device_type == 'cpu':
is_supported = True
if dtype == torch.bfloat16:
is_supported = True
if dtype == torch.bfloat16:
if self.device_type == 'cpu':
self.precision = 1 # 43 vs 43.75
else:
is_supported = (dtype != torch.bfloat16 or AMPERE_OR_ROCM)
else:
is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)

if not is_supported:
b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1)
b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1)
t = make_tensor((M, O), device, dtype, low=-1, high=1)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.addbmm(t, b1, b2))
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.addbmm(t, b1, b2))
return

def invert_perm(p):
Expand Down Expand Up @@ -4983,19 +4993,25 @@ def generate_tensor():
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05)
def test_baddbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
# So on PyTorch, we consider BFloat16 support on SM < 53 as
# undefined bahavior
return

num_batches = 10
M, N, O = 12, 8, 5

if self.device_type == 'cpu':
is_supported = True
elif self.device_type == 'cuda':
is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM
is_supported = True
if dtype == torch.bfloat16 and self.device_type == 'cuda':
is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)

if not is_supported:
b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1)
b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1)
t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.baddbmm(t, b1, b2))
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
lambda: torch.baddbmm(t, b1, b2))
return

def invert_perm(p):
Expand Down
4 changes: 4 additions & 0 deletions torch/testing/_internal/common_cuda.py
Expand Up @@ -15,6 +15,10 @@
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))
TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0

CUDA11OrLater = torch.version.cuda and float(torch.version.cuda) >= 11
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)

TEST_MAGMA = TEST_CUDA
if TEST_CUDA:
torch.ones(1).cuda() # has_magma shows up after cuda is initialized
Expand Down
5 changes: 3 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -18,7 +18,7 @@
from torch.testing._internal.common_device_type import \
(skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm,
expectedAlertNondeterministic, precisionOverride, onlyCPU)
from torch.testing._internal.common_cuda import tf32_is_not_fp32
from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \
(prod_single_zero, random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix,
Expand Down Expand Up @@ -883,8 +883,9 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
OpInfo('addmm',
dtypes=floating_types(),
dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16),
# BFloat16 support on CUDA requires CUDA 11 and SM53
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if tf32_is_not_fp32() else []),
*[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half),
assert_autodiffed=True,
autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
Expand Down