Skip to content

Commit

Permalink
conditionally enable hipsparse const descriptors for version >= 2.4.0 (
Browse files Browse the repository at this point in the history
…#1217)

* conditionally enable hipsparse const descriptors

* update hipsparse const API version condition to 2.4.0
  • Loading branch information
jeffdaily committed Sep 29, 2023
1 parent 359c2a5 commit 6b97905
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
21 changes: 19 additions & 2 deletions aten/src/ATen/cuda/CUDASparse.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>
#if defined(USE_ROCM)
#include <hipsparse/hipsparse-version.h>
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
#endif

// cuSparse Generic API added in CUDA 10.1
// Windows support added in CUDA 11.0
Expand All @@ -25,20 +29,33 @@
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
#endif

#if defined(USE_ROCM)

// hipSparse const API added in v2.4.0
#if HIPSPARSE_VERSION >= 200400
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
#define AT_USE_HIPSPARSE_GENERIC_52_API() 0
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0

// hipSparse Generic API ROCm 5.2
#if defined(USE_ROCM) && ROCM_VERSION >= 50200
#if ROCM_VERSION >= 50200
#define AT_USE_HIPSPARSE_GENERIC_52_API() 1
#else
#define AT_USE_HIPSPARSE_GENERIC_52_API() 0
#endif

// hipSparse Generic API ROCm 5.1
#if defined(USE_ROCM) && ROCM_VERSION >= 50100
#if ROCM_VERSION >= 50100
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_GENERIC_API() 0
#endif

#endif // HIPSPARSE_VERSION >= 200400
#endif // USE_ROCM

// cuSparse Generic API spsv function was added in CUDA 11.3.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
Expand Down
19 changes: 9 additions & 10 deletions aten/src/ATen/cuda/CUDASparseDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CuSparseDescriptor {
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
};

#if AT_USE_CUSPARSE_CONST_DESCRIPTORS()
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
template <typename T, cusparseStatus_t (*destructor)(const T*)>
struct ConstCuSparseDescriptorDeleter {
void operator()(T* x) {
Expand All @@ -58,16 +58,15 @@ class ConstCuSparseDescriptor {
protected:
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
};
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS

#if defined(USE_ROCM)
// hipSPARSE doesn't define this
using cusparseMatDescr = std::remove_pointer<cusparseMatDescr_t>::type;
using cusparseDnMatDescr = std::remove_pointer<cusparseDnMatDescr_t>::type;
using cusparseDnVecDescr = std::remove_pointer<cusparseDnVecDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<cusparseSpMatDescr_t>::type;
using cusparseSpGEMMDescr = std::remove_pointer<cusparseSpGEMMDescr_t>::type;
using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
Expand Down Expand Up @@ -143,7 +142,7 @@ class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor

//AT_USE_HIPSPARSE_GENERIC_52_API() || (AT_USE_CUSPARSE_GENERIC_API() && AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS())

#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS()
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
: public ConstCuSparseDescriptor<
cusparseDnMatDescr,
Expand Down

0 comments on commit 6b97905

Please sign in to comment.