Skip to content

Commit

Permalink
Enable the Inductor Lowering of QConv2d post op hardtanh
Browse files Browse the repository at this point in the history
ghstack-source-id: ecc8e9434ffd69a36440b0cb9d5a4f72f766adc5
Pull Request resolved: #114580
  • Loading branch information
leslie-fang-intel committed Nov 28, 2023
1 parent 3cfd908 commit d558f55
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 25 deletions.
91 changes: 70 additions & 21 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import contextlib
import copy
import itertools
import unittest

Expand Down Expand Up @@ -495,17 +496,21 @@ def test_qconv2d_int8_mixed_bf16(self):
"""
self._qconv2d_cpu_test_helper(int8_mixed_bf16=True)

def _qconv2d_unary_cpu_test_helper(self, int8_mixed_bf16=False):
def _qconv2d_unary_cpu_test_helper(
self,
int8_mixed_bf16=False,
unary_op=torch.nn.ReLU(),
):
class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
self.unary_fn = torch.nn.ReLU()
self.unary_fn = copy.deepcopy(unary_op)
self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1)
self.unary_fn2 = torch.nn.ReLU()
self.unary_fn2 = copy.deepcopy(unary_op)

def forward(self, x):
tmp = self.unary_fn(self.conv(x))
Expand Down Expand Up @@ -549,6 +554,24 @@ def test_qconv2d_relu_int8_mixed_bf16(self):
"""
self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qconv2d_relu6_cpu(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern.
"""
self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qconv2d_hardtanh_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
"""
self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())

def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
r"""
This testcase will quantize a Conv2d->Add pattern as:
Expand Down Expand Up @@ -735,26 +758,26 @@ def matcher_check_fn():
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_relu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern with qat flow.
"""

def _qat_qconv2d_unary_cpu_test_helper(
self,
unary_op=torch.nn.ReLU(),
):
class M(torch.nn.Module):
def __init__(
self,
**kwargs,
):
super().__init__()
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
self.unary_fn = torch.nn.ReLU()
self.bn = torch.nn.BatchNorm2d(128)
self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
self.unary_fn = copy.deepcopy(unary_op)
self.bn = torch.nn.BatchNorm2d(3)
self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
self.unary_fn2 = copy.deepcopy(unary_op)
self.bn2 = torch.nn.BatchNorm2d(3)

def forward(self, x):
return self.unary_fn(self.bn(self.conv(x)))
tmp = self.unary_fn(self.bn(self.conv(x)))
return self.unary_fn2(self.bn2(self.conv2(tmp)))

mod = M()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
Expand All @@ -763,15 +786,11 @@ def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 8)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)

self._test_common(
mod,
Expand All @@ -781,6 +800,36 @@ def matcher_check_fn():
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_relu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper()

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_relu6(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qat_qconv2d_hardtanh(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
Expand Down
29 changes: 25 additions & 4 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,17 @@ def generate_pattern_with_binary(

def generate_pattern_with_unary(computation_call, unary_post_op):
if unary_post_op is not None:
return CallFunction(
unary_post_op,
computation_call,
)
if unary_post_op == aten.hardtanh.default:
return CallFunction(
aten.clamp_max,
CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
KeywordArg("max_value"),
)
else:
return CallFunction(
unary_post_op,
computation_call,
)
return computation_call


Expand Down Expand Up @@ -286,6 +293,11 @@ def qconv(match: Match, *args, **kwargs):
assert (
kwargs["attr"] == "none"
) # Expected no post op fused in weight prepack phase
if unary_attr.op_name == "hardtanh":
min_value = kwargs.get("min_value")
max_value = kwargs.get("max_value")
unary_attr.scalars_attr = [min_value, max_value]

computation_args = (
x,
x_scale,
Expand Down Expand Up @@ -506,6 +518,12 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
),
dtype=original_pattern_output_dtype,
),
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
dequantize_qconv_pt2e_pattern, aten.hardtanh.default
),
dtype=original_pattern_output_dtype,
),
}

for unary_attr, patterns in conv_unary_replace_patterns.items():
Expand All @@ -524,6 +542,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
dequantize_qconv_pt2e_pattern, aten.relu.default
),
UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary(
dequantize_qconv_pt2e_pattern, aten.hardtanh.default
),
}

for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
Expand Down

0 comments on commit d558f55

Please sign in to comment.