Skip to content

Commit

Permalink
enable mkldnn bf32 matmul (#116015)
Browse files Browse the repository at this point in the history
### Testing
FP32 matmul vs. mkldnn BF32 matmul  on SPR

single core:

Input | BF32   / ms | FP32  /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 32.842 | 38.279 | 1.165
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 38.590 | 73.967 | 1.917
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 18456.267 | 74588.002 | 4.041

56 cores:
Input | BF32   / ms | FP32 /   ms | Speed up
-- | -- | -- | --
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 1199.400 | 1715.548 | 1.430
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True |1129.204 | 1708.912 |  1.513
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 3655.915  | 7992.877 | 2.186
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 3707.993 |  8026.191 | 2.165
Batch: 768, M: 128, N: 64, K: 128  | 1296.419 | 1308.411 | 1.009

Pull Request resolved: #116015
Approved by: https://github.com/jgong5, https://github.com/ezyang
  • Loading branch information
zhuhaozhe authored and pytorchmergebot committed Jan 20, 2024
1 parent aaae2d8 commit 0ae952d
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 16 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tens
if (result.numel() != 0) {

NoNamesGuard guard;
if (use_mkldnn_lower_precision_matmul(mat, vec, /*result=*/Tensor())){
if (use_mkldnn_matmul(mat, vec, /*result=*/Tensor())){
mkldnn_matmul(mat, vec, result, beta_.to<float>(), alpha_.to<float>());
return;
}
Expand Down Expand Up @@ -176,7 +176,7 @@ Tensor dot(const Tensor &self, const Tensor &other){
return at::_efficientzerotensor({}, self.options());
}

if (use_mkldnn_lower_precision_matmul(self, other, /*result=*/Tensor())){
if (use_mkldnn_matmul(self, other, /*result=*/Tensor())){
// mkldnn matmul expect result have sizes info to create ideep tensor
auto r = at::empty({1, 1}, self.options());
mkldnn_matmul(self, other, r, /*beta=*/0);
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/CPUBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ void gemm(
const float beta,
float *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if AT_MKLDNN_ENABLED()
if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
return;
}
#endif
#if AT_BUILD_WITH_BLAS()
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
};

bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
if (apply_heur && use_mkldnn_lower_precision_matmul(batch1, batch2, self_or_result)) {
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
try {
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
return;
Expand Down
77 changes: 66 additions & 11 deletions aten/src/ATen/native/mkldnn/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,25 @@ bool mkldnn_fp16_gemm(
c10::Half *c, int64_t ldc) {
return false;
}
bool mkldnn_bf32_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc){
return false;
}

bool use_mkldnn_lower_precision_matmul(
bool use_mkldnn_bf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return false;
}

bool use_mkldnn_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
Expand All @@ -80,22 +97,29 @@ static bool use_mkldnn_fp16_matmul() {
return at::globalContext().userEnabledMkldnn() && mkldnn_fp16_device_check();
}

static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
}


template<typename scalar_t>
inline typename std::enable_if_t<
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, c10::Half> ||
std::is_same_v<scalar_t, c10::BFloat16>,
bool>
mkldnn_lowerp_gemm(
mkldnn_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const scalar_t *a_data, int64_t lda,
const scalar_t *b_data, int64_t ldb,
float beta,
scalar_t *c_data, int64_t ldc) {
if (!(std::is_same_v<scalar_t, c10::BFloat16> ? use_mkldnn_bf16_matmul()
: use_mkldnn_fp16_matmul()) ||
bool bf16_usable = std::is_same_v<scalar_t, c10::BFloat16> && use_mkldnn_bf16_matmul();
bool fp16_usable = std::is_same_v<scalar_t, c10::Half> && use_mkldnn_fp16_matmul();
bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
if ( !(bf16_usable || fp16_usable || bf32_usable) ||
(m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
return false;
}
Expand All @@ -105,6 +129,7 @@ mkldnn_lowerp_gemm(
if (beta != 0.0f) {
op_attr = ideep::attr_t::fuse_sum();
}
if (std::is_same_v<scalar_t, float>) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path

// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
// Use identity: C = AB <=> C^T = B^T A^T
Expand All @@ -117,9 +142,12 @@ mkldnn_lowerp_gemm(
}

auto idtype = ideep::tensor::data_type::bf16;
if constexpr (!std::is_same_v<scalar_t, c10::BFloat16>) {
if constexpr (std::is_same_v<scalar_t, c10::Half>) {
idtype = ideep::tensor::data_type::f16;
}
if constexpr (std::is_same_v<scalar_t, float>) {
idtype = ideep::tensor::data_type::f32;
}

ideep::tensor a({
/*sizes=*/{k, m},
Expand Down Expand Up @@ -164,7 +192,7 @@ bool mkldnn_bf16_gemm(
const c10::BFloat16 *b, int64_t ldb,
float beta,
c10::BFloat16 *c, int64_t ldc) {
return mkldnn_lowerp_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

bool mkldnn_fp16_gemm(
Expand All @@ -175,9 +203,19 @@ bool mkldnn_fp16_gemm(
const c10::Half *b, int64_t ldb,
float beta,
c10::Half *c, int64_t ldc) {
return mkldnn_lowerp_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
return mkldnn_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

bool mkldnn_bf32_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc){
return mkldnn_gemm<float>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

void mkldnn_matmul(
const Tensor &mat1,
Expand Down Expand Up @@ -205,11 +243,12 @@ void mkldnn_matmul(
#else
TORCH_CHECK(
(mat1.scalar_type() == at::kBFloat16 ||
mat1.scalar_type() == at::kHalf) &&
mat1.scalar_type() == at::kHalf ||
mat1.scalar_type() == at::kFloat) &&
mat2.scalar_type() == mat1.scalar_type() &&
result.scalar_type() == mat1.scalar_type(),
"mkldnn_matmul: only enabled for bf16 and fp16 path");
if (mat1.scalar_type() == at::kBFloat16) {
if (mat1.scalar_type() == at::kBFloat16 || mat1.scalar_type() == at::kFloat) {
TORCH_CHECK(
mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, or AWS Graviton3");
Expand All @@ -230,6 +269,7 @@ void mkldnn_matmul(
// but mkldnn matmul primitive only support bias be 1-D tensors
// to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
if (mat1.scalar_type() == at::kFloat) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
// If alpha = 0, dose not need actually do gemm computation
if (alpha == 0)
return;
Expand Down Expand Up @@ -340,11 +380,26 @@ bool use_mkldnn_fp16_matmul(
checksize(mat1, mat2));
}

bool use_mkldnn_lower_precision_matmul(
bool use_mkldnn_bf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {

return (
use_mkldnn_bf32_matmul() &&
mat1.scalar_type() == kFloat &&
mat2.scalar_type() == kFloat &&
(!result.defined() || result.scalar_type() == kFloat) &&
mat1.numel() != 0 &&
mat2.numel() != 0 &&
checksize(mat1, mat2));
}

bool use_mkldnn_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result));
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result));
}

} // namespace native
Expand Down
21 changes: 20 additions & 1 deletion aten/src/ATen/native/mkldnn/Matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ bool use_mkldnn_fp16_matmul(
const Tensor& mat2,
const Tensor& result_opt);

bool use_mkldnn_bf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result_opt);

// Try running mkldnn optimized gemm, or returns false if naive gemm would be faster
bool mkldnn_bf16_gemm(
TransposeType transa, TransposeType transb,
Expand All @@ -43,7 +48,21 @@ bool mkldnn_fp16_gemm(
float beta,
c10::Half *c, int64_t ldc);

bool use_mkldnn_lower_precision_matmul(
/*
oneDNN implicit reduced precision arithmetic feature
https://github.com/mgouicem/oneDNN/tree/mgouicem/rfcs/implicit_downconvert/rfcs/20210301-computation-datatype
to allow implicitly cast data type from FP32 to BF16 in onednn compute primitives
*/
bool mkldnn_bf32_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
float beta,
float *c, int64_t ldc);

bool use_mkldnn_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result);
Expand Down
13 changes: 13 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum
import operator
Expand All @@ -58,6 +59,7 @@ def tearDown(self):
@dtypes(torch.float, torch.cfloat)
@precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06})
@tf32_on_and_off(5e-3)
@bf32_on_and_off(5e-3)
def test_inner(self, device, dtype):
def check(a_sizes_, b_sizes_):
for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)):
Expand Down Expand Up @@ -605,6 +607,7 @@ def cholesky_test_helper(n, batch_dims, upper):
@skipCPUIfNoLapack
@dtypes(*floating_and_complex_types())
@tf32_on_and_off(0.01)
@bf32_on_and_off(0.01)
def test_old_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

Expand Down Expand Up @@ -5633,6 +5636,7 @@ def maybe_transpose(cond, m):
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_addmm(self, device, dtype):
self._test_addmm_impl(torch.addmm, None, device, dtype)

Expand All @@ -5642,6 +5646,7 @@ def test_addmm(self, device, dtype):
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_addmm_relu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)

Expand All @@ -5651,12 +5656,14 @@ def test_addmm_relu(self, device, dtype):
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
@dtypes(*floating_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_addmm_gelu(self, device, dtype):
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)

@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*floating_and_complex_types())
@tf32_on_and_off(0.005)
@bf32_on_and_off(0.005)
def test_addmm_sizes(self, device, dtype):
for m in [0, 1, 25]:
for n in [0, 1, 10]:
Expand Down Expand Up @@ -5976,6 +5983,7 @@ def int4_mm(a, b_int32, b_scales_and_zeros):
@dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
@dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble)
@tf32_on_and_off(0.01)
@bf32_on_and_off(0.01)
def test_mm(self, device, dtype):
def _test_mm(n, m, p, dtype, genf):
# helper function
Expand Down Expand Up @@ -6155,6 +6163,7 @@ def test_strided_mm_bmm(self, device, dtype):
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
Expand Down Expand Up @@ -6267,6 +6276,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_addbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
Expand Down Expand Up @@ -6340,6 +6350,7 @@ def generate_tensor():
@onlyNativeDeviceTypes
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
@tf32_on_and_off(0.05)
@bf32_on_and_off(0.05)
def test_baddbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
# cuBLAS does not guarantee BFloat16 support on SM < 53.
Expand Down Expand Up @@ -7298,6 +7309,7 @@ def dims_full_for_fn():
self.assertEqual(r0, r1)

@tf32_on_and_off(0.001)
@bf32_on_and_off(0.001)
def test_broadcast_batched_matmul(self, device):
n_dim = random.randint(1, 8)
m_dim = random.randint(1, 8)
Expand Down Expand Up @@ -7627,6 +7639,7 @@ def fn(torchfn, *args):
fn(torch.slogdet, (0, 0)))

@tf32_on_and_off(0.005)
@bf32_on_and_off(0.005)
def test_tensordot(self, device):
a = torch.arange(60., device=device).reshape(3, 4, 5)
b = torch.arange(24., device=device).reshape(4, 3, 2)
Expand Down
7 changes: 6 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.types import _TensorOrTensors

from torch.testing._internal.common_mkldnn import bf32_on_and_off

AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()

Expand Down Expand Up @@ -8165,6 +8165,7 @@ def _test_module_empty_inputs(self, module, inputs):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
@tf32_on_and_off()
@bf32_on_and_off()
def test_affine_2d_rotate0(self, device):
# scipy before 1.0.0 do not support homogeneous coordinate
# scipy.ndimage.affine_transform, so we need to skip.
Expand Down Expand Up @@ -8204,6 +8205,7 @@ def test_affine_2d_rotate0(self, device):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
@tf32_on_and_off(0.001)
@bf32_on_and_off(0.001)
def test_affine_2d_rotate90(self, device):
# scipy before 1.0.0 do not support homogeneous coordinate
# scipy.ndimage.affine_transform, so we need to skip.
Expand Down Expand Up @@ -8252,6 +8254,7 @@ def test_affine_2d_rotate90(self, device):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
@tf32_on_and_off(0.005)
@bf32_on_and_off(0.005)
def test_affine_2d_rotate45(self, device):
# scipy before 1.0.0 do not support homogeneous coordinate
# scipy.ndimage.affine_transform, so we need to skip.
Expand Down Expand Up @@ -8307,6 +8310,7 @@ def test_avg_pool_large_tensor(self, device):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
@tf32_on_and_off(0.005)
@bf32_on_and_off(0.005)
def test_affine_2d_rotateRandom(self, device):
# scipy before 1.0.0 do not support homogeneous coordinate
# scipy.ndimage.affine_transform, so we need to skip.
Expand Down Expand Up @@ -8358,6 +8362,7 @@ def test_affine_2d_rotateRandom(self, device):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
@tf32_on_and_off(0.005)
@bf32_on_and_off(0.005)
def test_affine_3d_rotateRandom(self, device):
# scipy before 1.0.0 do not support homogeneous coordinate
# scipy.ndimage.affine_transform, so we need to skip.
Expand Down

0 comments on commit 0ae952d

Please sign in to comment.