Skip to content

Commit

Permalink
[Quant][Inductor] Enable dequant promotion inside inductor (#104590)
Browse files Browse the repository at this point in the history
**Summary**
Enable the `dequant pattern` promotion pass in inductor. Since in the qconv weight prepack pass, we will match the `dequant->conv2d` pattern. If the `dequant pattern` has multi user nodes, it will fail to be matched.
Taking the example of
```
        conv1
       /     \
   conv2    conv3
```
After quantization flow, it will generate pattern as
```
      dequant1
          |
        conv1
          |
        quant2
          |
       dequant2
       /     \
   conv2    conv3
```
We need to duplicate `dequant2` into `dequant2` and `dequant3`, in order to make `dequant2->conv2` and  `dequant3->conv3`  pattern matched.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_dequant_promotion
```

Pull Request resolved: #104590
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #104580, #104581, #104588
  • Loading branch information
leslie-fang-intel authored and voznesenskym committed Aug 27, 2023
1 parent e679cd8 commit 4c9ab56
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,45 @@ def forward(self, x):
check_quantization=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_dequant_promotion(self):
class M(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)

def forward(self, x):
temp = self.conv1(x)
temp = self.conv2(temp) + self.conv3(temp)
return temp

mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
# For now, we have annotated conv_add in x86InductorQuantizer. But we didn't implement the lowering.
# TODO <leslie>: Modify the pattern matcher count after we implement the qconv2d_add lowering.
# Totally 10 pattern_matcher_count, 43 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-conv pattern matched in quantization weight prepack * 3
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 4. Quantization fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
10,
43,
check_quantization=True,
)

# https://github.com/pytorch/pytorch/issues/99841.
def test_hardtanh_pattern_fallback(self):
class Model(torch.nn.Module):
Expand Down
60 changes: 60 additions & 0 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools

import torch
Expand Down Expand Up @@ -189,6 +190,62 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
)


def _is_valid_dequant_promotion_pattern(match):
mul_node = match.output_node()
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
if (
mul_node.target is aten.mul.Tensor
and sub_node.target is aten.sub.Tensor
and to_fp32_node.target is prims.convert_element_type.default
and len(list(mul_node.users)) > 1
):
# dequant pattern has more than 1 users to be promoted
return True
return False


def _register_dequant_promotion_pass(pattern, pass_number):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_promotion_pattern,
pass_number=pass_number,
)
def dequant_promotion(match: Match, *args, **kwargs):
# If dequant pattern used by multiply nodes,
# we will do dequant promotion. So each user node has a seperate dequant pattern connected.
def clone_to_new_node(graph, source_node, user_node):
assert (
source_node.op == "call_function"
), "clone_to_new_node only support node.op call_function"
with graph.inserting_before(user_node):
new_node = graph.call_function(
source_node.target,
args=source_node.args,
kwargs=source_node.kwargs,
)
new_node.meta = copy.copy(source_node.meta)
user_node.replace_input_with(source_node, new_node)
return new_node

mul_node = match.output_node()
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
assert mul_node.target is aten.mul.Tensor
assert sub_node.target is aten.sub.Tensor
assert to_fp32_node.target is prims.convert_element_type.default

graph = match.graph
user_node_list = list(mul_node.users)
for user_node in user_node_list:
# Step1: Duplicate the mul node
new_mul_node = clone_to_new_node(graph, mul_node, user_node)
# Step2: Duplicate the sub node
new_sub_node = clone_to_new_node(graph, sub_node, new_mul_node)
# Step3: Duplicate the to_fp32 node
_ = clone_to_new_node(graph, to_fp32_node, new_sub_node)


def _is_valid_dequant_conv2d_pattern(match):
# Here we do some further check to ensure:
# 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
Expand Down Expand Up @@ -380,6 +437,9 @@ def _generate_qconv_weight_prepack_patterns():

@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
_register_dequant_promotion_pass(
dequantize_per_tensor_activation_pattern, pass_number=0
) # pass_number=0 to run before weight prepack
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns()
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
Expand Down

0 comments on commit 4c9ab56

Please sign in to comment.