Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
e6d2625
Update
zhuhaozhe May 28, 2024
a197dd0
Update
zhuhaozhe May 30, 2024
7e0bd49
Update
zhuhaozhe Jun 4, 2024
e01772b
Update
zhuhaozhe Jun 4, 2024
1dde237
Update
zhuhaozhe Jun 5, 2024
ed0c68f
Update
zhuhaozhe Jun 5, 2024
ab13767
Update
zhuhaozhe Aug 1, 2024
4b48bd6
Update
zhuhaozhe Aug 1, 2024
0110d1c
Update
zhuhaozhe Aug 6, 2024
4a4787e
Update
zhuhaozhe Aug 6, 2024
a92d816
Update
zhuhaozhe Aug 7, 2024
c52eb08
Update
zhuhaozhe Aug 7, 2024
8d0c715
Update
zhuhaozhe Aug 29, 2024
1eca8f9
Update
zhuhaozhe Sep 2, 2024
6f97328
Update
zhuhaozhe Sep 4, 2024
3d53e42
Update
zhuhaozhe Sep 5, 2024
0b7627f
Update
zhuhaozhe Sep 9, 2024
c60f8c7
Update
zhuhaozhe Sep 9, 2024
4d444c6
Update
zhuhaozhe Sep 9, 2024
5f8593d
Update
zhuhaozhe Sep 10, 2024
ce393da
Update
zhuhaozhe Sep 10, 2024
4b775ae
Update
zhuhaozhe Sep 11, 2024
e5e380b
Update
zhuhaozhe Sep 12, 2024
fd113f8
Update
zhuhaozhe Sep 19, 2024
bede96d
Update
zhuhaozhe Sep 20, 2024
cca2705
Update
zhuhaozhe Sep 20, 2024
02a0901
Update
zhuhaozhe Sep 23, 2024
570e6ab
Update
zhuhaozhe Sep 23, 2024
bc70132
Update
zhuhaozhe Sep 23, 2024
14b0326
Update
zhuhaozhe Sep 24, 2024
ca9cf2a
Update
zhuhaozhe Sep 25, 2024
4a3016b
Update
zhuhaozhe Nov 1, 2024
a037113
Update
zhuhaozhe Nov 1, 2024
ad37017
Update on "[inductor] enable bf32 for mkldnn linear pointwise/binary …
zhuhaozhe Nov 29, 2024
0b4c06f
Update
yanbing-j Dec 5, 2024
2183c97
Update
yanbing-j Dec 23, 2024
ae2f55a
Update
yanbing-j Dec 24, 2024
154dc67
Update
yanbing-j Dec 25, 2024
0d70af1
Update
yanbing-j Dec 26, 2024
bc7cfed
Update
yanbing-j Jan 3, 2025
78f8fa6
Update
yanbing-j Jan 20, 2025
c3c0833
Update
yanbing-j Feb 6, 2025
be1ae9b
Update
yanbing-j Feb 8, 2025
bf59db7
Update
yanbing-j Feb 8, 2025
a3cdf4e
Update
yanbing-j Mar 7, 2025
111275b
Update
yanbing-j Mar 10, 2025
320136d
Update
yanbing-j Mar 13, 2025
82443e8
Update
yanbing-j Mar 28, 2025
f357a54
Update
yanbing-j Apr 30, 2025
d19fb0c
Update
yanbing-j May 7, 2025
fbef203
Update
yanbing-j May 7, 2025
2a4770f
Update
yanbing-j May 9, 2025
ab7e481
Update
yanbing-j May 10, 2025
ac9ed89
Update
yanbing-j May 14, 2025
b8ad515
Update
yanbing-j May 29, 2025
c228966
Update
yanbing-j Jun 13, 2025
a942b01
Update
yanbing-j Jun 19, 2025
8da6226
Update
yanbing-j Jun 20, 2025
e55b236
Update
yanbing-j Jun 21, 2025
a4f2f6d
Update
yanbing-j Jun 22, 2025
ee48561
Update
yanbing-j Jun 23, 2025
600b435
Update
yanbing-j Jun 25, 2025
79e8dfd
Update
yanbing-j Jun 25, 2025
3776c31
Update
yanbing-j Jun 25, 2025
5761f5d
Update
yanbing-j Jun 26, 2025
2192f93
Update
yanbing-j Jun 30, 2025
682b6c5
Update
yanbing-j Jul 1, 2025
1d12fbb
Update
yanbing-j Jul 1, 2025
c7e8536
Update
yanbing-j Jul 1, 2025
2d2956b
Update
yanbing-j Jul 3, 2025
fc2e94e
Update
yanbing-j Jul 3, 2025
63915d3
Update
yanbing-j Jul 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aten/src/ATen/native/mkldnn/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,

