diff --git a/BUILD.bazel b/BUILD.bazel index 3c6d5c248207..1b84b21c6fb4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -376,7 +376,6 @@ filegroup( "aten/src/THC/THCTensorCopy.cu.cc", "aten/src/THC/THCTensorIndex.cu.cc", "aten/src/THC/THCTensorMath.cu.cc", - "aten/src/THC/THCTensorMathBlas.cu.cc", "aten/src/THC/THCTensorMathMagma.cu.cc", "aten/src/THC/THCTensorMathPairwise.cu.cc", "aten/src/THC/THCTensorMathReduce.cu.cc", diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 7325b8eb88f0..3fa4be81043e 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -361,42 +361,6 @@ - real alpha ]] [[ -[[ - name: _th_bmm - cuda_bfloat16: True - cname: baddbmm - variants: - - function - backends: - - CUDA - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - argument 0 - - THTensor* self - - THTensor* mat2 - - CONSTANT AS_REAL(0) - - CONSTANT AS_REAL(1) -]] -[[ - name: _th_baddbmm - cuda_bfloat16: True - cname: baddbmm - variants: - - function - backends: - - CUDA - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - - THTensor* batch1 - - THTensor* batch2 - - real beta - - real alpha -]] [[ name: _th_gels cname: gels diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index f59cbed39abb..668838877123 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -517,17 +517,16 @@ std::vector compute_bmm_outnames( } std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* batch1, - TensorImpl* batch2, - TensorImpl* bias) { - if (!impl::has_names(result) && !impl::has_names(batch1) && - !impl::has_names(batch2) && !impl::has_names(bias)) { + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias) { + if (!result.has_names() && !self.has_names() + && !other.has_names() && !bias.has_names()) { return {}; } - auto bmm_names = compute_matmul_outnames( - impl::get_names(batch1), impl::get_names(batch2)); - auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names); + auto bmm_names = compute_matmul_outnames(self.names(), other.names()); + auto baddbmm_names = unify_from_right(bias.names(), bmm_names); return baddbmm_names; } diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index 6777f39f7fcf..47dfd580a189 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv( CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); CAFFE2_API std::vector compute_baddbmm_outnames( - TensorImpl* result, - TensorImpl* self, - TensorImpl* other, - TensorImpl* bias); + Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias); CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 956569e8386b..ec4d67fb5ca7 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -140,8 +140,6 @@ _(aten, _tan) \ _(aten, _tanh) \ _(aten, _tanh_backward) \ _(aten, _tanh_forward) \ -_(aten, _th_baddbmm) \ -_(aten, _th_bmm) \ _(aten, _th_get_device) \ _(aten, _th_kthvalue) \ _(aten, _th_median) \ @@ -671,7 +669,6 @@ _(aten, tan) \ _(aten, tanh) \ _(aten, tensor) \ _(aten, tensordot) \ -_(aten, th_addmm) \ _(aten, th_clone) \ _(aten, th_norm) \ _(aten, th_pow) \ diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index cc678ce5d794..0959468ee9c6 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -143,6 +143,110 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { CUDABLAS_POSINT_CHECK(gemm, ldc); \ } while (0) +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + GEMM_CHECK_ARGVALUES(double); + TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + GEMM_CHECK_ARGVALUES(float); + TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched( + handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); +} + +#ifndef __HIP_PLATFORM_HCC__ + template <> + void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + GEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); + } + + template <> + void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + GEMM_CHECK_ARGVALUES(c10::complex); + TORCH_CUDABLAS_CHECK(cublasCgemm3mStridedBatched( + handle, opa, opb, m, n, k, reinterpret_cast(&alpha), reinterpret_cast(a), + lda, stridea, reinterpret_cast(b), ldb, strideb, reinterpret_cast(&beta), + reinterpret_cast(c), ldc, stridec, num_batches)); + } +#endif + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + GEMM_CHECK_ARGVALUES(at::Half); + float fAlpha = alpha; + float fBeta = beta; +#ifdef __HIP_PLATFORM_HCC__ + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&fAlpha, a, rocblas_datatype_f16_r, (int)lda, stridea, + b, rocblas_datatype_f16_r, (int)ldb, strideb, + (void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC, + c, rocblas_datatype_f16_r, (int)ldc, stridec, + (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0)); +#else +#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); +#endif // CUDA_VERSION < 11000 + TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx( + handle, opa, opb, m, n, k, (void*)(&alpha), a, CUDA_R_16F, lda, stridea, + b, CUDA_R_16F, ldb, strideb, (void*)(&beta), c, CUDA_R_16F, ldc, stridec, num_batches, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 + // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH + // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. + TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); +#endif // CUDA_VERSION < 11000 +#endif // __HIP_PLATFORM_HCC__ +} + +#ifdef __HIP_PLATFORM_HCC__ +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t opa = _cublasOpFromChar(transa); + cublasOperation_t opb = _cublasOpFromChar(transb); + float falpha = alpha; + float fbeta = beta; + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, + (void*)&fAlpha, a, rocblas_datatype_bf16_r, (int)lda, strideA, + b, rocblas_datatype_bf16_r, (int)ldb, strideB, + (void*)&fBeta, c, rocblas_datatype_bf16_r, (int)ldc, strideC, + c, rocblas_datatype_bf16_r, (int)ldc, strideC, + (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, + 0, 0, NULL, NULL)); +} +#endif + template <> void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { globalContext().alertCuBLASConfigNotDeterministic(); diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 60696e6674f8..9a8f97e42486 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -69,6 +69,34 @@ template <> void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); #endif +#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ + char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \ + const Dtype *a, int64_t lda, int64_t stridea, \ + const Dtype *b, int64_t ldb, int64_t strideb, \ + Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches + +template +inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)); +#ifndef __HIP_PLATFORM_HCC__ + template <> + void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); + template <> + void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +#endif +#ifdef __HIP_PLATFORM_HCC__ + template <> + void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +#endif +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); + /* LEVEL 2 BLAS FUNCTIONS */ #define CUDABLAS_GEMV_ARGTYPES(Dtype) \ @@ -85,17 +113,17 @@ void gemv(CUDABLAS_GEMV_ARGTYPES(double)); template <> void gemv(CUDABLAS_GEMV_ARGTYPES(float)); #ifndef __HIP_PLATFORM_HCC__ -template <> -void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); -template <> -void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); + template <> + void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); + template <> + void gemv>(CUDABLAS_GEMV_ARGTYPES(c10::complex)); #endif -template <> -void gemv(CUDABLAS_GEMV_ARGTYPES(at::Half)); #ifdef __HIP_PLATFORM_HCC__ -template <> -void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); + template <> + void gemv(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)); #endif +template <> +void gemv(CUDABLAS_GEMV_ARGTYPES(at::Half)); template void ger( diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 439699f322dd..163b88ea7721 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -3,33 +3,11 @@ #include #include -namespace at { namespace native { - -Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); - return legacy::cuda::_th_baddbmm(b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm_out"); - return legacy::cuda::_th_baddbmm_out(result, b_self, batch1, batch2, beta, alpha); -} - -Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { - return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha); -} +#include +#include +#include -Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) { - result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) }); - return legacy::cuda::_th_bmm_out(result, batch1, batch2); -} - -Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) { - Tensor result = at::empty({0}, self.options()); - return native::bmm_out_cuda(result, self, mat2); -} +namespace at { namespace native { Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { Tensor tensor_; @@ -50,6 +28,33 @@ Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) { return tensor_; } +Tensor prepare_batch_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int m, int n) { + IntArrayRef tensor_strides = tensor.strides(); + Tensor tensor_; + + if (tensor_strides[transpose_result ? 2 : 1] == 1 && + (tensor_strides[transpose_result ? 1 : 2] >= std::max(1, m))) { + transpose_tensor = false; + tensor_ = tensor; + ld_tensor = tensor_strides[transpose_result ? 1 : 2]; + } else if ((tensor_strides[transpose_result ? 1 : 2] == 1) && + (tensor_strides[transpose_result ? 2 : 1] >= std::max(1, n))) { + transpose_tensor = true; + tensor_ = tensor; + ld_tensor = tensor_strides[transpose_result ? 2 : 1]; + } else { + transpose_tensor = !transpose_result; + if (tensor.is_contiguous()) { + tensor_ = tensor; + } else { + tensor_ = tensor.clone(at::MemoryFormat::Contiguous); + } + ld_tensor = tensor_strides[1]; + } + + return tensor_; +} + namespace { Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { @@ -130,6 +135,88 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma return result; } +Tensor& baddmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor"); + TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor"); + + IntArrayRef batch1_sizes = batch1.sizes(); + IntArrayRef batch2_sizes = batch2.sizes(); + IntArrayRef self_sizes = self.sizes(); + + TORCH_CHECK(self_sizes[0] == batch1_sizes[0], "self dim 0 must match batch1 dim 0"); + TORCH_CHECK(self_sizes[0] == batch2_sizes[0], "self dim 0 must match batch2 dim 0"); + TORCH_CHECK(self_sizes[1] == batch1_sizes[1], "self dim 1 must match batch1 dim 1"); + TORCH_CHECK(self_sizes[2] == batch2_sizes[2], "self dim 2 must match batch2 dim 2"); + TORCH_CHECK(batch1_sizes[2] == batch2_sizes[1], "batch1 dim 2 must match batch2 dim 1"); + + if (!result.is_same(self)) { + at::native::resize_as_(result, self); + if ((beta.isComplex() && beta.to>() != 0.0) || + (beta.to() != 0.0)) { + at::native::copy_(result, self); + } + } + + bool transpose_result = false; + bool transpose_batch1, transpose_batch2; + int64_t lda, ldb, ldc; + Tensor result_; + IntArrayRef result_strides = result.strides(); + IntArrayRef result_sizes = result.sizes(); + + if ((result_strides[1] == 1) && + ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[0])))) { + result_ = result; + } else if ((result_strides[2] == 1) && + (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { + transpose_result = true; + result_ = result; + } else { + auto result_ = result.transpose(1, 2).clone(); + result_ = result_.transpose(1, 2); + } + + ldc = result_.stride(transpose_result ? 1 : 2); + int64_t m = result_sizes[transpose_result ? 2 : 1]; + int64_t n = result_sizes[transpose_result ? 1 : 2]; + int64_t k = batch1.size(transpose_result ? 1 : 2); + + Tensor batch1_ = transpose_result ? batch2 : batch1; + Tensor batch2_ = transpose_result ? batch1 : batch2; + batch1_ = prepare_batch_matrix_for_cublas(batch1_, transpose_batch1, lda, transpose_result, m, k); + batch2_ = prepare_batch_matrix_for_cublas(batch2_, transpose_batch2, ldb, transpose_result, k, n); + + IntArrayRef result__sizes = result.sizes(); + m = result__sizes[transpose_result ? 2 : 1]; + n = result__sizes[transpose_result ? 1 : 2]; + k = batch1_.size(transpose_result ? 1 : 2); + int64_t num_batches = result__sizes[0]; + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddmm_cuda", [&] { + scalar_t alpha_val = alpha.to(); + scalar_t beta_val = beta.to(); + scalar_t* batch1_ptr = batch1_.data_ptr(); + scalar_t* batch2_ptr = batch2_.data_ptr(); + scalar_t* result_ptr = result_.data_ptr(); + at::cuda::blas::bgemm( + transpose_batch1 ? 't' : 'n', + transpose_batch2 ? 't' : 'n', + m, n, k, + alpha_val, + batch1_ptr, lda, batch1_.stride(0), + batch2_ptr, ldb, batch2_.stride(0), + beta_val, + result_ptr, ldc, result_.stride(0), + num_batches + ); + }); + if (!result.is_same(result_)) { + result.copy_(result_); + } + return result; +} + } // anonymous namespace Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) { @@ -166,6 +253,52 @@ Tensor& addmm__cuda(Tensor& self, const Tensor& mat1, const Tensor& mat2, return self; } +Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + Tensor self_; + // std::tie(self_) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm_out"); + if (&result != &self) { + std::tie(self_) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm"); + } else { + self_ = self; + } + { + at::NoNamesGuard guard; + baddmm_out_cuda_impl(result, self_, batch1, batch2, beta, alpha); + } + namedinference::propagate_names_if_nonempty( + result, + namedinference::compute_baddbmm_outnames(result, batch1, batch2, self)); + return result; +} + +Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + Tensor out = at::empty({0}, self.options()); + return baddbmm_out_cuda(out, self, batch1, batch2, beta, alpha); +} + +Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) { + return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha); +} + +Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) { + result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) }); + Scalar beta(0.0); + Scalar alpha(1.0); + { + NoNamesGuard guard; + baddmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha); + } + namedinference::propagate_names_if_nonempty( + result, + namedinference::compute_bmm_outnames(result, batch1, batch2)); + return result; +} + +Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) { + Tensor result = at::empty({0}, self.options()); + return native::bmm_out_cuda(result, self, mat2); +} + template void addr_impl_ger_cuda(Tensor &out, const Tensor &self, const Tensor& vec1, const Tensor& vec2, diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index bee2f5b84e50..4ba4a4ce4456 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -48,7 +48,6 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorCopy.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMath.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathBlas.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathMagma.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathPairwise.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathReduce.cu @@ -141,8 +140,6 @@ install(FILES generic/THCTensorMasked.cu generic/THCTensorMath.h generic/THCTensorMath.cu - generic/THCTensorMathBlas.cu - generic/THCTensorMathBlas.h generic/THCTensorMathMagma.h generic/THCTensorMathMagma.cu generic/THCTensorMathPairwise.h diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index 55c8d8b1b958..52b3f44b134e 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -11,137 +11,14 @@ #include #endif -/* Level 2 */ - -void adjustLdLevel2(int64_t m, int64_t n, int64_t *lda) -{ - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - // TODO: why does Level3 check trans but this doesn't? - if (n <= 1) - *lda = std::max(m, 1); -} - -void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Sger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - -void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda) -{ - adjustLdLevel2(m, n, &lda); - - if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) - { - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda)); - return; - } - THError("Cublas_Dger only supports m, n, lda, incx, incy" - "with the bound [val] <= %d", INT_MAX); -} - - -cublasOperation_t convertTransToCublasOperation(char trans) { - if (trans == 't') return CUBLAS_OP_T; - else if (trans == 'n') return CUBLAS_OP_N; - else if (trans == 'c') return CUBLAS_OP_C; - else { - THError("trans must be one of: t, n, c"); - return CUBLAS_OP_T; - } -} - -void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc) -{ - int transa_ = ((transa == 't') || (transa == 'T')); - int transb_ = ((transb == 't') || (transb == 'T')); - - // Note: leading dimensions generally are checked that they are > 0 and at least as big the result - // requires (even if the value won't be used). - if(n <= 1) - *ldc = std::max(m, 1); - - if(transa_) - { - if(m <= 1) - *lda = std::max(k, 1); - } - else - { - if(k <= 1) - *lda = std::max(m, 1); - } - - if(transb_) - { - if(k <= 1) - *ldb = std::max(n, 1); - } - else - { - if(n <= 1) - *ldb = std::max(k, 1); - } - -} - -// Check https://github.com/pytorch/pytorch/issues/22078 -// for information about the bug. We don't know the exact conditions that trigger it, -// but using Sgemm or Hgemm on Maxwell or Pascal seems to be a -// necessary condition. -static void checkCuda90Bug(int i_m, int i_n, int i_k) -{ -#if CUDA_VERSION < 9200 && CUDA_VERSION >= 9000 - static std::once_flag alreadyWarned; - const int LIMIT = 1 << 21; - if (i_m > LIMIT || i_n > LIMIT || i_k > LIMIT) { - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major == 5 || prop->major == 6) { - std::call_once(alreadyWarned, []() { - TORCH_WARN("Matrix multiplication for dimensions larger than 2^21 has known bugs on your combination of CUDA version and device type. Please consider upgrading to CUDA 9.2 or later."); - }); - } - } -#endif -} - /* Level 3 */ void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc) { - checkCuda90Bug((int)m, (int)n, (int)k); at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -// In CUDA 8.0, definition of data types for sgemmex changed -#if CUDA_VERSION < 8000 -# define CUDA_R_16F CUBLAS_DATA_HALF -#endif - void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::Half alpha, at::Half *a, int64_t lda, at::Half *b, int64_t ldb, at::Half beta, at::Half *c, int64_t ldc) { - checkCuda90Bug((int)m, (int)n, (int)k); at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } @@ -156,197 +33,3 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6 { at::cuda::blas::gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - -#if CUDA_VERSION >= 9010 || defined __HIP_PLATFORM_HCC__ -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB, - at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; -#ifdef __HIP_PLATFORM_HCC__ - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_f16_r, (int)lda, strideA, - b, rocblas_datatype_f16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC, - c, rocblas_datatype_f16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0)); -#else -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); -#endif // CUDA_VERSION < 11000 - THCublasCheck(cublasGemmStridedBatchedEx(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, - b, CUDA_R_16F, (int)ldb, strideB, - (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC, - (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); -#endif // CUDA_VERSION < 11000 -#endif // __HIP_PLATFORM_HCC__ -} -#endif // CUDA_VERSION or __HIP_PLATFORM_HCC__ - -#ifdef __HIP_PLATFORM_HCC__ -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - float fAlpha = alpha; - float fBeta = beta; - THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_bf16_r, (int)lda, strideA, - b, rocblas_datatype_bf16_r, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_bf16_r, (int)ldc, strideC, - c, rocblas_datatype_bf16_r, (int)ldc, strideC, - (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0, NULL, NULL)); -} -#endif // __HIP_PLATFORM_HCC__ - -void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_SgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_SgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ -void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - - { - THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasSgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} -#endif - -void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - -#ifdef __HIP_PLATFORM_HCC__ - - const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n; - const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k; - const int64_t stridec = ldc*n; - - THCudaBlas_DgemmStridedBatched(state, transa, transb, m, n, k, alpha, *a, lda, stridea, *b, ldb, strideb, beta, *c, ldc, stridec, batchCount); - -#else - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc, - (int)batchCount)); -#endif -} - -#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ -void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount) -{ - at::globalContext().alertCuBLASConfigNotDeterministic(); - if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - { - THError("Cublas_DgemmBatched only supports m, n, k, lda, ldb, ldc, batchCount" - "with the bound [val] <= %d", INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - THCublasCheck(cublasDgemmStridedBatched(handle, - opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, - (int)batchCount)); -} -#endif - diff --git a/aten/src/THC/THCBlas.h b/aten/src/THC/THCBlas.h index cff3180a974a..984fa83f3f33 100644 --- a/aten/src/THC/THCBlas.h +++ b/aten/src/THC/THCBlas.h @@ -5,10 +5,6 @@ #include #include -/* Level 2 */ -THC_API void THCudaBlas_Sger(THCState *state, int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda); -THC_API void THCudaBlas_Dger(THCState *state, int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda); - /* Level 3 */ THC_API void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc); THC_API void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, double alpha, double *a, int64_t lda, double *b, int64_t ldb, double beta, double *c, int64_t ldc); @@ -18,31 +14,4 @@ THC_API void THCudaBlas_Hgemm(THCState *state, char transa, char transb, int64_t THC_API void THCudaBlas_Bgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, at::BFloat16 alpha, at::BFloat16 *a, int64_t lda, at::BFloat16 *b, int64_t ldb, at::BFloat16 beta, at::BFloat16 *c, int64_t ldc); #endif -THC_API void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a[], int64_t lda, const float *b[], int64_t ldb, - float beta, float *c[], int64_t ldc, int64_t batchCount); -THC_API void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a[], int64_t lda, const double *b[], int64_t ldb, - double beta, double *c[], int64_t ldc, int64_t batchCount); -#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ -THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, - float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount); -THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, - double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount); -#endif - -#if CUDA_VERSION >= 9010 || defined(__HIP_PLATFORM_HCC__) -void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - THHalf alpha, const THHalf *a, int64_t lda, int64_t strideA, const THHalf *b, int64_t ldb, int64_t strideB, - THHalf beta, THHalf *c, int64_t ldc, int64_t strideC, int64_t batchCount); -#endif - -#ifdef __HIP_PLATFORM_HCC__ -void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - at::BFloat16 alpha, const at::BFloat16 *a, int64_t lda, int64_t strideA, const at::BFloat16 *b, int64_t ldb, int64_t strideB, - at::BFloat16 beta, at::BFloat16 *c, int64_t ldc, int64_t strideC, int64_t batchCount); -#endif - #endif diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 68fbb240afb4..fd316f93ed55 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -13,12 +13,6 @@ #include #include -#include -#include - -#include -#include - #include #include diff --git a/aten/src/THC/THCTensorMathBlas.cu b/aten/src/THC/THCTensorMathBlas.cu deleted file mode 100644 index 383d1ed17b1d..000000000000 --- a/aten/src/THC/THCTensorMathBlas.cu +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu deleted file mode 100644 index 3158e0e267ed..000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ /dev/null @@ -1,328 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.cu" -#else - -#include -#include - -#define ERROR_ONLY_FP_TYPES(func) \ - THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func); - -__global__ void createBatchGemmBuffer3(const scalar_t** buffer1, const scalar_t ** buffer2, const scalar_t ** buffer3, scalar_t* data1, - scalar_t * data2, scalar_t * data3, int64_t stride1, int64_t stride2, int64_t stride3, int64_t num_batches) { - const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < num_batches) { - buffer1[idx] = data1 + idx * stride1; - buffer2[idx] = data2 + idx * stride2; - buffer3[idx] = data3 + idx * stride3; - } -} - -void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, - THCTensor *batch1, THCTensor *batch2, - scalar_t beta, scalar_t alpha) { -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_BFLOAT16) - THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2)); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 3, 4, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch1) == 3, 6, "expected 3D tensor"); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, batch2) == 3, 7, "expected 3D tensor"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch1, 0), 6, - "equal number of batches expected"); - THArgCheck(THCTensor_(size)(state, t, 0) == THCTensor_(size)(state, batch2, 0), 7, - "equal number of batches expected"); - auto maybe_outnames = at::namedinference::compute_baddbmm_outnames(result, batch1, batch2, t); - { - at::NoNamesGuard guard; - THArgCheck(THCTensor_(size)(state, t, 1) == THCTensor_(size)(state, batch1, 1), 6, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, t, 2) == THCTensor_(size)(state, batch2, 2), 7, - "wrong matrix size"); - THArgCheck(THCTensor_(size)(state, batch1, 2) == THCTensor_(size)(state, batch2, 1), 6, - "wrong matrix size"); - - if (t != result) { - THCTensor_(resizeAs)(state, result, t); - if (ScalarConvert::to(beta) != 0.0) { - THCTensor_(copy)(state, result, t); - } - } - - bool transpose_result; - char transpose_batch1, transpose_batch2; - int64_t lda, ldb, ldc; - THCTensor *result_, *batch1_, *batch2_; - if (result->stride(1) == 1 && - (result->size(2) == 1 || result->stride(2) >= std::max(1, result->size(1)))) - { - transpose_result = false; - result_ = result; - ldc = result_->stride(2); - } - else if (result->stride(2) == 1 && - (result->size(1) == 1 || result->stride(1) >= std::max(1, result->size(2)))) - { - transpose_result = true; - - THCTensor *swap = batch2; - batch2 = batch1; - batch1 = swap; - - result_ = result; - ldc = result_->stride(1); - } - else - { - transpose_result = false; - - THCTensor *transp_r_ = THCTensor_(newTranspose)(state, result, 1, 2); - result_ = THCTensor_(newClone)(state, transp_r_); - THCTensor_(free)(state, transp_r_); - THCTensor_(transpose)(state, result_, NULL, 1, 2); - - ldc = result_->stride(2); - } - - const int64_t m = result->size(transpose_result ? 2 : 1); - const int64_t n = result->size(transpose_result ? 1 : 2); - const int64_t k = batch1->size(transpose_result ? 1 : 2); - - if (batch1->stride(transpose_result ? 2 : 1) == 1 && - batch1->stride(transpose_result ? 1 : 2) >= std::max(1, m)) - { - transpose_batch1 = 'n'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 1 : 2); - } - else if (batch1->stride(transpose_result ? 1 : 2) == 1 && - batch1->stride(transpose_result ? 2 : 1) >= std::max(1, k)) - { - transpose_batch1 = 't'; - batch1_ = batch1; - lda = batch1_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch1 = transpose_result ? 'n' : 't'; - // batch1_ is later freed if batch1_ != batch1 - if (THCTensor_(isContiguous)(state, batch1)) { - batch1_ = batch1; - } else { - batch1_ = THCTensor_(newContiguous)(state, batch1); - } - lda = batch1_->stride(1); - } - - if (batch2->stride(transpose_result ? 2 : 1) == 1 && - batch2->stride(transpose_result ? 1 : 2) >= std::max(1, k)) - { - transpose_batch2 = 'n'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 1 : 2); - } - else if (batch2->stride(transpose_result ? 1 : 2) == 1 && - batch2->stride(transpose_result ? 2 : 1) >= std::max(1, n)) - { - transpose_batch2 = 't'; - batch2_ = batch2; - ldb = batch2_->stride(transpose_result ? 2 : 1); - } - else - { - transpose_batch2 = transpose_result ? 'n' : 't'; - // batch2_ is later freed if batch2_ != batch2 - if (THCTensor_(isContiguous)(state, batch2)) { - batch2_ = batch2; - } else { - batch2_ = THCTensor_(newContiguous)(state, batch2); - } - ldb = batch2_->stride(1); - } - int64_t num_batches = result_->size(0); - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - // Compute pointers to matrices in each batch. -#if CUDA_VERSION < 8000 && !defined __HIP_PLATFORM_HCC__ - size_t matrices_size = num_batches * sizeof(scalar_t*); - -// Copy pointers to device. - auto d_matrices1 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_matrices2 = static_cast(THCudaMalloc(state, matrices_size)); - auto d_result_matrices = static_cast(THCudaMalloc(state, matrices_size)); - - const int64_t block = 512; - const int64_t grid = (num_batches + block - 1) / block; - - createBatchGemmBuffer3<<>>( - d_matrices1, d_matrices2, (const scalar_t**)d_result_matrices, THCTensor_(data)(state, batch1_), - THCTensor_(data)(state, batch2_), THCTensor_(data)(state, result_), - batch1_->stride(0), batch2_->stride(0), result_->stride(0), num_batches); - -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); -#endif //THC_REAL - - THCudaFree(state, d_matrices1); - THCudaFree(state, d_matrices2); - THCudaFree(state, d_result_matrices); - -#else -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_SgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_DgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif //THC_REAL -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_HALF) - -#if CUDA_VERSION < 9010 && !defined(__HIP_PLATFORM_HCC__) - // Currently no HgemmBatched in Cublas - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } -#else -#ifndef __HIP_PLATFORM_HCC__ - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major >= 5){ -#endif - - THCudaBlas_HgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#ifndef __HIP_PLATFORM_HCC__ - } else { - for (int64_t i = 0; i < num_batches; ++i) { - THCudaBlas_Hgemm( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_) + i * batch1_->stride(0), lda, - THCTensor_(data)(state, batch2_) + i * batch2_->stride(0), ldb, - beta, - THCTensor_(data)(state, result_) + i * result_->stride(0), ldc); - } - } -#endif -#endif //CUDA_VERSION - -#elif defined(THC_REAL_IS_BFLOAT16) -#if defined(__HIP_PLATFORM_HCC__) - THCudaBlas_BgemmStridedBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size(transpose_result ? 2 : 1), - result_->size(transpose_result ? 1 : 2), - batch1_->size(transpose_result ? 1 : 2), - alpha, - THCTensor_(data)(state, batch1_), lda, batch1_->stride(0), - THCTensor_(data)(state, batch2_), ldb, batch2_->stride(0), - beta, - THCTensor_(data)(state, result_), ldc, result_->stride(0), - num_batches); -#endif // __HIP_PLATFORM_HCC__ -#endif - - if (batch1_ != batch1) { - THCTensor_(free)(state, batch1_); - } - - if (batch2_ != batch2) { - THCTensor_(free)(state, batch2_); - } - - if (result_ != result) { - THCTensor_(freeCopyTo)(state, result_, result); - } - -#if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__) - // To avoid "variable was set but never used" warning - [&transpose_batch1, &transpose_batch2, &lda, &ldb, &ldc]{}(); - TORCH_CHECK(false, "BgemmStridedBatched is not supported with at::BFloat16 type"); -#endif - } -#if !defined(THC_REAL_IS_BFLOAT16) || defined(__HIP_PLATFORM_HCC__) - at::namedinference::propagate_names_if_nonempty(result, maybe_outnames); -#endif - -#else - ERROR_ONLY_FP_TYPES("baddbmm"); -#endif -} - -#endif diff --git a/aten/src/THC/generic/THCTensorMathBlas.h b/aten/src/THC/generic/THCTensorMathBlas.h deleted file mode 100644 index e15baafaca64..000000000000 --- a/aten/src/THC/generic/THCTensorMathBlas.h +++ /dev/null @@ -1,7 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathBlas.h" -#else - -THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha); - -#endif diff --git a/test/test_torch.py b/test/test_torch.py index d0b7dc0648f0..d1e6c6a3639b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -19579,12 +19579,12 @@ def inner(self, device, dtype): 1e-1, 1e-1, 1e-4, _float_types_2_and_complex_if_not_rocm, _cpu_types, True, [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]), ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types2), + 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types), ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types2, _cpu_types, True, + 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types, _cpu_types, True, [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('baddbmm', 'two_scalars', _small_3d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types2, _cpu_types, True, + 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types, _cpu_types, True, [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('bmm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False),