Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port bmm and baddbmm from TH to ATen #42553

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1cd351c
Port bmm and baddbmm from TH to ATen
anjali411 Aug 4, 2020
2e66241
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
8ae2836
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
01356ee
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
3edd722
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
7c8a985
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 7, 2020
59da4e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 12, 2020
248f5b6
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 13, 2020
495c0e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 26, 2020
188e763
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 27, 2020
0edb947
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
1363488
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
ad966a5
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
b3cee89
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
70840de
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
9518f8a
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 2, 2020
2d9388e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 2, 2020
7363e7d
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 6, 2020
ee97b87
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 7, 2020
38c5057
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 9, 2020
7e578b8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 9, 2020
a6f6d8e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 10, 2020
815ac05
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 10, 2020
d0ec8d8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
006a3bd
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
4c4af10
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
a5e2c96
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
7813ce7
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
af8c125
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 12, 2020
909b366
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -378,7 +378,6 @@ filegroup(
"aten/src/THC/THCTensorCopy.cu.cc",
"aten/src/THC/THCTensorIndex.cu.cc",
"aten/src/THC/THCTensorMath.cu.cc",
"aten/src/THC/THCTensorMathBlas.cu.cc",
"aten/src/THC/THCTensorMathMagma.cu.cc",
"aten/src/THC/THCTensorMathPairwise.cu.cc",
"aten/src/THC/THCTensorMathReduce.cu.cc",
Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/LegacyTHFunctionsCUDA.h
Expand Up @@ -44,10 +44,6 @@ Tensor & _th_fmod_(Tensor & self, Scalar other);
Tensor & _th_fmod_(Tensor & self, const Tensor & other);
Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim);
Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim);
Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2);
Tensor _th_bmm(const Tensor & self, const Tensor & mat2);
Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
Expand Down
17 changes: 8 additions & 9 deletions aten/src/ATen/NamedTensorUtils.cpp
Expand Up @@ -517,17 +517,16 @@ std::vector<Dimname> compute_bmm_outnames(
}

std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* batch1,
TensorImpl* batch2,
TensorImpl* bias) {
if (!impl::has_names(result) && !impl::has_names(batch1) &&
!impl::has_names(batch2) && !impl::has_names(bias)) {
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias) {
if (!result.has_names() && !self.has_names()
&& !other.has_names() && !bias.has_names()) {
return {};
}
auto bmm_names = compute_matmul_outnames(
impl::get_names(batch1), impl::get_names(batch2));
auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names);
auto bmm_names = compute_matmul_outnames(self.names(), other.names());
auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
return baddbmm_names;
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv(
CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);

CAFFE2_API std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* self,
TensorImpl* other,
TensorImpl* bias);
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);

CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other);

Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -133,8 +133,6 @@ _(aten, _sum_cuda) \
_(aten, _tan) \
_(aten, _tanh) \
_(aten, _tanh_forward) \
_(aten, _th_baddbmm) \
_(aten, _th_bmm) \
_(aten, _th_get_device) \
_(aten, _th_kthvalue) \
_(aten, _th_mode) \
Expand Down Expand Up @@ -669,7 +667,6 @@ _(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, tensor_split) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
_(aten, th_pow) \
Expand Down
205 changes: 205 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Expand Up @@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {

/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No harm done, workaround is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 no my bad! we should fix that to avoid confusion in future

#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
// Workaround for https://github.com/pytorch/pytorch/issues/45724
cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType Btype,
int ldb,
long long int strideB,
const void *beta,
void *C,
cudaDataType Ctype,
int ldc,
long long int strideC,
int64_t batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo)
{
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major != 7) {
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
}
cublasStatus_t result;
constexpr int64_t split = 63 * 1024;
for(int64_t i = 0; i < batchCount; i += split) {
int64_t count = std::min<int64_t>(split, batchCount - i);
result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
(char *)A + i * strideA * 2, Atype, lda, strideA,
(char *)B + i * strideB * 2, Btype, ldb, strideB,
beta,
(char *)C + i * strideC * 2, Ctype, ldc, strideC,
(int)count, computeType, algo);
TORCH_CUDABLAS_CHECK(result);
}
return result;
}
#endif
#endif

#define GEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
Expand All @@ -143,6 +193,161 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)

#define BGEMM_CHECK_ARGVALUES(Dtype) \
do { \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
} while (0)

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(double);
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(float);
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<double>);
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}

template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(c10::complex<float>);
TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can consider using cublasCgemm3mStridedBatched in future since it would be faster albeit less precise.

handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}

template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
#ifdef __HIP_PLATFORM_HCC__
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
b, rocblas_datatype_f16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
c, rocblas_datatype_f16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0));
#else
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000

cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
handle, opa, opb, m, n, k,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (int64_t i = 0; i < num_batches; ++i) {
at::cuda::blas::gemm<at::Half>(
transa, transb,
m, n, k,
alpha, (a + i * stridea), lda,
(b + i * strideb), ldb, beta,
(c + i * stridec), ldc);
}
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
}

#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
BGEMM_CHECK_ARGVALUES(at::BFloat16);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#elif defined(__HIP_PLATFORM_HCC__)
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea,
b, rocblas_datatype_bf16_r, (int)ldb, strideb,
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec,
c, rocblas_datatype_bf16_r, (int)ldc, stridec,
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0, NULL, NULL));
#else
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
}
#endif // __HIP_PLATFORM_HCC__

template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
Expand Down
37 changes: 25 additions & 12 deletions aten/src/ATen/cuda/CUDABlas.h
Expand Up @@ -69,6 +69,31 @@ template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, int64_t stridea, \
const Dtype *b, int64_t ldb, int64_t strideb, \
Dtype beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches

template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
}

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
#endif
/* LEVEL 2 BLAS FUNCTIONS */

#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
Expand Down Expand Up @@ -97,18 +122,6 @@ template <>
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
#endif

template <typename Dtype>
void ger(
int64_t m,
int64_t n,
Dtype alpha,
Dtype* x,
int64_t incx,
Dtype* y,
int64_t incy,
Dtype* a,
int64_t lda);

/* LEVEL 1 BLAS FUNCTIONS */

#define CUDABLAS_DOT_ARGTYPES(Dtype) \
Expand Down