namespace at::native {

static bool use_mkldnn_bf32_linear() {
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
mkldnn_bf16_device_check();
}

Tensor mkldnn_linear(
const Tensor& self,
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
Expand Down Expand Up @@ -251,7 +256,9 @@ Tensor mkldnn_linear_pointwise(
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
op_attr = it->second(scalars, algorithm);
}

if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
Expand Down Expand Up @@ -341,6 +348,10 @@ Tensor mkldnn_linear_pointwise_binary(
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
auto aprop_kind = ideep::prop_kind::forward_inference;

if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}

if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
Expand Down
24 changes: 20 additions & 4 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def matcher_check_fn():


class TestPatternMatcher(TestPatternMatcherBase):
@bf32_on_and_off()
def test_linear_unary(self, device="cpu"):
self.device = device

Expand Down Expand Up @@ -729,6 +730,8 @@ def forward(self, x):
dtypes.append(torch.bfloat16)
if is_mkldnn_fp16_supported(self.device):
dtypes.append(torch.float16)
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
dtypes.append(torch.float32)
options = itertools.product(unary_list, [True, False], dtypes)
for unary_fn, bias, dtype in options:
metrics.reset()
Expand All @@ -739,7 +742,7 @@ def forward(self, x):

def matcher_check_fn():
match_nodes = unary_list[unary_fn]
if self._check_unary_is_decomposed(unary_fn):
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self.assertEqual(
Expand All @@ -751,9 +754,14 @@ def matcher_check_fn():
)

self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
# only generated 1 kernel for "to"
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
# only generated 1 kernel for "to_dtype"
expected_kernel_count = 2 if TEST_ACL else 1
if dtype == torch.float32:
# In BF32, input is float32, will not generate kernel for "to_dtype"
expected_kernel_count -= 1
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)

@bf32_on_and_off()
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_fp32(self, device="cpu"):
self.device = device
Expand Down Expand Up @@ -901,6 +909,7 @@ def matcher_check_fn():
# 1 kernel for "to_lowp", 2 kernels for unary ops
self.assertEqual(metrics.generated_kernel_count, 3)

@bf32_on_and_off()
def test_linear_binary(self, device="cpu"):
self.device = device

Expand All @@ -922,6 +931,8 @@ def forward(self, x, y):
dtypes.append(torch.bfloat16)
if is_mkldnn_fp16_supported(self.device):
dtypes.append(torch.float16)
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
dtypes.append(torch.float32)
options = itertools.product(
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
)
Expand Down Expand Up @@ -958,7 +969,12 @@ def matcher_check_fn():
matcher_check_fn,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
# only generated 1 kernel for "to_dtype"
expected_kernel_count = 2 if TEST_ACL else 1
if dtype == torch.float32:
# In BF32, input is float32, will not generate kernel for "to_dtype"
expected_kernel_count -= 1
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)

def test_linear_binary_broadcast_shapes(self, device="cpu"):
self.device = device
Expand Down
20 changes: 16 additions & 4 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,10 +1228,15 @@ def is_const_or_cat_by_const(weight):
torch.bfloat16,
torch.float16,
)
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
use_bf16_for_fp32_weight = (
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
)
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
if (
not is_lp_weight
not compute_with_lp
and not mkldnn._is_mkldnn_acl_supported()
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
):
Expand Down Expand Up @@ -1444,16 +1449,23 @@ def linear(match, *args, **kwargs):
torch.bfloat16,
torch.float16,
)
bf32_matmul_enabled = (
torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
)
use_bf16_for_fp32_weight = (
bf32_matmul_enabled and weight_dtype == torch.float32
)
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size):
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
)
packed_weight_node = mkldnn_device_op.pack_linear_weight(
graph, is_lp_weight, transpose_weight_node, batch_size
graph, compute_with_lp, transpose_weight_node, batch_size
)
packed_linear_node = mkldnn_device_op.pack_linear(
graph, is_lp_weight, batch_size, input, packed_weight_node, bias
graph, compute_with_lp, batch_size, input, packed_weight_node, bias
)

linear_node.replace_all_uses_with(packed_linear_node)
Expand Down
Loading