From d558f55db9ae9966ad5c86a5c412776c1826efb7 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Tue, 28 Nov 2023 08:46:20 +0800 Subject: [PATCH] Enable the Inductor Lowering of QConv2d post op hardtanh ghstack-source-id: ecc8e9434ffd69a36440b0cb9d5a4f72f766adc5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114580 --- test/inductor/test_mkldnn_pattern_matcher.py | 91 +++++++++++++++----- torch/_inductor/fx_passes/quantization.py | 29 ++++++- 2 files changed, 95 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index f2395c319c3bf..e7d986900ed2f 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import contextlib +import copy import itertools import unittest @@ -495,7 +496,11 @@ def test_qconv2d_int8_mixed_bf16(self): """ self._qconv2d_cpu_test_helper(int8_mixed_bf16=True) - def _qconv2d_unary_cpu_test_helper(self, int8_mixed_bf16=False): + def _qconv2d_unary_cpu_test_helper( + self, + int8_mixed_bf16=False, + unary_op=torch.nn.ReLU(), + ): class M(torch.nn.Module): def __init__( self, @@ -503,9 +508,9 @@ def __init__( ): super().__init__() self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) - self.unary_fn = torch.nn.ReLU() + self.unary_fn = copy.deepcopy(unary_op) self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) - self.unary_fn2 = torch.nn.ReLU() + self.unary_fn2 = copy.deepcopy(unary_op) def forward(self, x): tmp = self.unary_fn(self.conv(x)) @@ -549,6 +554,24 @@ def test_qconv2d_relu_int8_mixed_bf16(self): """ self._qconv2d_unary_cpu_test_helper(int8_mixed_bf16=True) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_relu6_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qconv2d_hardtanh_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) + def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -735,26 +758,26 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfRocm - def test_qat_qconv2d_relu(self): - r""" - This testcase will quantize Conv2d->ReLU pattern with qat flow. - """ - + def _qat_qconv2d_unary_cpu_test_helper( + self, + unary_op=torch.nn.ReLU(), + ): class M(torch.nn.Module): def __init__( self, **kwargs, ): super().__init__() - self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) - self.unary_fn = torch.nn.ReLU() - self.bn = torch.nn.BatchNorm2d(128) + self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.bn = torch.nn.BatchNorm2d(3) + self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1) + self.unary_fn2 = copy.deepcopy(unary_op) + self.bn2 = torch.nn.BatchNorm2d(3) def forward(self, x): - return self.unary_fn(self.bn(self.conv(x))) + tmp = self.unary_fn(self.bn(self.conv(x))) + return self.unary_fn2(self.bn2(self.conv2(tmp))) mod = M() v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1) @@ -763,15 +786,11 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 - ) - self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6 + counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 8) + self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) self._test_common( mod, @@ -781,6 +800,36 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_relu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper() + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_relu6(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) + + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qat_qconv2d_hardtanh(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern with qat flow. + """ + + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index ad91483049580..7273f36ec226e 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -169,10 +169,17 @@ def generate_pattern_with_binary( def generate_pattern_with_unary(computation_call, unary_post_op): if unary_post_op is not None: - return CallFunction( - unary_post_op, - computation_call, - ) + if unary_post_op == aten.hardtanh.default: + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + else: + return CallFunction( + unary_post_op, + computation_call, + ) return computation_call @@ -286,6 +293,11 @@ def qconv(match: Match, *args, **kwargs): assert ( kwargs["attr"] == "none" ) # Expected no post op fused in weight prepack phase + if unary_attr.op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + unary_attr.scalars_attr = [min_value, max_value] + computation_args = ( x, x_scale, @@ -506,6 +518,12 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): ), dtype=original_pattern_output_dtype, ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( + generate_pattern_with_unary( + dequantize_qconv_pt2e_pattern, aten.hardtanh.default + ), + dtype=original_pattern_output_dtype, + ), } for unary_attr, patterns in conv_unary_replace_patterns.items(): @@ -524,6 +542,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): UnaryAttr("relu", [], ""): generate_pattern_with_unary( dequantize_qconv_pt2e_pattern, aten.relu.default ), + UnaryAttr("hardtanh", [], ""): generate_pattern_with_unary( + dequantize_qconv_pt2e_pattern, aten.hardtanh.default + ), } for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():