Skip to content

Commit

Permalink
Port bmm and baddbmm from TH to ATen
Browse files Browse the repository at this point in the history
ghstack-source-id: 61c04d480b16d92ad66fe7246773b48f5f5302fc
Pull Request resolved: #42553
  • Loading branch information
anjali411 committed Aug 12, 2020
1 parent ac93d45 commit 222077a
Show file tree
Hide file tree
Showing 16 changed files with 314 additions and 795 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -376,7 +376,6 @@ filegroup(
"aten/src/THC/THCTensorCopy.cu.cc",
"aten/src/THC/THCTensorIndex.cu.cc",
"aten/src/THC/THCTensorMath.cu.cc",
"aten/src/THC/THCTensorMathBlas.cu.cc",
"aten/src/THC/THCTensorMathMagma.cu.cc",
"aten/src/THC/THCTensorMathPairwise.cu.cc",
"aten/src/THC/THCTensorMathReduce.cu.cc",
Expand Down
36 changes: 0 additions & 36 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -361,42 +361,6 @@
- real alpha
]]
[[
[[
name: _th_bmm
cuda_bfloat16: True
cname: baddbmm
variants:
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
output: True
- argument 0
- THTensor* self
- THTensor* mat2
- CONSTANT AS_REAL(0)
- CONSTANT AS_REAL(1)
]]
[[
name: _th_baddbmm
cuda_bfloat16: True
cname: baddbmm
variants:
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
output: True
- arg: THTensor* self
- THTensor* batch1
- THTensor* batch2
- real beta
- real alpha
]]
[[
name: _th_gels
cname: gels
Expand Down
17 changes: 8 additions & 9 deletions aten/src/ATen/NamedTensorUtils.cpp
Expand Up @@ -517,17 +517,16 @@ std::vector<Dimname> compute_bmm_outnames(
}

std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* batch1,
TensorImpl* batch2,
TensorImpl* bias) {
if (!impl::has_names(result) && !impl::has_names(batch1) &&
!impl::has_names(batch2) && !impl::has_names(bias)) {
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias) {
if (!result.has_names() && !self.has_names()
&& !other.has_names() && !bias.has_names()) {
return {};
}
auto bmm_names = compute_matmul_outnames(
impl::get_names(batch1), impl::get_names(batch2));
auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names);
auto bmm_names = compute_matmul_outnames(self.names(), other.names());
auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
return baddbmm_names;
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv(
CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);

CAFFE2_API std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* self,
TensorImpl* other,
TensorImpl* bias);
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);

CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other);

Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -140,8 +140,6 @@ _(aten, _tan) \
_(aten, _tanh) \
_(aten, _tanh_backward) \
_(aten, _tanh_forward) \
_(aten, _th_baddbmm) \
_(aten, _th_bmm) \
_(aten, _th_get_device) \
_(aten, _th_kthvalue) \
_(aten, _th_median) \
Expand Down Expand Up @@ -671,7 +669,6 @@ _(aten, tan) \
_(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
_(aten, th_pow) \
Expand Down
104 changes: 104 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Expand Up @@ -143,6 +143,110 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

#ifndef __HIP_PLATFORM_HCC__
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}

template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemm3mStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}
#endif

template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
float fAlpha = alpha;
float fBeta = beta;
#ifdef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0));
#else
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
handle, opa, opb, m, n, k, (void*)(&alpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&beta), c, CUDA_R_16F, ldc, stridec, num_batches,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
}

#ifdef __HIP_PLATFORM_HCC__
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_bf16_r, (int)lda, strideA,
b, rocblas_datatype_bf16_r, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_bf16_r, (int)ldc, strideC,
c, rocblas_datatype_bf16_r, (int)ldc, strideC,
(int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
}
#endif

template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
globalContext().alertCuBLASConfigNotDeterministic();
Expand Down
44 changes: 36 additions & 8 deletions aten/src/ATen/cuda/CUDABlas.h
Expand Up @@ -69,6 +69,34 @@ template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches

template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
}

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
#ifndef __HIP_PLATFORM_HCC__
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
#endif
#ifdef __HIP_PLATFORM_HCC__
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#endif
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));

/* LEVEL 2 BLAS FUNCTIONS */

#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
Expand All @@ -85,17 +113,17 @@ void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
template <>
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
#ifndef __HIP_PLATFORM_HCC__
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
template <>
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
template <>
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
#endif
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
#ifdef __HIP_PLATFORM_HCC__
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
#endif
template <>
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));

template <typename Dtype>
void ger(
Expand Down

0 comments on commit 222077a

Please sign in to comment.