diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 44416eb5063de..cbf9dd89c506b 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1385,6 +1385,18 @@ def test_dynamic_qlinear_qat_cpu(self): (torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_dynamic_qlinear_input_dim_exceeds_2(self): + r""" + This testcase will quantize a single Linear Moduel. + """ + for bias in [True, False]: + self._qlinear_cpu_test_helper( + (torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -1577,7 +1589,13 @@ def test_qlinear_gelu_int8_mixed_bf16(self): (torch.randn((2, 4)),), gelu, int8_mixed_bf16=True ) - def _qlinear_dequant_promotion_cpu_test_helper(self, inputs, int8_mixed_bf16=False): + def _qlinear_dequant_promotion_cpu_test_helper( + self, + inputs, + int8_mixed_bf16=False, + is_dynamic=False, + matcher_check_fn=None, + ): class M(torch.nn.Module): def __init__( self, @@ -1595,7 +1613,7 @@ def forward(self, x): mod = M().eval() - def matcher_check_fn(): + def default_matcher_check_fn(): # 1. Dequant pattern matcher for dequant promotion * 1 self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) # 2. dequant-linear pattern matched in quantization weight prepack * 3 @@ -1610,7 +1628,10 @@ def matcher_check_fn(): inputs, check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, check_quantization=True, - matcher_check_fn=matcher_check_fn, + matcher_check_fn=matcher_check_fn + if matcher_check_fn is not None + else default_matcher_check_fn, + is_dynamic=is_dynamic, ) @skipIfNoDynamoSupport @@ -1693,6 +1714,37 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self): (torch.randn((2, 3, 4)),), int8_mixed_bf16=True ) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfRocm + def test_qlinear_dequant_promotion_dynamic_cpu(self): + r""" + This testcase test if dequant node before linear is promoted correctly: + X + | + Linear1(X) + / \ + Linear2(X) Linear3(X) + \ / + Add + | + Y + """ + + def matcher_check_fn(): + # 1. Dequant pattern matcher for dequant promotion * 1 + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) + # 2. dequant-linear pattern matched in quantization weight prepack * 3 + self.assertEqual( + counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3 + ) + + self._qlinear_dequant_promotion_cpu_test_helper( + (torch.randn((2, 4)),), + matcher_check_fn=matcher_check_fn, + is_dynamic=True, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 0f6e4b5a59780..dded66207ee9f 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1229,6 +1229,7 @@ def _inner(match): dequant_pattern_end_node = match.output_node() if dequant_pattern_end_node.target not in [ quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ]: @@ -1254,7 +1255,11 @@ def _inner(match): ) if ( - dequant_node.target is quantized_decomposed.dequantize_per_tensor.default + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] and len(list(dequant_pattern_end_node.users)) > 1 ): # If dequant pattern has more than 1 users, then do dequant promoted @@ -1319,6 +1324,7 @@ def clone_to_new_node(graph, source_node, user_node): dequant_pattern_end_node = match.output_node() assert dequant_pattern_end_node.target in [ quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, prims.convert_element_type.default, aten.reshape.default, ] @@ -1328,7 +1334,10 @@ def clone_to_new_node(graph, source_node, user_node): # * OPT(prims.convert_element_type.default) (to_bf16) # * dequantize_per_tensor def _find_first_node_in_dequant_pattern(_node): - if _node.target is quantized_decomposed.dequantize_per_tensor.default: + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: # For a dequant pattern, we expect the start node is a dequantize_per_tensor node return _node else: @@ -1341,10 +1350,10 @@ def _find_first_node_in_dequant_pattern(_node): dequant_pattern_end_node ) - assert ( - dequant_pattern_start_node.target - is quantized_decomposed.dequantize_per_tensor.default - ) + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] # Clone the dequant pattern for each user node graph = match.graph @@ -1993,9 +2002,9 @@ def _generate_qlinear_weight_prepack_patterns( def _register_dequant_promotion(): dequant_pattern_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False] + [torch.float32, torch.bfloat16], [True, False], [True, False] ) - for dtype, input_dim_exceeds_two in dequant_pattern_cases: + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: # 4 dequantization patterns will be matched based on the dtype and input dimension size. # Case 1: int8-mixed-fp32, input dim size is 2 # Case 2: int8-mixed-fp32, input dim size exceeds 2 @@ -2019,7 +2028,9 @@ def _register_dequant_promotion(): _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ),