Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several test_ops cuda dtypes tests #60922

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/testing/_internal/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CUDA11OrLater = torch.version.cuda and distutils.version.LooseVersion(torch.version.cuda) >= "11.0"
CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.')
SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)

TEST_MAGMA = TEST_CUDA
if TEST_CUDA:
Expand Down
21 changes: 8 additions & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.testing._internal.common_device_type import \
(skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater
from torch.testing._internal.common_utils import \
(is_iterable_of_tensors,
random_symmetric_matrix, random_symmetric_psd_matrix,
Expand Down Expand Up @@ -6069,13 +6069,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and_complex(),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @mruberry . Seems like the CI is happy with the new flags

assert_autodiffed=True,
sample_inputs_func=sample_inputs_matmul,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
# matmul does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out'),
SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'),
Expand Down Expand Up @@ -6545,14 +6543,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and_complex(),
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else [],
torch.complex64, torch.complex128),
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128),
backward_dtypesIfCUDA=floating_types_and(torch.float16,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else [],
torch.complex64, torch.complex128),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_matmul,
supports_out=False,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
SkipInfo('TestJit', 'test_variant_consistency_jit',),
)),
OpInfo('__rmod__',
Expand Down Expand Up @@ -6904,13 +6901,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
op=lambda tensors, equation: torch.einsum(equation, tensors),
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
supports_out=False,
sample_inputs_func=sample_inputs_einsum,
skips=(
# FIXME: bfloat16 backward support likely depends on CUDA11+
# and SM53+
SkipInfo('TestCommon', 'test_dtypes', active_if=IS_WINDOWS),
# test does not work with passing lambda for op
# there's a test `test_einsum` in `test_jit.py` to handle this case
SkipInfo('TestJit', 'test_variant_consistency_jit'),
Expand Down