Add support for CUBLAS_COMPUTE_16F for GEMM opeartions in cudaBLAS #123157
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
matrix multiplication
module: cublas
Problem related to cublas support
module: cuda
Related to torch.cuda, and CUDA support in general
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃殌 The feature, motivation and pitch
For Nvidia ADA architecture GPUs like 4090, the performance of the GEMM kernel for fp16 is different between FP16 accumulation(330.3TFLOPS) and FP32 accumulation(165.2TFLOPS). This can be checked in the document. The accumulation mode can be controlled by the cublasComputeType argument when calling cublas functions.
The current implementation in pytorch has an option named
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
(allowFP16ReductionCuBLAS
in aten context). As the document suggested, when it is set to true, the accumulation for FP16 GEMM is allowed in FP16 mode.The check of this option is in:
pytorch/aten/src/ATen/cuda/CUDABlas.cpp
Lines 454 to 456 in 39901f2
Even if the user enabled the
allowFP16ReductionCuBLAS
operation, the following call tocublasGemmEx
will still use fp32 as the accumulation mode(The second to last argument is alwaysCUDA_R_32F
).pytorch/aten/src/ATen/cuda/CUDABlas.cpp
Lines 460 to 479 in 39901f2
Also, for batched GEMM operation, the cublas compute mode is set to
CUBLAS_COMPUTE_32F
unconditionally without checking theallowFP16ReductionCuBLAS
option.pytorch/aten/src/ATen/cuda/CUDABlas.cpp
Line 645 in 39901f2
This feature can DOUBLE the performance in commodity GPUs like 4090/4080. For users who explicitly enable the
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
, we should use theCUBLAS_COMPUTE_16F
as the cublasComputeType in cublas function calls for higher performance.If this is the desired behavior, I'm very willing to write a patch to support this feature :)
Alternatives
No response
Additional context
No response
cc @ptrblck @csarofeen @xwang233
The text was updated successfully, but these errors were encountered: