Skip to content

Commit

Permalink
[Quant][Inductor] Enable quantization conv_binary(add/add_relu) patte…
Browse files Browse the repository at this point in the history
…rn fusion inside inductor (#105456)

**Summary**
Enable the `dequant-conv2d-binary_postop(add)-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_binary
```

Pull Request resolved: #105456
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590, #105455
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Aug 25, 2023
1 parent d2105a8 commit 1374974
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 6 deletions.
74 changes: 69 additions & 5 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
lambda x, y: x.sub_(y): (1, 2, True), # call_method
}

quantization_binary_list = [
lambda x, y: torch.add(x, y),
lambda x, y: x.add(y),
lambda x, y: x.add_(y),
]


@config.patch({"freezing": True})
class TestPatternMatcherBase(TestCase):
Expand Down Expand Up @@ -391,6 +397,62 @@ def forward(self, x):
match_nodes = 19
self._test_common(mod, (v,), match_count, match_nodes, rtol=1e-2, atol=1e-2)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_binary(self):
class M(torch.nn.Module):
def __init__(
self,
binary_fn,
has_relu,
**kwargs,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.binary_fn = binary_fn
self.has_relu = has_relu
self.relu = torch.nn.ReLU()

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
if self.has_relu:
return self.relu(self.binary_fn(x1, x2))
else:
return self.binary_fn(x1, x2)

options = itertools.product(
quantization_binary_list,
[True, False], # has_relu
)

for binary_fn, has_relu in options:
mod = M(binary_fn, has_relu=has_relu).eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 9 pattern_matcher_count, 41 pattern_matcher_nodes + 1 optional(unary post op)
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 4. Quantization fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, optional(relu),
# mul_6, round_4, add_4, clamp_min_3, clamp_max_3, convert_element_type_6]
self._test_common(
mod,
(v,),
9,
42 if has_relu else 41,
check_quantization=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_unary(self):
Expand Down Expand Up @@ -468,9 +530,8 @@ def forward(self, x):

mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
# For now, we have annotated conv_add in x86InductorQuantizer. But we didn't implement the lowering.
# TODO <leslie>: Modify the pattern matcher count after we implement the qconv2d_add lowering.
# Totally 10 pattern_matcher_count, 43 pattern_matcher_nodes

# Totally 11 pattern_matcher_count, 54 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
Expand All @@ -480,11 +541,14 @@ def forward(self, x):
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 4. Quantization fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
self._test_common(
mod,
(v,),
10,
43,
11,
54,
check_quantization=True,
)

Expand Down
131 changes: 131 additions & 0 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@
Arg(), # algorithm
)

dequantize_accum_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
KeywordArg("accum"),
KeywordArg("accum_dq_dtype"),
),
KeywordArg("accum_zp"),
),
KeywordArg("accum_scale"),
)


def generate_pattern_with_binary(binary_post_op, computation_call, extra_input_pattern):
return CallFunction(
binary_post_op,
computation_call,
extra_input_pattern,
)


def generate_pattern_with_unary(computation_call, unary_post_op):
if unary_post_op is not None:
Expand Down Expand Up @@ -179,6 +201,67 @@ def qconv(match: Match, *args, **kwargs):
return qconv


def _register_quantized_conv_binary_lowering(
pattern,
pass_number,
computation_op,
fp32_output,
binary_unary_attr,
):
@register_lowering_pattern(pattern, pass_number=pass_number)
def qconv_binary(match: Match, *args, **kwargs):
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
accum, accum_scale, accum_zp = (
kwargs["accum"],
kwargs["accum_scale"],
kwargs["accum_zp"],
)
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
b, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
kwargs["groups"],
)
o_inv_scale, o_zero_point = (
kwargs["o_inv_scale"],
kwargs["o_zp"],
)

computation_args = (
x,
x_scale,
x_zp,
accum,
accum_scale,
accum_zp,
packed_weight,
w_scale,
w_zp,
b,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
fp32_output,
binary_unary_attr.binary_op_name,
binary_unary_attr.alpha,
binary_unary_attr.unary_op_name,
binary_unary_attr.scalars_attr,
binary_unary_attr.algorithm_attr,
)
return L[computation_op](*computation_args)

return qconv_binary


def _register_quantization_unary_fusion():
class UnaryAttr:
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
Expand Down Expand Up @@ -208,8 +291,56 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
)


def _register_quantization_binary_fusion():
class BinaryUnaryAttr:
def __init__(
self,
binary_op_name: str,
alpha=None,
unary_op_name: str = "none",
scalars_attr=None,
algorithm_attr=None,
):
self.binary_op_name = binary_op_name
self.alpha = alpha if alpha else 1.0
self.unary_op_name = unary_op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""

binary_replace_patterns = {
BinaryUnaryAttr("add", 1.0, "none", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_binary(
aten.add.Tensor,
dequantize_qconv_pt2e_pattern,
dequantize_accum_pattern,
)
),
BinaryUnaryAttr("add", 1.0, "relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
dequantize_qconv_pt2e_pattern,
dequantize_accum_pattern,
),
aten.relu.default,
)
),
}

for binary_unary_attr, patterns in binary_replace_patterns.items():
# Register qconv2d_binary_unary pattern for ExternKernel Lowering
_register_quantized_conv_binary_lowering(
patterns,
0 if binary_unary_attr.unary_op_name != "none" else 1, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
False, # fp32_output
binary_unary_attr, # binary_unary_attr
)


def _register_quantization_lowerings():
_register_quantization_unary_fusion()
_register_quantization_binary_fusion()


def _is_valid_dequant_promotion_pattern(match):
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ def run_node(self, n: torch.fx.Node):
torch.ops.mkldnn._linear_pointwise.binary,
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.onednn.qconv2d_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.binary,
]
if torch._C.has_mkl:
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
Expand Down

0 comments on commit 1374974

Please sign in to comment.