diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 9c39f1c14001..756de35df84c 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1688,13 +1688,10 @@ def matcher_check_fn(): to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2) self.assertEqual( counters["inductor"]["qlinear_binary_matcher_nodes"], - (4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary, + 5 + 2 * use_relu + to_bf16_after_binary, ) - is_qat_list = [False, True] - is_dynamic_list = [False, True] - cases = itertools.product(is_qat_list, is_dynamic_list) - for is_qat, is_dynamic in cases: + for is_qat in [False, True]: self._test_common( mod, (v,), @@ -1702,7 +1699,6 @@ def matcher_check_fn(): check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, matcher_check_fn=matcher_check_fn, is_qat=is_qat, - is_dynamic=is_dynamic, ) @skipIfNoDynamoSupport diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 3202900a2862..218b30bd9e33 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1198,63 +1198,44 @@ def test_linear(self): node_list, ) - def _test_linear_unary_helper( - self, - post_op_module, - post_op_aten, - post_op_aten_inplace, - post_op_algo_list=None, - is_qat=False, - is_dynamic=False, - ): + @skipIfNoX86 + def test_linear_unary(self): """ Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] inplace_list = [True, False] - if post_op_algo_list is None: - post_op_algo_list = [None] - cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) + postop_list = [nn.ReLU, nn.LeakyReLU] # only test two to save time + cases = itertools.product(use_bias_list, inplace_list, postop_list) + post_op_map = { + nn.ReLU: [torch.ops.aten.relu_.default, torch.ops.aten.relu.default], + nn.LeakyReLU: [ + torch.ops.aten.leaky_relu_.default, + torch.ops.aten.leaky_relu.default, + ], + } with override_quantized_engine("x86"), torch.no_grad(): - for use_bias, inplace, post_op_algo in cases: - if inplace and post_op_aten_inplace is None: - continue + for use_bias, inplace, postop in cases: m = TestHelperModules.LinearUnaryModule( - use_bias=use_bias, - postop=post_op_module, - inplace_postop=inplace, - post_op_algo=post_op_algo, + use_bias=use_bias, postop=postop, inplace_postop=inplace ).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config( - is_qat=is_qat, - is_dynamic=is_dynamic, - ) - ) - quantize_per_tensor_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.default - ) - dequantize_per_tensor_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.default + xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { - # one for input of the linear - quantize_per_tensor_op: 1, - dequantize_per_tensor_op: 1, + # one for input and weight of the conv, one for output for the relu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - quantize_per_tensor_op, - dequantize_per_tensor_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - post_op_aten_inplace if inplace else post_op_aten, + post_op_map[postop][0 if inplace else 1], ] self._test_quantizer( m, @@ -1262,70 +1243,47 @@ def _test_linear_unary_helper( quantizer, node_occurrence, node_list, - is_qat=is_qat, ) @skipIfNoX86 - def test_linear_unary(self): - aten = torch.ops.aten - self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) - self._test_linear_unary_helper( - nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default - ) - self._test_linear_unary_helper( - nn.GELU, aten.gelu.default, None, ["none", "tanh"] - ) - - @skipIfNoX86 - def test_linear_unary_qat(self): - aten = torch.ops.aten - self._test_linear_unary_helper( - nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True - ) - self._test_linear_unary_helper( - nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True - ) - self._test_linear_unary_helper( - nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True - ) - - @skipIfNoX86 - def test_linear_unary_dynamic(self): - aten = torch.ops.aten - self._test_linear_unary_helper( - nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True - ) - self._test_linear_unary_helper( - nn.LeakyReLU, - aten.leaky_relu.default, - aten.leaky_relu_.default, - is_dynamic=True, - ) - self._test_linear_unary_helper( - nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True - ) - - @skipIfNoX86 - def test_linear_unary_dynamic_qat(self): - aten = torch.ops.aten - self._test_linear_unary_helper( - nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True - ) - self._test_linear_unary_helper( - nn.LeakyReLU, - aten.leaky_relu.default, - aten.leaky_relu_.default, - is_qat=True, - is_dynamic=True, - ) - self._test_linear_unary_helper( - nn.GELU, - aten.gelu.default, - None, - ["none", "tanh"], - is_qat=True, - is_dynamic=True, - ) + def test_linear_unary_gelu(self): + """ + Test pattern of linear with unary post ops (e.g. gelu) with X86InductorQuantizer. + """ + use_bias_list = [True, False] + postop = nn.GELU + post_op_algorithm = ["none", "tanh"] + cases = itertools.product(use_bias_list, post_op_algorithm) + with override_quantized_engine("x86"), torch.no_grad(): + for use_bias, post_op_algo in cases: + m = TestHelperModules.LinearUnaryModule( + use_bias=use_bias, postop=postop, post_op_algo=post_op_algo + ).eval() + example_inputs = (torch.randn(2, 4),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + # one for input and weight of the conv, one for output for the gelu + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.gelu.default, + ] + self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + ) def _check_annotation_stat(self, gm, expected_stat_dict): # Check expected annotation statistics to ensure the annotation is correct @@ -1344,7 +1302,8 @@ def _check_annotation(node): for op_stat in expected_stat_dict.values(): assert all(v == 0 for v in op_stat.values()) - def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): + @skipIfNoX86 + def test_linear_binary(self): """ Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. @@ -1354,20 +1313,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config( - is_qat=is_qat, - is_dynamic=is_dynamic, - ) - ) - quantize_per_tensor_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.default - ) - dequantize_per_tensor_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.default + xiq.get_default_x86_inductor_quantization_config() ) cases = itertools.product(linear_pos_list, inplace_add_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1379,28 +1325,26 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): node_occurrence = { # Only one 1 q-dq for input of the linear # No q-dq for extra input node of add - quantize_per_tensor_op: 1, - dequantize_per_tensor_op: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: - # convert_pt2e disables duplicate dequant for dynamic quant - num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - quantize_per_tensor_op: 1, - dequantize_per_tensor_op: num_dequant, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - quantize_per_tensor_op, - dequantize_per_tensor_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1412,7 +1356,6 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): quantizer, node_occurrence, node_list, - is_qat=is_qat, )[-1] # One linear and add are fused. The other linear is quantized alone if present aten = torch.ops.aten @@ -1426,22 +1369,6 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): } self._check_annotation_stat(fq_m, expected_annotation_stat) - @skipIfNoX86 - def test_linear_binary(self): - self._test_linear_binary_helper() - - @skipIfNoX86 - def test_linear_binary_qat(self): - self._test_linear_binary_helper(is_qat=True) - - @skipIfNoX86 - def test_linear_binary_dynamic(self): - self._test_linear_binary_helper(is_dynamic=True) - - @skipIfNoX86 - def test_linear_binary_dynamic_qat(self): - self._test_linear_binary_helper(is_qat=True, is_dynamic=True) - @skipIfNoX86 def test_linear_binary2(self): """ @@ -1452,43 +1379,28 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] - is_qat_list = [False, True] - is_dynamic_list = [False, True] - cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - for inplace_add, is_qat, is_dynamic in cases: - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config( - is_qat=is_qat, is_dynamic=is_dynamic - ) - ) + for inplace_add in inplace_add_list: m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() - quantize_per_tensor_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.default - ) - dequantize_per_tensor_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.default - ) # Two q-dq nodes for inputs of linear nodes # No q-dq for extra input node of add node_occurrence = { - quantize_per_tensor_op: 2, - dequantize_per_tensor_op: 2, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - quantize_per_tensor_op, - dequantize_per_tensor_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor, @@ -1513,7 +1425,7 @@ def test_linear_binary2(self): self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 - def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): + def test_linear_binary_unary(self): """ Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. @@ -1525,20 +1437,7 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config( - is_qat=is_qat, - is_dynamic=is_dynamic, - ) - ) - quantize_per_tensor_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.default - ) - dequantize_per_tensor_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.default + xiq.get_default_x86_inductor_quantization_config() ) cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) with override_quantized_engine("x86"), torch.no_grad(): @@ -1552,28 +1451,26 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): node_occurrence = { # Only one q-dq node for input of the linear # No q-dq node for extra input node of add - quantize_per_tensor_op: 1, - dequantize_per_tensor_op: 1, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: - # convert_pt2e disables duplicate dequant for dynamic quant - num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add - quantize_per_tensor_op: 1, - dequantize_per_tensor_op: num_dequant, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ - quantize_per_tensor_op, - dequantize_per_tensor_op, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.aten.add_.Tensor if inplace_add @@ -1601,91 +1498,57 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): } self._check_annotation_stat(fq_m, expected_annotation_stat) - @skipIfNoX86 - def test_linear_binary_unary(self): - self._test_linear_binary_unary_helper() - - @skipIfNoX86 - def test_linear_binary_unary_qat(self): - self._test_linear_binary_unary_helper(is_qat=True) - - @skipIfNoX86 - def test_linear_binary_unary_dynamic(self): - self._test_linear_binary_unary_helper(is_dynamic=True) - - @skipIfNoX86 - def test_linear_binary_unary_dynamic_qat(self): - self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) - @skipIfNoX86 def test_linear_binary_unary_serials(self): """ Test pattern of 2 following up linear add relu with X86InductorQuantizer. """ - is_qat_list = [False, True] - is_dynamic_list = [False, True] - cases = itertools.product(is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): - for is_qat, is_dynamic in cases: - m = TestHelperModules.SerialsLinearAddReLUModule().eval() - example_inputs = (torch.randn(2, 16),) - quantizer = X86InductorQuantizer().set_global( - xiq.get_default_x86_inductor_quantization_config( - is_qat=is_qat, - is_dynamic=is_dynamic, - ) - ) - quantize_per_tensor_op = ( - torch.ops.quantized_decomposed.quantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.quantize_per_tensor.default - ) - dequantize_per_tensor_op = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor - if is_dynamic - else torch.ops.quantized_decomposed.dequantize_per_tensor.default - ) - # convert_pt2e disables duplicate dequant for dynamic quant - num_dequant = 3 if is_dynamic else 4 - node_occurrence = { - # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 - # dequantize_per_tensor: 1 for each linear - # No q-dq for extra input node of add - quantize_per_tensor_op: 3, - dequantize_per_tensor_op: num_dequant, - # quantize_per_channel for weights are const propagated - torch.ops.quantized_decomposed.quantize_per_channel.default: 0, - torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, - } - node_list = [ - quantize_per_tensor_op, - dequantize_per_tensor_op, - torch.ops.aten.linear.default, - torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.aten.linear.default, - torch.ops.aten.linear.default, - torch.ops.aten.add.Tensor, - torch.ops.aten.relu.default, - ] - fq_m = self._test_quantizer( - m, - example_inputs, - quantizer, - node_occurrence, - node_list, - )[-1] - # Two linear nodes are quantized alone - # The other two are fused with add and relu - aten = torch.ops.aten - expected_annotation_stat = { - aten.linear.default: { - "annotated": 4, - "is_quant_out": 2, - }, - aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, - aten.relu.default: {"annotated": 2, "is_quant_out": 2}, - } - self._check_annotation_stat(fq_m, expected_annotation_stat) + m = TestHelperModules.SerialsLinearAddReLUModule().eval() + example_inputs = (torch.randn(2, 16),) + quantizer = X86InductorQuantizer().set_global( + xiq.get_default_x86_inductor_quantization_config() + ) + node_occurrence = { + # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 + # dequantize_per_tensor: 1 for each linear + # No q-dq for extra input node of add + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.linear.default, + torch.ops.aten.linear.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.relu.default, + ] + fq_m = self._test_quantizer( + m, + example_inputs, + quantizer, + node_occurrence, + node_list, + )[-1] + # Two linear nodes are quantized alone + # The other two are fused with add and relu + aten = torch.ops.aten + expected_annotation_stat = { + aten.linear.default: { + "annotated": 4, + "is_quant_out": 2, + }, + aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, + aten.relu.default: {"annotated": 2, "is_quant_out": 2}, + } + self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfTorchDynamo("very slow") @skipIfNoX86 diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index ecb9a14c0a4c..4cc05e46c6a7 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -776,8 +776,10 @@ def _annotate_conv2d_fusion_pattern(self, model: torch.fx.GraphModule): def _annotate_linear_fusion_pattern(self, model: torch.fx.GraphModule): if config := self._get_aten_operator_qconfig(torch.ops.aten.linear.default): - self._annotate_linear_binary_unary(model, config) - self._annotate_linear_unary(model, config) + if config.input_activation and not config.input_activation.is_dynamic: + # Weiwen: Dynamic Quant of linear unary will be supported in next step + self._annotate_linear_binary_unary(model, config) + self._annotate_linear_unary(model, config) self._annotate_linear(model, config) def _annotate_matmul(self, model: torch.fx.GraphModule):