From 38dd4cbdf1dc982492a0cc94a54eb2f71c31d8fe Mon Sep 17 00:00:00 2001 From: jpvillam Date: Wed, 26 Oct 2022 19:39:21 +0000 Subject: [PATCH] ROCm enable sparse_sampled_addmm (#86401) Enables: test_comprehensive_sparse_sampled_addmm_cuda_complex128 test_comprehensive_sparse_sampled_addmm_cuda_complex64 test_comprehensive_sparse_sampled_addmm_cuda_float32 test_comprehensive_sparse_sampled_addmm_cuda_float64 test_dispatch_meta_sparse_sampled_addmm_cuda_complex128 test_dispatch_meta_sparse_sampled_addmm_cuda_complex64 test_dispatch_meta_sparse_sampled_addmm_cuda_float32 test_dispatch_meta_sparse_sampled_addmm_cuda_float64 test_meta_sparse_sampled_addmm_cuda_complex128 test_meta_sparse_sampled_addmm_cuda_complex64 test_meta_sparse_sampled_addmm_cuda_float32 test_meta_sparse_sampled_addmm_cuda_float64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86401 Approved by: https://github.com/ngimel --- aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp | 2 +- torch/testing/_internal/common_cuda.py | 7 +++++++ torch/testing/_internal/common_methods_invocations.py | 6 ++++-- torch/utils/hipify/cuda_to_hip_mappings.py | 4 ++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp index 379640bad56b9..833fd41eb6a02 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp @@ -1401,7 +1401,7 @@ void sampled_addmm_out_sparse_csr( const Scalar& beta, const Scalar& alpha, const at::sparse_csr::SparseCsrTensor& C) { -#if !AT_USE_CUSPARSE_GENERIC_SDDMM() +#if !(AT_USE_CUSPARSE_GENERIC_SDDMM() || AT_USE_HIPSPARSE_GENERIC_52_API()) TORCH_CHECK( false, "Calling sampled_addmm with sparse GPU tensors requires compiling ", diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 1ee8e40ebd062..b226c7af58e51 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -173,6 +173,13 @@ def _get_torch_cuda_version(): cuda_version = str(torch.version.cuda) return tuple(int(x) for x in cuda_version.split(".")) +def _get_torch_rocm_version(): + if not TEST_WITH_ROCM: + return (0, 0) + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + return tuple(int(x) for x in rocm_version.split(".")) + def _check_cusparse_generic_available(): version = _get_torch_cuda_version() min_supported_version = (10, 1) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 00f454bdf454a..94c12f5bc93d0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -26,7 +26,7 @@ toleranceOverride, tol) from torch.testing._internal.common_cuda import ( CUDA11OrLater, SM53OrLater, SM60OrLater, with_tf32_off, TEST_CUDNN, - _get_torch_cuda_version) + _get_torch_cuda_version, _get_torch_rocm_version) from torch.testing._internal.common_utils import ( make_fullrank_matrices_with_distinct_singular_values, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, @@ -9392,7 +9392,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_autograd=True, sample_inputs_func=sample_inputs_sparse_sampled_addmm, decorators=[ - skipCUDAIf(_get_torch_cuda_version() < (11, 3), "cusparseSDDMM was added in 11.2.1"), + skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) + or (_get_torch_rocm_version() >= (5, 2))), + "cusparseSDDMM was added in 11.2.1"), skipCPUIfNoMklSparse, ], skips=( # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 3b3a047a7f9b5..33e14e9e0572e 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -7920,6 +7920,9 @@ ("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPARSE)), ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPARSE)), ("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseSDDMM_bufferSize", ("hipsparseSDDMM_bufferSize", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseSDDMM_preprocess", ("hipsparseSDDMM_preprocess", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseSDDMM", ("hipsparseSDDMM", CONV_MATH_FUNC, API_SPARSE)), ("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPARSE)), ("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPARSE)), ("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPARSE)), @@ -7947,6 +7950,7 @@ ("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_CSRMM_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)), + ("CUSPARSE_SDDMM_ALG_DEFAULT", ("HIPSPARSE_SDDMM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)), ( "CUSPARSE_STATUS_SUCCESS", ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE),