Skip to content

Commit

Permalink
ROCm enable sparse_sampled_addmm (#86401)
Browse files Browse the repository at this point in the history
Enables:
test_comprehensive_sparse_sampled_addmm_cuda_complex128
test_comprehensive_sparse_sampled_addmm_cuda_complex64
test_comprehensive_sparse_sampled_addmm_cuda_float32
test_comprehensive_sparse_sampled_addmm_cuda_float64
test_dispatch_meta_sparse_sampled_addmm_cuda_complex128
test_dispatch_meta_sparse_sampled_addmm_cuda_complex64
test_dispatch_meta_sparse_sampled_addmm_cuda_float32
test_dispatch_meta_sparse_sampled_addmm_cuda_float64
test_meta_sparse_sampled_addmm_cuda_complex128
test_meta_sparse_sampled_addmm_cuda_complex64
test_meta_sparse_sampled_addmm_cuda_float32
test_meta_sparse_sampled_addmm_cuda_float64

Pull Request resolved: #86401
Approved by: https://github.com/ngimel
  • Loading branch information
jpvillam-amd authored and pytorchmergebot committed Oct 26, 2022
1 parent 123b103 commit 38dd4cb
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ void sampled_addmm_out_sparse_csr(
const Scalar& beta,
const Scalar& alpha,
const at::sparse_csr::SparseCsrTensor& C) {
#if !AT_USE_CUSPARSE_GENERIC_SDDMM()
#if !(AT_USE_CUSPARSE_GENERIC_SDDMM() || AT_USE_HIPSPARSE_GENERIC_52_API())
TORCH_CHECK(
false,
"Calling sampled_addmm with sparse GPU tensors requires compiling ",
Expand Down
7 changes: 7 additions & 0 deletions torch/testing/_internal/common_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def _get_torch_cuda_version():
cuda_version = str(torch.version.cuda)
return tuple(int(x) for x in cuda_version.split("."))

def _get_torch_rocm_version():
if not TEST_WITH_ROCM:
return (0, 0)
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
return tuple(int(x) for x in rocm_version.split("."))

def _check_cusparse_generic_available():
version = _get_torch_cuda_version()
min_supported_version = (10, 1)
Expand Down
6 changes: 4 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
CUDA11OrLater, SM53OrLater, SM60OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version)
_get_torch_cuda_version, _get_torch_rocm_version)
from torch.testing._internal.common_utils import (
make_fullrank_matrices_with_distinct_singular_values,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
Expand Down Expand Up @@ -9392,7 +9392,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_autograd=True,
sample_inputs_func=sample_inputs_sparse_sampled_addmm,
decorators=[
skipCUDAIf(_get_torch_cuda_version() < (11, 3), "cusparseSDDMM was added in 11.2.1"),
skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
or (_get_torch_rocm_version() >= (5, 2))),
"cusparseSDDMM was added in 11.2.1"),
skipCPUIfNoMklSparse, ],
skips=(
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
Expand Down
4 changes: 4 additions & 0 deletions torch/utils/hipify/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7920,6 +7920,9 @@
("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPARSE)),
("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSDDMM_bufferSize", ("hipsparseSDDMM_bufferSize", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSDDMM_preprocess", ("hipsparseSDDMM_preprocess", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSDDMM", ("hipsparseSDDMM", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPARSE)),
("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPARSE)),
Expand Down Expand Up @@ -7947,6 +7950,7 @@
("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPARSE)),
("CUSPARSE_CSRMM_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPARSE)),
("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)),
("CUSPARSE_SDDMM_ALG_DEFAULT", ("HIPSPARSE_SDDMM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPARSE)),
(
"CUSPARSE_STATUS_SUCCESS",
("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE),
Expand Down

0 comments on commit 38dd4cb

Please sign in to comment.