Skip to content

Commit

Permalink
[Quant][Inductor] Enable quantization conv_unary(relu) pattern fusion…
Browse files Browse the repository at this point in the history
… inside inductor (#105455)

**Summary**
Enable the `dequant-conv2d-unary_postop(relu)-quant` pattern fusion and lowering inside inductor.

**Test Plan**
```
clear && python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_unary
```

Pull Request resolved: #105455
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588, #104590
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Aug 25, 2023
1 parent 4f3ff16 commit c1e0fb7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 20 deletions.
39 changes: 31 additions & 8 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
torch.nn.Tanh,
]

quantization_unary_list = {
None: 0,
torch.nn.ReLU(): 1,
}

# The dict value is (match_count, match_nodes, inplace)
binary_list = {
lambda x, y: torch.add(x, y): (1, 2, False), # call_function
Expand Down Expand Up @@ -392,37 +397,55 @@ def test_qconv2d_unary(self):
class M(torch.nn.Module):
def __init__(
self,
auto_insert_channel_last_node=False,
unary_fn,
**kwargs,
):
super().__init__()
if auto_insert_channel_last_node:
if (
"auto_insert_channel_last_node" in kwargs
and kwargs["auto_insert_channel_last_node"]
):
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
else:
self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.unary_fn = unary_fn

def forward(self, x):
return self.conv(x)
x = self.conv(x)
return self.unary_fn(x) if self.unary_fn else x

options = itertools.product(
quantization_unary_list.keys(),
[True, False], # auto_insert_channel_last_node
)

for auto_insert_channel_last_node in [True, False]:
mod = M(auto_insert_channel_last_node).eval()
for unary_fn, auto_insert_channel_last_node in options:
if auto_insert_channel_last_node and unary_fn is not None:
# Skip trivial test combinations to reduce test time.
continue
mod = M(
unary_fn, auto_insert_channel_last_node=auto_insert_channel_last_node
).eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)

# Totally 4 pattern_matcher_count, 17 pattern_matcher_nodes
# Totally pattern_matcher_count 4,
# pattern_matcher_nodes 17 + 1 for optional(unary_post_op)
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# [qconv2d_pointwise_default, optional(unary_post_op), div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
4,
17,
17 + quantization_unary_list[unary_fn],
check_quantization=True,
)

Expand Down
46 changes: 34 additions & 12 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@
)


def generate_pattern_with_unary(computation_call, unary_post_op):
if unary_post_op is not None:
return CallFunction(
unary_post_op,
computation_call,
)
return computation_call


def generate_pattern_with_output_quant(computation_call):
"""
quantize output:
Expand Down Expand Up @@ -170,24 +179,37 @@ def qconv(match: Match, *args, **kwargs):
return qconv


def _register_quantization_lowerings():
def _register_quantization_unary_fusion():
class UnaryAttr:
def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""

# Register dq-conv2d-q pattern for ExternKernel Lowering
quantize_conv_output_pattern_pt2e = generate_pattern_with_output_quant(
dequantize_qconv_pt2e_pattern
)
_register_quantized_conv_lowering(
quantize_conv_output_pattern_pt2e,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
False, # fp32_output
UnaryAttr("none", [], ""), # unary_attr
)
unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
dequantize_qconv_pt2e_pattern
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
dequantize_qconv_pt2e_pattern, aten.relu.default
)
),
}

for unary_attr, patterns in unary_replace_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_quantized_conv_lowering(
patterns,
1 if unary_attr.op_name != "none" else 2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
False, # fp32_output
unary_attr, # unary_attr
)


def _register_quantization_lowerings():
_register_quantization_unary_fusion()


def _is_valid_dequant_promotion_pattern(match):
Expand Down

0 comments on commit c1e0fb7

Please sign in to comment.