Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
zhuhaozhe committed Jun 21, 2024
1 parent 5957453 commit fa076d6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
22 changes: 18 additions & 4 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,21 @@ def forward(self, x):

def test_linear_add_bias(self):
class M(torch.nn.Module):
def __init__(self, dtype, unary_fn):
def __init__(self, dtype, unary_fn, cast_bias):
super().__init__()
self.linear1 = torch.nn.Linear(10, 64, bias=False)
self.bias1 = torch.randn(64)
self.linear2 = torch.nn.Linear(10, 64, bias=False)
self.bias2 = torch.randn(64)
if cast_bias:
self.bias1 = self.bias1.to(dtype=dtype)
self.bias2 = self.bias2.to(dtype=dtype)
self.unary_fn = unary_fn

def forward(self, x):
a = self.linear1(x) + self.bias1
b = self.linear2(x) + self.bias2
return self.unary_fn(a).to(dtype), self.unary_fn(b).to(dtype)
return self.unary_fn(a), self.unary_fn(b)

dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
Expand All @@ -427,7 +430,7 @@ def forward(self, x):
options = itertools.product(unary_list, dtypes)
for unary_fn, dtype in options:
metrics.reset()
mod = M(dtype, unary_fn).eval()
fold_mod = M(dtype, unary_fn, cast_bias=True).eval()
v = torch.randn(2, 10)
matcher_count = 3
# Add 1 for weight packing pass, add 2 for bias folding pass per linear.
Expand All @@ -437,9 +440,20 @@ def forward(self, x):
matcher_nodes += 2
# we have 2 linears, so we double the matcher_count/nodes
self._test_common(
mod, (v,), matcher_count * 2, matcher_nodes * 2, check_autocast=dtype
fold_mod,
(v,),
matcher_count * 2,
matcher_nodes * 2,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 1)
# we won't fold the bias if bias is not same dtype with weight
# https://github.com/pytorch/pytorch/pull/129138
metrics.reset()
mod = M(dtype, unary_fn, cast_bias=False).eval()
self._test_common(mod, (v,), 2, 2, check_autocast=dtype)
# 1 kernel for "to_lowp", 2 kernels for unary ops
self.assertEqual(metrics.generated_kernel_count, 3)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand Down
11 changes: 2 additions & 9 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def is_linear_add_bias(match):
torch.bfloat16,
torch.float16,
)
if bias_meta.dtype not in (torch.bfloat16, torch.float16, torch.float):
if bias_meta.dtype != weight_meta.dtype:
return False
return (
linear_node.args[2] is None
Expand All @@ -827,15 +827,8 @@ def linear_bias_pattern(match, *args):
graph = match.graph
add_node = match.output_node()
linear_node = add_node.args[0]
w_node = linear_node.args[1].args[0].args[0]
w_dtype = w_node.meta.get("val").dtype
new_args = list(linear_node.args)
bias_node = add_node.args[1]
to_dtype = graph.call_function(
prims.convert_element_type.default,
(bias_node, w_dtype),
)
new_args[2] = to_dtype
new_args[2] = add_node.args[1]
repl = graph.call_function(
mkldnn._linear_pointwise.default, tuple(new_args)
)
Expand Down

0 comments on commit fa076d6

Please sign in to comment.