Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port bmm and baddbmm from TH to ATen #42553

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1cd351c
Port bmm and baddbmm from TH to ATen
anjali411 Aug 4, 2020
2e66241
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
8ae2836
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
01356ee
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
3edd722
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 Aug 5, 2020
7c8a985
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 7, 2020
59da4e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 12, 2020
248f5b6
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 13, 2020
495c0e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Aug 26, 2020
188e763
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 27, 2020
0edb947
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
1363488
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
ad966a5
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
b3cee89
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
70840de
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Oct 28, 2020
9518f8a
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 2, 2020
2d9388e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 2, 2020
7363e7d
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 6, 2020
ee97b87
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 7, 2020
38c5057
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 9, 2020
7e578b8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 9, 2020
a6f6d8e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 10, 2020
815ac05
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 10, 2020
d0ec8d8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
006a3bd
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
4c4af10
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
a5e2c96
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
7813ce7
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 11, 2020
af8c125
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 12, 2020
909b366
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 Nov 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion BUILD.bazel
Expand Up @@ -380,7 +380,6 @@ filegroup(
"aten/src/THC/THCTensorCopy.cu.cc",
"aten/src/THC/THCTensorIndex.cu.cc",
"aten/src/THC/THCTensorMath.cu.cc",
"aten/src/THC/THCTensorMathBlas.cu.cc",
"aten/src/THC/THCTensorMathMagma.cu.cc",
"aten/src/THC/THCTensorMathPairwise.cu.cc",
"aten/src/THC/THCTensorMathReduce.cu.cc",
Expand Down
36 changes: 0 additions & 36 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -393,42 +393,6 @@
- real alpha
]]
[[
[[
name: _th_bmm
cuda_bfloat16: True
cname: baddbmm
variants:
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
output: True
- argument 0
- THTensor* self
- THTensor* mat2
- CONSTANT AS_REAL(0)
- CONSTANT AS_REAL(1)
]]
[[
name: _th_baddbmm
cuda_bfloat16: True
cname: baddbmm
variants:
- function
backends:
- CUDA
return: argument 0
arguments:
- arg: THTensor* result
output: True
- arg: THTensor* self
- THTensor* batch1
- THTensor* batch2
- real beta
- real alpha
]]
[[
name: _th_gels
cname: gels
Expand Down
17 changes: 8 additions & 9 deletions aten/src/ATen/NamedTensorUtils.cpp
Expand Up @@ -517,17 +517,16 @@ std::vector<Dimname> compute_bmm_outnames(
}

std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* batch1,
TensorImpl* batch2,
TensorImpl* bias) {
if (!impl::has_names(result) && !impl::has_names(batch1) &&
!impl::has_names(batch2) && !impl::has_names(bias)) {
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias) {
if (!result.has_names() && !self.has_names()
&& !other.has_names() && !bias.has_names()) {
return {};
}
auto bmm_names = compute_matmul_outnames(
impl::get_names(batch1), impl::get_names(batch2));
auto baddbmm_names = unify_from_right(impl::get_names(bias), bmm_names);
auto bmm_names = compute_matmul_outnames(self.names(), other.names());
auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
return baddbmm_names;
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/NamedTensorUtils.h
Expand Up @@ -155,10 +155,10 @@ CAFFE2_API void propagate_names_for_addmv(
CAFFE2_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);

CAFFE2_API std::vector<Dimname> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* self,
TensorImpl* other,
TensorImpl* bias);
Tensor& result,
const Tensor& self,
const Tensor& other,
const Tensor& bias);

CAFFE2_API bool are_names_equal(TensorImpl* self, TensorImpl* other);

Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -141,8 +141,6 @@ _(aten, _tan) \
_(aten, _tanh) \
_(aten, _tanh_backward) \
_(aten, _tanh_forward) \
_(aten, _th_baddbmm) \
_(aten, _th_bmm) \
_(aten, _th_get_device) \
_(aten, _th_kthvalue) \
_(aten, _th_median) \
Expand Down Expand Up @@ -672,7 +670,6 @@ _(aten, tan) \
_(aten, tanh) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, th_addmm) \
_(aten, th_clone) \
_(aten, th_norm) \
_(aten, th_pow) \
Expand Down
117 changes: 117 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Expand Up @@ -143,6 +143,123 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) {
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
} while (0)

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(double);
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(float);
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
}

#ifndef __HIP_PLATFORM_HCC__
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
Copy link
Collaborator

cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<double>);
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
}
#endif

#ifndef __HIP_PLATFORM_HCC__
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(c10::complex<float>);
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
TORCH_CUDABLAS_CHECK(cublasCgemm3mStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
}
#endif

// #ifdef __HIP_PLATFORM_HCC__
// template <>
// void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cublasOperation_t opa = _cublasOpFromChar(transa);
// cublasOperation_t opb = _cublasOpFromChar(transb);
// float falpha = alpha;
// float fbeta = beta;
// _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
// GEMM_CHECK_ARGVALUES(at::BFloat16);
// TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(
// handle,
// opa,
// opb,
// m,
// n,
// k,
// &falpha,
// a,
// rocblas_datatype_bf16_r,
// lda,
// b,
// rocblas_datatype_bf16_r,
// ldb,
// &fbeta,
// c,
// rocblas_datatype_bf16_r,
// ldc,
// c,
// rocblas_datatype_bf16_r,
// ldc,
// rocblas_datatype_f32_r,
// rocblas_gemm_algo_standard,
// 0,
// 0));
// }
// #endif

template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasOperation_t opa = _cublasOpFromChar(transa);
cublasOperation_t opb = _cublasOpFromChar(transb);
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
const int64_t stridea = (transa == 'N' || transa == 'n') ? lda*k : lda*n;
const int64_t strideb = (transb == 'N' || transb == 'n') ? ldb*n : ldb*k;
const int64_t stridec = ldc*n;
TORCH_CUDABLAS_CHECK(cublasHgemmStridedBatched(
handle, opa, opb, m, n, k, reinterpret_cast<const __half *>(&alpha), reinterpret_cast<const __half *>(a), lda, stridea,
reinterpret_cast<const __half *>(b), ldb, strideb, reinterpret_cast<const __half *>(&beta), reinterpret_cast<__half *>(c), ldc, stridec, num_batches));
}

template <>
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
Expand Down
40 changes: 40 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.h
Expand Up @@ -50,6 +50,46 @@ template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif

// cublasStatus_t cublasHgemmBatched(cublasHandle_t handle,
// cublasOperation_t transa,
// cublasOperation_t transb,
// int m, int n, int k,
// const __half *alpha,
// const __half *Aarray[], int lda,
// const __half *Barray[], int ldb,
// const __half *beta,
// __half *Carray[], int ldc,
// int batchCount)

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
char transa, char transb, int64_t m, int64_t n, int64_t k, Dtype alpha, \
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, Dtype beta, \
Dtype *c, int64_t ldc, int64_t num_batches

template <typename Dtype>
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
}

template <>
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
template <>
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
#ifndef __HIP_PLATFORM_HCC__
template <>
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
#endif
#ifndef __HIP_PLATFORM_HCC__
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
template <>
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
#endif
template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
// #ifdef __HIP_PLATFORM_HCC__
// template <>
// void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
// #endif

/* LEVEL 2 BLAS FUNCTIONS */

#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
Expand Down