Skip to content

Commit

Permalink
Fix several test_ops cuda dtypes tests (#60922)
Browse files Browse the repository at this point in the history
Summary:
Close #60443

Pull Request resolved: #60922

Reviewed By: jdonald, iramazanli

Differential Revision: D29630122

Pulled By: mruberry

fbshipit-source-id: 441f79828860282e5849a2565facf9e7f72912e8
  • Loading branch information
xwang233 authored and facebook-github-bot committed Jul 9, 2021
1 parent 5e9bcf9 commit c966ce6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
1 change: 1 addition & 0 deletions torch/testing/_internal/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CUDA11OrLater = torch.version.cuda and distutils.version.LooseVersion(torch.version.cuda) >= "11.0"
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)

TEST_MAGMA = TEST_CUDA
if TEST_CUDA:
Expand Down
21 changes: 8 additions & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.testing._internal.common_device_type import \
(skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater
from torch.testing._internal.common_utils import \
(is_iterable_of_tensors,
random_symmetric_matrix, random_symmetric_psd_matrix,
Expand Down Expand Up @@ -6104,13 +6104,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and_complex(),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_matmul,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
# matmul does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out'),
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),
Expand Down Expand Up @@ -6627,14 +6625,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and_complex(),
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else [],
torch.complex64, torch.complex128),
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128),
backward_dtypesIfCUDA=floating_types_and(torch.float16,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else [],
torch.complex64, torch.complex128),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_matmul,
supports_out=False,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
SkipInfo('TestJit', 'test_variant_consistency_jit',),
)),
OpInfo('__rmod__',
Expand Down Expand Up @@ -6986,13 +6983,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
op=lambda tensors, equation: torch.einsum(equation, tensors),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
supports_out=False,
sample_inputs_func=sample_inputs_einsum,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
# test does not work with passing lambda for op
# there's a test `test_einsum` in `test_jit.py` to handle this case
SkipInfo('TestJit', 'test_variant_consistency_jit'),
Expand Down

0 comments on commit c966ce6

Please sign in to comment.