From fdff9920f6b7d5b4c36322b331069c40d2675d69 Mon Sep 17 00:00:00 2001 From: Michael Ranieri Date: Thu, 9 May 2024 01:54:25 +0000 Subject: [PATCH] [pytorch] fix blasLt on windows (#125792) Summary: It seems like required functions are not available due to `_MSC_VER` guard. Does anyone have more context why this functionality has been disabled for windows? I'm also unsure how this currently compiles in OSS land on windows, as there doesn't seem to be any preprocessor protection around `scaled_gemm` getting pulled in. Test Plan: Fix compilation errors like this ``` C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): error C2039: 'scaled_gemm': is not a member of 'at::cuda::blas' C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\CUDABlas.h(19): note: see declaration of 'at::cuda::blas' C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): note: the template instantiation context (the oldest one first) is C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(71): note: while compiling class template 'at::cuda::tunable::DefaultScaledGemmOp' Action failed: fbsource//xplat/caffe2:ATen_cuda_lib_ovrsource (cxx_compile aten/src/ATen/native/cuda/Blas.cpp) ``` Differential Revision: D57087985 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125792 Approved by: https://github.com/malfet, https://github.com/eqy --- aten/src/ATen/cuda/CUDABlas.cpp | 8 ++-- aten/src/ATen/cuda/CUDABlas.h | 2 +- aten/src/ATen/cuda/CUDAContextLight.h | 4 +- aten/src/ATen/cuda/CublasHandlePool.cpp | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 58 +++++++++++++++++-------- 5 files changed, 48 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 3efcd23df5a58..bfe6a02741ede 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -236,7 +236,7 @@ namespace at::cuda::blas { CUDABLAS_NONNEGINT_CHECK(bgemm, num_batches); \ } while (0) -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000 // only for rocm 5.7 where we first supported hipblaslt, it was difficult @@ -375,7 +375,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< template inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) cudaDataType_t abcType = CUDA_R_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; @@ -1235,7 +1235,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) template void gemm_and_bias( @@ -1745,7 +1745,7 @@ void int8_gemm( TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above"); #endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) } -#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not. #if defined(USE_ROCM) && ROCM_VERSION < 50600 diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index eb12bb350c598..24aad7678ec49 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -82,7 +82,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) enum GEMMAndBiasActivationEpilogue { None, RELU, diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 4ec35f59a2108..60d09dfaee169 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -9,7 +9,7 @@ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also // added bf16 support -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) #include #endif @@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); /* Handles */ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); #endif diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 95d1ba2fb4ae9..6a03803720b5a 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -191,7 +191,7 @@ cublasHandle_t getCurrentCUDABlasHandle() { return handle; } -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) cublasLtHandle_t getCurrentCUDABlasLtHandle() { #ifdef USE_ROCM c10::DeviceIndex device = 0; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 7929ebd8a255c..df6f470916428 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -157,7 +157,7 @@ enum class Activation { GELU, }; -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) { switch (a) { case Activation::None: @@ -236,7 +236,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned self_; if (&result != &self) { -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700)) // Strangely, if mat2 has only 1 row or column, we get // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] @@ -334,8 +334,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700)) if (useLtInterface) { +#if defined(USE_ROCM) AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -353,28 +354,49 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.lda, args.matb->const_data_ptr(), args.ldb, -#if defined(USE_ROCM) // This condition is needed for mm case on ROCm for hipblasLt path. // Passing the bias ptr as null to avoid accuracy issues for mm case. (&result != &self) ? self.const_data_ptr() : nullptr, -#else - self.const_data_ptr(), -#endif args.result->data_ptr(), args.result_ld, -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM) activation_to_gemm_and_blas_arg(activation) + ); + }); #else - // GELU is not supported (and does not compile!) prior - // to CUDA 11.4. Have observed accuracy issues with - // GELU epilogue in 11.4; disabling the GELU epilogue - // path for CUDA version < 11.8. - activation != Activation::GELU - ? activation_to_gemm_and_blas_arg(activation) - : cuda::blas::GEMMAndBiasActivationEpilogue::None + auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); +#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) + // GELU is not supported (and does not compile!) prior + // to CUDA 11.4. Have observed accuracy issues with + // GELU epilogue in 11.4; disabling the GELU epilogue + // path for CUDA version < 11.8. + if (activation == Activation::GELU) + activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None; #endif + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "addmm_cuda_lt", + [&] { + at::cuda::blas::gemm_and_bias( + args.transa == 't', + args.transb == 't', + args.m, + args.n, + args.k, + alpha.to>(), + args.mata->const_data_ptr(), + args.lda, + args.matb->const_data_ptr(), + args.ldb, + self.const_data_ptr(), + args.result->data_ptr(), + args.result_ld, + activation_epilogue ); }); +#endif } else #endif { @@ -748,7 +770,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous."); -#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) +#if (!defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000)) cublasCommonArgs args(self, mat2, result); at::cuda::blas::int8_gemm( @@ -768,7 +790,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) result.copy_(*args.result); } #else -#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) +#if !defined(USE_ROCM) && defined(CUDA_VERSION) TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION); #else TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform."); @@ -888,7 +910,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); -#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000)) cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");