Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally use hipblaslt #120551

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <c10/util/irange.h>

#ifdef USE_ROCM
#if ROCM_VERSION >= 60000
#ifdef HIPBLASLT
#include <hipblaslt/hipblaslt-ext.hpp>
#endif
// until hipblas has an API to accept flags, we must use rocblas here
Expand Down Expand Up @@ -788,7 +788,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))

#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
Expand Down Expand Up @@ -919,6 +919,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
Expand Down Expand Up @@ -1131,7 +1132,7 @@ template void gemm_and_bias(
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);

#endif
void scaled_gemm(
char transa,
char transb,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAContextLight.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
#include <cublasLt.h>
#endif

Expand Down Expand Up @@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace at::cuda {

namespace {

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
}
Expand Down Expand Up @@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
c10::DeviceIndex device = 0;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
#if ROCM_VERSION >= 50700
#ifdef HIPBLASLT
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#endif
#include <ATen/cuda/tunable/GemmRocblas.h>
Expand Down Expand Up @@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
Expand Down Expand Up @@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
Expand Down
14 changes: 8 additions & 6 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ enum class Activation {
GELU,
};

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
Expand Down Expand Up @@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {

#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
#if defined(HIPBLASLT)
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
Expand All @@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
#endif
return false;
}
#endif
Expand All @@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
Expand Down Expand Up @@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}
self__sizes = self_->sizes();
} else {
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
#if defined(USE_ROCM) && defined(HIPBLASLT)
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
Expand Down Expand Up @@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
if (useLtInterface) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
Expand Down Expand Up @@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});

#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
Expand Down Expand Up @@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 60000
#if defined(USE_ROCM) && defined(HIPBLASLT)
// rocm's hipblaslt does not yet support amax, so calculate separately
auto out_float32 = out.to(kFloat);
out_float32.abs_();
Expand Down
3 changes: 3 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,9 @@ if(USE_ROCM)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
endif()
if(hipblast_FOUND)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
endif()
if(HIPBLASLT_CUSTOM_DATA_TYPE)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
endif()
Expand Down
4 changes: 2 additions & 2 deletions cmake/public/LoadHIP.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ if(HIP_FOUND)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
find_package_and_print_version(hipblaslt REQUIRED)
find_package_and_print_version(hipblaslt)
endif()
find_package_and_print_version(miopen REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
Expand Down Expand Up @@ -191,7 +191,7 @@ if(HIP_FOUND)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)

if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
if(hipblastlt_FOUND)
# check whether hipblaslt is using its own datatype
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
file(WRITE ${file} ""
Expand Down