Skip to content

Commit

Permalink
Port bmm and baddbmm from TH to ATen
Browse files Browse the repository at this point in the history
ghstack-source-id: c52815b490a9d924af2ec40660c8f20ff1e188c8
Pull Request resolved: #42553
  • Loading branch information
anjali411 committed Nov 12, 2020
1 parent 4738672 commit 4f8c62b
Show file tree
Hide file tree
Showing 17 changed files with 408 additions and 1,133 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -378,7 +378,6 @@ filegroup(
"aten/src/THC/THCTensorCopy.cu.cc",
"aten/src/THC/THCTensorIndex.cu.cc",
"aten/src/THC/THCTensorMath.cu.cc",
"aten/src/THC/THCTensorMathBlas.cu.cc",
"aten/src/THC/THCTensorMathMagma.cu.cc",
"aten/src/THC/THCTensorMathPairwise.cu.cc",
"aten/src/THC/THCTensorMathReduce.cu.cc",
Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/LegacyTHFunctionsCUDA.h
Expand Up @@ -44,10 +44,6 @@ Tensor & _th_fmod_(Tensor & self, Scalar other);
Tensor & _th_fmod_(Tensor & self, const Tensor & other);
Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim);
Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim);
Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2);
Tensor _th_bmm(const Tensor & self, const Tensor & mat2);
Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
Expand Down
17 changes: 8 additions & 9 deletions aten/src/ATen/NamedTensorUtils.cpp
Expand Up @@ -517,17 +517,16 @@ std::vector<Dimname> compute_bmm_outnames(
}

std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* batch1,
TensorImpl* batch2,
TensorImpl* bias) {
if (!impl::has_names(result) && !impl::has_names(batch1) &&
!impl::has_names(batch2) && !impl::has_names(bias)) {
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias) {
if (!result.has_names() && !self.has_names()
&& !other.has_names() && !bias.has_names()) {
return {};
}
auto bmm_names = compute_matmul_outnames(
impl::get_names(batch1), impl::get_names(batch2));
auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names);
auto bmm_names = compute_matmul_outnames(self.names(), other.names());
auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
return baddbmm_names;
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv(
CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);

CAFFE2_API std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* self,
TensorImpl* other,
TensorImpl* bias);
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);

CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other);

Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -133,8 +133,6 @@ _(aten, _sum_cuda) \
_(aten, _tan) \
_(aten, _tanh) \
_(aten, _tanh_forward) \
_(aten, _th_baddbmm) \
_(aten, _th_bmm) \
_(aten, _th_get_device) \
_(aten, _th_kthvalue) \
_(aten, _th_mode) \
Expand Down Expand Up @@ -669,7 +667,6 @@ _(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, tensor_split) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
_(aten, th_pow) \
Expand Down
205 changes: 205 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Expand Up @@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {

/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
// Workaround for https://github.com/pytorch/pytorch/issues/45724
cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType Btype,
int ldb,
long long int strideB,
const void *beta,
void *C,
cudaDataType Ctype,
int ldc,
long long int strideC,
int64_t batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo)
{
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major != 7) {
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
}
cublasStatus_t result;
constexpr int64_t split = 63 * 1024;
for(int64_t i = 0; i < batchCount; i += split) {
int64_t count = std::min<int64_t>(split, batchCount - i);
result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
(char *)A + i * strideA * 2, Atype, lda, strideA,
(char *)B + i * strideB * 2, Btype, ldb, strideB,
beta,
(char *)C + i * strideC * 2, Ctype, ldc, strideC,
(int)count, computeType, algo);
TORCH_CUDABLAS_CHECK(result);
}
return result;
}
#endif
#endif

#define GEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
Expand All @@ -143,6 +193,161 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)

#define BGEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}

template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}

template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0));
#else
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000

cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
handle, opa, opb, m, n, k,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (int64_t i = 0; i < num_batches; ++i) {
at::cuda::blas::gemm<at::Half>(
transa, transb,
m, n, k,
alpha, (a + i * stridea), lda,
(b + i * strideb), ldb, beta,
(c + i * stridec), ldc);
}
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
}

#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#elif defined(__HIP_PLATFORM_HCC__)
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
c, rocblas_datatype_bf16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
#else
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__

template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
Expand Down
37 changes: 25 additions & 12 deletions aten/src/ATen/cuda/CUDABlas.h
Expand Up @@ -69,6 +69,31 @@ template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches

template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
}

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#endif
/* LEVEL 2 BLAS FUNCTIONS */

#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
Expand Down Expand Up @@ -97,18 +122,6 @@ template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
#endif

template <typename Dtype>
void ger(
int64_t m,
int64_t n,
Dtype alpha,
Dtype* x,
int64_t incx,
Dtype* y,
int64_t incy,
Dtype* a,
int64_t lda);

/* LEVEL 1 BLAS FUNCTIONS */

#define CUDABLAS_DOT_ARGTYPES(Dtype) \
Expand Down

0 comments on commit 4f8c62b

Please sign in to comment.