Skip to content

Commit

Permalink
[pytorch] fix blasLt on windows (#125792)
Browse files Browse the repository at this point in the history
Summary:
It seems like required functions are not available due to `_MSC_VER` guard. Does anyone have more context why this functionality has been disabled for windows?

I'm also unsure how this currently compiles in OSS land on windows, as there doesn't seem to be any preprocessor protection around `scaled_gemm` getting pulled in.

Test Plan:
Fix compilation errors like this
```
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): error C2039: 'scaled_gemm': is not a member of 'at::cuda::blas'
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\CUDABlas.h(19): note: see declaration of 'at::cuda::blas'
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(74): note: the template instantiation context (the oldest one first) is
C:\open\fbsource\xplat\caffe2\aten\src\ATen\cuda\tunable\TunableGemm.h(71): note: while compiling class template 'at::cuda::tunable::DefaultScaledGemmOp'
Action failed: fbsource//xplat/caffe2:ATen_cuda_lib_ovrsource (cxx_compile aten/src/ATen/native/cuda/Blas.cpp)
```

Differential Revision: D57087985

Pull Request resolved: #125792
Approved by: https://github.com/malfet, https://github.com/eqy
  • Loading branch information
EscapeZero authored and pytorchmergebot committed May 9, 2024
1 parent 902a74c commit fdff992
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 26 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ namespace at::cuda::blas {
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
Expand Down Expand Up @@ -375,7 +375,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<

template <typename Dtype>
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
Expand Down Expand Up @@ -1235,7 +1235,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

template <typename Dtype>
void gemm_and_bias(
Expand Down Expand Up @@ -1745,7 +1745,7 @@ void int8_gemm(
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
}
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION < 50600
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAContextLight.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#include <cublasLt.h>
#endif

Expand Down Expand Up @@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
c10::DeviceIndex device = 0;
Expand Down
58 changes: 40 additions & 18 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ enum class Activation {
GELU,
};

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
Expand Down Expand Up @@ -236,7 +236,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
Expand Down Expand Up @@ -334,8 +334,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());

#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 50700))
if (useLtInterface) {
#if defined(USE_ROCM)
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand All @@ -353,28 +354,49 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
#if defined(USE_ROCM)
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
#else
self.const_data_ptr<scalar_t>(),
#endif
args.result->data_ptr<scalar_t>(),
args.result_ld,
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || defined(USE_ROCM)
activation_to_gemm_and_blas_arg(activation)
);
});
#else
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path for CUDA version < 11.8.
activation != Activation::GELU
? activation_to_gemm_and_blas_arg(activation)
: cuda::blas::GEMMAndBiasActivationEpilogue::None
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080))
// GELU is not supported (and does not compile!) prior
// to CUDA 11.4. Have observed accuracy issues with
// GELU epilogue in 11.4; disabling the GELU epilogue
// path for CUDA version < 11.8.
if (activation == Activation::GELU)
activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None;
#endif

AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
});
#endif
} else
#endif
{
Expand Down Expand Up @@ -748,7 +770,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)

TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");

#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if (!defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
cublasCommonArgs args(self, mat2, result);

at::cuda::blas::int8_gemm(
Expand All @@ -768,7 +790,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
result.copy_(*args.result);
}
#else
#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION)
#if !defined(USE_ROCM) && defined(CUDA_VERSION)
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for CUDA ", CUDA_VERSION);
#else
TORCH_CHECK(false, "_int_mm_out_cuda not compiled for this platform.");
Expand Down Expand Up @@ -888,7 +910,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});

#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
#if !defined(USE_ROCM) || (defined(USE_ROCM) && (ROCM_VERSION >= 60000))
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
Expand Down

0 comments on commit fdff992

Please sign in to comment.