diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index cbf9dd89c506..44416eb5063d 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1385,18 +1385,6 @@ def test_dynamic_qlinear_qat_cpu(self): (torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfRocm - def test_dynamic_qlinear_input_dim_exceeds_2(self): - r""" - This testcase will quantize a single Linear Moduel. - """ - for bias in [True, False]: - self._qlinear_cpu_test_helper( - (torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True - ) - @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -1589,13 +1577,7 @@ def test_qlinear_gelu_int8_mixed_bf16(self): (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True ) - def _qlinear_dequant_promotion_cpu_test_helper( - self, - inputs, - int8_mixed_bf16=False, - is_dynamic=False, - matcher_check_fn=None, - ): + def _qlinear_dequant_promotion_cpu_test_helper(self, inputs, int8_mixed_bf16=False): class M(torch.nn.Module): def __init__( self, @@ -1613,7 +1595,7 @@ def forward(self, x): mod = M().eval() - def default_matcher_check_fn(): + def matcher_check_fn(): # 1. Dequant pattern matcher for dequant promotion * 1 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) # 2. dequant-linear pattern matched in quantization weight prepack * 3 @@ -1628,10 +1610,7 @@ def default_matcher_check_fn(): inputs, check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, check_quantization=True, - matcher_check_fn=matcher_check_fn - if matcher_check_fn is not None - else default_matcher_check_fn, - is_dynamic=is_dynamic, + matcher_check_fn=matcher_check_fn, ) @skipIfNoDynamoSupport @@ -1714,37 +1693,6 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): (torch.randn((2, 3, 4)),), int8_mixed_bf16=True ) - @skipIfNoDynamoSupport - @skipIfNoONEDNN - @skipIfRocm - def test_qlinear_dequant_promotion_dynamic_cpu(self): - r""" - This testcase test if dequant node before linear is promoted correctly: - X - | - Linear1(X) - / \ - Linear2(X) Linear3(X) - \ / - Add - | - Y - """ - - def matcher_check_fn(): - # 1. Dequant pattern matcher for dequant promotion * 1 - self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) - # 2. dequant-linear pattern matched in quantization weight prepack * 3 - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 - ) - - self._qlinear_dequant_promotion_cpu_test_helper( - (torch.randn((2, 4)),), - matcher_check_fn=matcher_check_fn, - is_dynamic=True, - ) - @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index e18110b7ecc3..14b3ce0c7c3f 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1246,7 +1246,6 @@ def _inner(match): dequant_pattern_end_node = match.output_node() if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ]: @@ -1272,11 +1271,7 @@ def _inner(match): ) if ( - dequant_node.target - in [ - quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, - ] + dequant_node.target is quantized_decomposed.dequantize_per_tensor.default and len(list(dequant_pattern_end_node.users)) > 1 ): # If dequant pattern has more than 1 users, then do dequant promoted @@ -1341,7 +1336,6 @@ def clone_to_new_node(graph, source_node, user_node): dequant_pattern_end_node = match.output_node() assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ] @@ -1351,10 +1345,7 @@ def clone_to_new_node(graph, source_node, user_node): # * OPT(prims.convert_element_type.default) (to_bf16) # * dequantize_per_tensor def _find_first_node_in_dequant_pattern(_node): - if _node.target in [ - quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, - ]: + if _node.target is quantized_decomposed.dequantize_per_tensor.default: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node else: @@ -1367,10 +1358,10 @@ def _find_first_node_in_dequant_pattern(_node): dequant_pattern_end_node ) - assert dequant_pattern_start_node.target in [ - quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, - ] + assert ( + dequant_pattern_start_node.target + is quantized_decomposed.dequantize_per_tensor.default + ) # Clone the dequant pattern for each user node graph = match.graph @@ -2019,9 +2010,9 @@ def _generate_qlinear_weight_prepack_patterns( def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False] ) - for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + for dtype, input_dim_exceeds_two in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 @@ -2045,9 +2036,7 @@ def _register_dequant_promotion(): _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern( - is_tensor_overload=is_tensor_overload - ), + get_dequantize_per_tensor_activation_pattern(), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ),