Skip to content

Commit

Permalink
Optionally use hipblaslt
Browse files Browse the repository at this point in the history
The hipblaslt package is not available on Fedora.
Instead of requiring the package, make it optional.
If it is found, define the preprocessor variable HIPBLASLT
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks

Signed-off-by: Tom Rix <trix@redhat.com>
  • Loading branch information
trixirt committed Feb 24, 2024
1 parent 5c5b71b commit 2c029b1
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 16 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))

#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
Expand Down Expand Up @@ -909,6 +909,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
Expand Down Expand Up @@ -1121,7 +1122,7 @@ template void gemm_and_bias(
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);

#endif
void scaled_gemm(
char transa,
char transb,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAContextLight.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
#include <cublasLt.h>
#endif

Expand Down Expand Up @@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace at::cuda {

namespace {

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
}
Expand Down Expand Up @@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
c10::DeviceIndex device = 0;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
#if ROCM_VERSION >= 50700
#ifdef HIPBLASLT
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#endif
#include <ATen/cuda/tunable/GemmRocblas.h>
Expand Down Expand Up @@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
Expand Down Expand Up @@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ enum class Activation {
GELU,
};

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
Expand Down Expand Up @@ -191,6 +191,7 @@ static bool getDisableAddmmCudaLt() {

#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
#if defined(HIPBLASLT)
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
Expand All @@ -201,6 +202,7 @@ static bool isSupportedHipLtROCmArch(int index) {
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
#endif
return false;
}
#endif
Expand All @@ -226,7 +228,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
Expand Down Expand Up @@ -269,7 +271,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}
self__sizes = self_->sizes();
} else {
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
Expand Down Expand Up @@ -320,7 +322,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
if (useLtInterface) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
Expand Down
3 changes: 3 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,9 @@ if(USE_ROCM)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
endif()
if(hipblast_FOUND)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
endif()
if(HIPBLASLT_CUSTOM_DATA_TYPE)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
endif()
Expand Down
4 changes: 2 additions & 2 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ if(HIP_FOUND)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
find_package_and_print_version(hipblaslt REQUIRED)
find_package_and_print_version(hipblaslt)
endif()
find_package_and_print_version(miopen REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
Expand Down Expand Up @@ -191,7 +191,7 @@ if(HIP_FOUND)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)

if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
if(hipblastlt_FOUND)
# check whether hipblaslt is using its own datatype
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
file(WRITE ${file} ""
Expand Down

0 comments on commit 2c029b1

Please sign in to comment.