New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port bmm and baddbmm from TH to ATen #42553
Closed
Closed
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
1cd351c
Port bmm and baddbmm from TH to ATen
anjali411 2e66241
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 8ae2836
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 01356ee
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 3edd722
Update on "[WIP] Port bmm and baddbmm from TH to ATen"
anjali411 7c8a985
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 59da4e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 248f5b6
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 495c0e1
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 188e763
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 0edb947
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 1363488
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 ad966a5
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 b3cee89
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 70840de
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 9518f8a
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 2d9388e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 7363e7d
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 ee97b87
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 38c5057
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 7e578b8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 a6f6d8e
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 815ac05
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 d0ec8d8
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 006a3bd
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 4c4af10
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 a5e2c96
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 7813ce7
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 af8c125
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 909b366
Update on "Port bmm and baddbmm from TH to ATen"
anjali411 File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,6 +133,56 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { | |
|
||
/* LEVEL 3 BLAS FUNCTIONS */ | ||
|
||
#ifndef __HIP_PLATFORM_HCC__ | ||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 | ||
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx | ||
#else | ||
// Workaround for https://github.com/pytorch/pytorch/issues/45724 | ||
cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle, | ||
cublasOperation_t transa, | ||
cublasOperation_t transb, | ||
int m, | ||
int n, | ||
int k, | ||
const void *alpha, | ||
const void *A, | ||
cudaDataType Atype, | ||
int lda, | ||
long long int strideA, | ||
const void *B, | ||
cudaDataType Btype, | ||
int ldb, | ||
long long int strideB, | ||
const void *beta, | ||
void *C, | ||
cudaDataType Ctype, | ||
int ldc, | ||
long long int strideC, | ||
int64_t batchCount, | ||
cudaDataType computeType, | ||
cublasGemmAlgo_t algo) | ||
{ | ||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
if (prop->major != 7) { | ||
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo); | ||
} | ||
cublasStatus_t result; | ||
constexpr int64_t split = 63 * 1024; | ||
for(int64_t i = 0; i < batchCount; i += split) { | ||
int64_t count = std::min<int64_t>(split, batchCount - i); | ||
result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, | ||
(char *)A + i * strideA * 2, Atype, lda, strideA, | ||
(char *)B + i * strideB * 2, Btype, ldb, strideB, | ||
beta, | ||
(char *)C + i * strideC * 2, Ctype, ldc, strideC, | ||
(int)count, computeType, algo); | ||
TORCH_CUDABLAS_CHECK(result); | ||
} | ||
return result; | ||
} | ||
#endif | ||
#endif | ||
|
||
#define GEMM_CHECK_ARGVALUES(Dtype) \ | ||
do { \ | ||
CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \ | ||
|
@@ -143,6 +193,161 @@ const char* _cublasGetErrorEnum(cublasStatus_t error) { | |
CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \ | ||
} while (0) | ||
|
||
#define BGEMM_CHECK_ARGVALUES(Dtype) \ | ||
do { \ | ||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \ | ||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \ | ||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \ | ||
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \ | ||
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \ | ||
CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \ | ||
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \ | ||
} while (0) | ||
|
||
template <> | ||
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) { | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
BGEMM_CHECK_ARGVALUES(double); | ||
TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched( | ||
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); | ||
} | ||
|
||
template <> | ||
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
BGEMM_CHECK_ARGVALUES(float); | ||
TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched( | ||
handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)); | ||
} | ||
|
||
template <> | ||
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
BGEMM_CHECK_ARGVALUES(c10::complex<double>); | ||
TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched( | ||
handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a), | ||
lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta), | ||
reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches)); | ||
} | ||
|
||
template <> | ||
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
BGEMM_CHECK_ARGVALUES(c10::complex<float>); | ||
TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can consider using |
||
handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a), | ||
lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta), | ||
reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches)); | ||
} | ||
|
||
template <> | ||
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
BGEMM_CHECK_ARGVALUES(at::Half); | ||
float falpha = alpha; | ||
float fbeta = beta; | ||
#ifdef __HIP_PLATFORM_HCC__ | ||
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, | ||
(void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, | ||
b, rocblas_datatype_f16_r, (int)ldb, strideb, | ||
(void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, | ||
c, rocblas_datatype_f16_r, (int)ldc, stridec, | ||
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, | ||
0, 0)); | ||
#else | ||
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 | ||
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH | ||
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. | ||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); | ||
#endif // CUDA_VERSION < 11000 | ||
|
||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
if (prop->major >= 5){ | ||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix( | ||
handle, opa, opb, m, n, k, | ||
(void*)(&falpha), a, CUDA_R_16F, lda, stridea, | ||
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), | ||
c, CUDA_R_16F, ldc, stridec, | ||
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
} else { | ||
for (int64_t i = 0; i < num_batches; ++i) { | ||
at::cuda::blas::gemm<at::Half>( | ||
transa, transb, | ||
m, n, k, | ||
alpha, (a + i * stridea), lda, | ||
(b + i * strideb), ldb, beta, | ||
(c + i * stridec), ldc); | ||
} | ||
} | ||
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000 | ||
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH | ||
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. | ||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); | ||
#endif // CUDA_VERSION < 11000 | ||
#endif // __HIP_PLATFORM_HCC__ | ||
} | ||
|
||
#if defined(__HIP_PLATFORM_HCC__) || defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
template <> | ||
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
globalContext().alertCuBLASConfigNotDeterministic(); | ||
BGEMM_CHECK_ARGVALUES(at::BFloat16); | ||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | ||
cublasOperation_t opa = _cublasOpFromChar(transa); | ||
cublasOperation_t opb = _cublasOpFromChar(transb); | ||
float falpha = alpha; | ||
float fbeta = beta; | ||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
|
||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); | ||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, | ||
opa, opb, (int)m, (int)n, (int)k, | ||
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, | ||
b, CUDA_R_16BF, (int)ldb, strideb, | ||
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, | ||
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
#elif defined(__HIP_PLATFORM_HCC__) | ||
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k, | ||
(void*)&falpha, a, rocblas_datatype_bf16_r, (int)lda, stridea, | ||
b, rocblas_datatype_bf16_r, (int)ldb, strideb, | ||
(void*)&fbeta, c, rocblas_datatype_bf16_r, (int)ldc, stridec, | ||
c, rocblas_datatype_bf16_r, (int)ldc, stridec, | ||
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, | ||
0, 0, NULL, NULL)); | ||
#else | ||
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); | ||
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 | ||
} | ||
#endif // __HIP_PLATFORM_HCC__ | ||
|
||
template <> | ||
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) { | ||
// See Note [Writing Nondeterministic Operations] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this macro
CUDA_VERSION >= 11200
intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No harm done, workaround is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xwang233 no my bad! we should fix that to avoid confusion in future