diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 374d1f62de615..0b75e3500435c 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1671,7 +1671,7 @@ static at::Tensor _quantized_convolution_onednn( } tensor src_scales_t = tensor(ideep::scale_t(1, act_scale)); tensor wei_scales_t = tensor(weights_scales); - tensor dst_scales_t = tensor(ideep::scale_t(1, inv_output_scale)); + tensor dst_scales_t = tensor(ideep::scale_t(1, 1.0/inv_output_scale)); tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point)); tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point)); if (act_scale != 1.0f) { @@ -1707,7 +1707,7 @@ static at::Tensor _quantized_convolution_onednn( ideep::convolution_forward::prepare( params, src, packed_weight, expected_bias, dst_dims, dst, stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups, - src_scales, weights_scales, ideep::scale_t(1, 1.0f / inv_output_scale), + src_scales, weights_scales, ideep::scale_t(1, inv_output_scale), src_zero_points, dst_zero_points, op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index df6df3c35201d..166d0fd617c06 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -931,6 +931,7 @@ static at::Tensor linear_int8_with_onednn_weight( c10::string_view& unary_post_op_algorithm) { using ideep::tensor; const int64_t dim = input.dim(); + output_scale = 1.0f / output_scale; TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte, "qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char)."); TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char, diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 44416eb5063de..82445b92ce693 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -555,15 +555,15 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 - # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # int8_mixed_fp32: [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] + # int8_mixed_bf16: [convert_element_type_1, sub, mul_1, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 ) self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], - 12 if int8_mixed_bf16 else 8, + 16 if int8_mixed_bf16 else 12, ) self._test_common( @@ -683,13 +683,14 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): r""" This testcase will quantize Conv2d->Hardtanh pattern. Match.nodes: - [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + [qconv2d_pointwise_default_1, convert_element_type_5, clamp_min_1, clamp_max_1, mul_2, round_2, add_1, clamp_min_2, + clamp_max_1, mul_2, round_2, add_1, clamp_min_2, clamp_max_2, convert_element_type_8 + [qconv2d_pointwise_default, convert_element_type_13, clamp_min_3, clamp_max_3] """ self._qconv2d_unary_cpu_test_helper( unary_op=torch.nn.Hardtanh(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv2d_unary_matcher_nodes=14, ) @skipIfNoDynamoSupport @@ -709,14 +710,14 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): r""" This testcase will quantize Conv2d->Hardswish pattern. Match.nodes: - [qconv2d_pointwise_default, convert_element_type, add, clamp_min, - clamp_max, mul, div, convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + [qconv2d_pointwise_default_1, convert_element_type_5, add_1, clamp_min_1, + clamp_max_1, mul_2, div, mul_3, round_2, add_2, clamp_min_2, clamp_max_2, convert_element_type_8] + [qconv2d_pointwise_default, convert_element_type_13, add_3, clamp_min_3, clamp_max_3, mul_5, div_1] """ self._qconv2d_unary_cpu_test_helper( unary_op=torch.nn.Hardswish(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=17, + qconv2d_unary_matcher_nodes=20, ) @skipIfNoDynamoSupport @@ -736,14 +737,14 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): r""" This testcase will quantize Conv2d->SiLU pattern. Match.nodes: - [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, - convert_element_type, quantize_per_tensor] - [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + [qconv2d_pointwise_default_1, convert_element_type_5, sigmoid, mul_2, + mul_3, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_8] + [qconv2d_pointwise_default, convert_element_type_13, sigmoid_1, mul_5] """ self._qconv2d_unary_cpu_test_helper( unary_op=torch.nn.SiLU(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv2d_unary_matcher_nodes=14, ) def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False): @@ -1027,17 +1028,17 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 - # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 - # [qconv2d_pointwise_default, quantize_per_tensor] + # [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2) + self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 7) self._test_common( mod, @@ -1106,6 +1107,7 @@ def test_qat_qconv2d_relu6(self): r""" This testcase will quantize Conv2d->ReLU6 pattern with qat flow. """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6()) @skipIfNoDynamoSupport @@ -1115,6 +1117,7 @@ def test_qat_qconv2d_hardtanh(self): r""" This testcase will quantize Conv2d->Hardtanh pattern with qat flow. """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh()) @skipIfNoDynamoSupport @@ -1124,6 +1127,7 @@ def test_qat_qconv2d_silu(self): r""" This testcase will quantize Conv2d->SiLU pattern with qat flow. """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU()) @skipIfNoDynamoSupport @@ -1133,6 +1137,7 @@ def test_qat_qconv2d_hardswish(self): r""" This testcase will quantize Conv2d->Hardswish pattern with qat flow. """ + self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish()) @skipIfNoDynamoSupport @@ -1171,17 +1176,18 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 - # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 - # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor] + # [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4, + # clamp_min_3, clamp_max_3, convert_element_type_6] self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) - self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4) + self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11) self._test_common( mod, @@ -1230,17 +1236,18 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 - # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 - # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor] + # [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4, + # clamp_min_3, clamp_max_3, convert_element_type_6] self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1) - self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5) + self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12) self._test_common( mod, @@ -1287,16 +1294,16 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant pattern matcher for dequant promotion * 1 - # [dequantize_per_tensor] + # [convert_element_type_3, sub_1, mul_3] self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1) - self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1) + self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 3) # 2. Dequant-conv pattern matched in quantization weight prepack * 3 - # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 + counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 18 ) # 3. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, add_3] @@ -1438,7 +1445,7 @@ def matcher_check_fn(): ) self.assertEqual( counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 13 if bias else 12, + 17 if bias else 16, ) self._qlinear_cpu_test_helper( @@ -1466,7 +1473,7 @@ def matcher_check_fn(): ) self.assertEqual( counters["inductor"]["qlinear_weight_prepack_matcher_nodes"], - 17 if bias else 16, + 21 if bias else 20, ) self._qlinear_cpu_test_helper( @@ -1715,16 +1722,12 @@ def forward(self, x1, x2): x1 = torch.randn((2, 4)) x2 = torch.randn((2, 5)) - def matcher_check_fn(): - self.assertEqual( - counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1 - ) - self._test_common( mod, (x1, x2), + 2, + 8, check_quantization=True, - matcher_check_fn=matcher_check_fn, ) @skipIfNoDynamoSupport @@ -1760,19 +1763,22 @@ def forward(self, x): v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add( 1 ) - - def matcher_check_fn(): - self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1) - self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 - ) - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1) - + # Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes + # 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at + # torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1] + # 2. Dequant-conv pattern matched in quantization weight prepack * 1 + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] + # 3. qconv2d_relu fusion in post-grad fusion pass * 1 + # [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] + # 4. qmaxpool2d * 1 + # [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2, + # clamp_min_2, clamp_max_2, convert_element_type_4] self._test_common( mod, (v,), + 6, + 31, check_quantization=True, - matcher_check_fn=matcher_check_fn, ) @skipIfNoDynamoSupport @@ -1846,19 +1852,22 @@ def forward(self, x): mod = M().eval() v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) - - def matcher_check_fn(): - self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1) - self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 - ) - self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2) - + # Totally 10 pattern_matcher_count, 49 pattern_matcher_nodes + # 1. Pair of to_int8 and to_fp32 * 5, matched in pointless_convert pass at + # torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1] + # 2. Dequant-conv pattern matched in quantization weight prepack * 2 + # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] + # 3. qconv2d fusion in post-grad fusion pass * 2 + # [qconv2d_pointwise_default, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] + # 4. qcat * 1 + # [convert_element_type_3, sub_1, mul_3, convert_element_type_7, sub_3, mul_7, cat, mul_8, round_5, + # add_4, clamp_min_4, clamp_max_4, convert_element_type_8] self._test_common( mod, (v,), + 10, + 49, check_quantization=True, - matcher_check_fn=matcher_check_fn, ) # https://github.com/pytorch/pytorch/issues/99841. diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d59f1fffd9268..96c6a0b0f905e 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4309,7 +4309,7 @@ def _test_qlinear_pt2e_helper( if post_op in ("none", "relu", "gelu"): qy_cpu = qlinear_op( qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, + b, 1.0 / used_y_scale, used_y_zp, output_dtype, post_op, unary_post_op_args, post_op_algo ) if post_op == "relu": @@ -4330,7 +4330,7 @@ def _test_qlinear_pt2e_helper( accum = accum.bfloat16() qy_cpu = qlinear_op( qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, + b, 1.0 / used_y_scale, used_y_zp, output_dtype, accum, x2_scale, x2_zp, "sum", binary_alpha, unary_post_op, unary_post_op_args, post_op_algo ) @@ -4348,7 +4348,7 @@ def _test_qlinear_pt2e_helper( binary_alpha = 1.0 # we only support alpha=1.0 now qy_cpu = qlinear_op( qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, - b, used_y_scale, used_y_zp, output_dtype, + b, 1.0 / used_y_scale, used_y_zp, output_dtype, x2, 1.0, 0, "add", binary_alpha, unary_post_op, unary_post_op_args, post_op_algo ) @@ -6796,7 +6796,7 @@ def _test_qconv_impl_cpu_tensor( pads, dilations, groups, - Y_scale, + 1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant Y_zero_point, qconv_output_dtype, post_op.binary_attr, @@ -6818,7 +6818,7 @@ def _test_qconv_impl_cpu_tensor( pads, dilations, groups, - Y_scale, + 1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant Y_zero_point, qconv_output_dtype, post_op.unary_attr, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index a4fd1a9191c1c..25d7039a7c8ed 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -478,6 +478,70 @@ def linear_dynamic_fp16_unpacked_weight( ) +# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is +# scale and zero_point is scalar or scalar tensor +@register_decomposition(quantized_decomposed.quantize_per_tensor.default) +def quantize_per_tensor_default_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is +# scale and zero_point is scalar or scalar tensor +@register_decomposition(quantized_decomposed.dequantize_per_tensor.default) +def dequantize_per_tensor_default_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return (input.to(torch.float32) - zero_point) * scale + + +@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor) +def quantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + + +@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor) +def dequantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to( + torch.float32 + ) + + @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) def q_embedding_bag_byte_unpack_decomp(packed): def bitcast_u8_to_f32(u8): diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index cbcb5f6fdf463..3b813c58d7c25 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -49,25 +49,10 @@ """ -def _get_pattern_output_dtype(match: Match): - """ - Get the pattern's output dtype from node's meta - Assume only 1 output node in this matched pattern. - """ - pattern_output_nodes = match.output_nodes() - assert len(pattern_output_nodes) == 1 - output_node = pattern_output_nodes[0] - assert isinstance(output_node, torch.fx.Node) - output_dtype = output_node.meta["val"].dtype - if output_dtype is torch.uint8: - output_dtype = None - return output_dtype - - def _may_generate_pattern_with_dtype_convert( - pattern, dtype=Arg(), with_dtype_convert=True, users=1 + pattern, dtype=Arg(), dtype_convert=True, users=1 ): - if with_dtype_convert: + if dtype_convert: return CallFunction( prims.convert_element_type.default, pattern, @@ -109,25 +94,30 @@ def _generate_linear_t_pattern( def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): # only insert to_dtype if is_bf16 is True computation_call = _may_generate_pattern_with_dtype_convert( - call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + call_fn, dtype=KeywordArg("to_float"), dtype_convert=is_bf16, users=users ) return unary_fusion(computation_call) -def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): - dequantize_per_tensor_activation_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.tensor - if is_tensor_overload - else quantized_decomposed.dequantize_per_tensor.default, - KeywordArg("x"), - KeywordArg("x_scale"), +""" +dequantize activation: + x = x.to(fp32) + x = x - zero_point + x = x * scale +""" +dequantize_per_tensor_activation_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("x"), + KeywordArg("x_dq_dtype"), + ), KeywordArg("x_zp"), - KeywordArg("x_quant_min"), - KeywordArg("x_quant_max"), - KeywordArg("x_dq_dtype"), - ) - return dequantize_per_tensor_activation_pattern - + ), + KeywordArg("x_scale"), +) dequantize_per_channel_weight_pattern = CallFunction( quantized_decomposed.dequantize_per_channel.default, @@ -210,13 +200,17 @@ def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): dequantize_accum_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.default, - KeywordArg("accum"), + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + KeywordArg("accum"), + KeywordArg("accum_dq_dtype"), + ), + KeywordArg("accum_zp"), + ), KeywordArg("accum_scale"), - KeywordArg("accum_zp"), - Arg(), - Arg(), - KeywordArg("accum_dq_dtype"), ) @@ -247,18 +241,43 @@ def generate_pattern_with_unary(computation_call, unary_post_op): return computation_call -def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): +def generate_pattern_with_output_quant( + computation_call, has_to_fp32_before_quant=False +): + """ + quantize output: + output = round(output * o_inv_scale) + output = output + zero_point + output = clamp_min(output, 0) + output = clamp_max(output, 127) + output = output.to(uint8) + """ quantized_op_output_pattern_pt2e = CallFunction( - quantized_decomposed.quantize_per_tensor.default, - _may_generate_pattern_with_dtype_convert( - computation_call, - Arg(), - with_dtype_convert, + prims.convert_element_type.default, + CallFunction( + aten.clamp_max.default, + CallFunction( + aten.clamp_min.default, + CallFunction( + aten.add.Tensor, + CallFunction( + aten.round.default, + CallFunction( + aten.mul.Tensor, + _may_generate_pattern_with_dtype_convert( + computation_call, + KeywordArg("autocast_output_quant_dtype"), + has_to_fp32_before_quant, + ), + KeywordArg("o_inv_scale"), + ), + ), + KeywordArg("o_zp"), + ), + KeywordArg("o_qmin"), + ), + KeywordArg("o_qmax"), ), - KeywordArg("o_inv_scale"), - KeywordArg("o_zp"), - KeywordArg("o_qmin"), - KeywordArg("o_qmax"), KeywordArg("o_dtype"), ) return quantized_op_output_pattern_pt2e @@ -274,9 +293,8 @@ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_val return actual_value == expected_value -def _is_valid_quantized_conv2d_optimization_pattern(): +def _is_valid_quantized_conv2d_optimization_pattern(output_dtype): def fn(match): - output_dtype = _get_pattern_output_dtype(match) if output_dtype is not None: # Only keep matched pattern with same output_dtype qconv_node_after_weight_prepack = filter_nodes( @@ -294,11 +312,13 @@ def _register_quantized_conv_lowering( pattern, pass_number, computation_op, + output_dtype, unary_attr, + original_pattern_output_dtype=torch.float32, ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_conv2d_optimization_pattern(), + extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype), pass_number=pass_number, ) def qconv(match: Match, *args, **kwargs): @@ -322,11 +342,13 @@ def qconv(match: Match, *args, **kwargs): kwargs["dilation"], kwargs["groups"], ) - output_dtype = _get_pattern_output_dtype(match) assert output_dtype in [None, torch.float32, torch.bfloat16] # Output QParams o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32-out qconv in weight prepack phase assert ( kwargs["attr"] == "none" ) # Expected no post op fused in weight prepack phase @@ -361,9 +383,8 @@ def qconv(match: Match, *args, **kwargs): return qconv -def _is_valid_quantized_linear_optimization_pattern(): +def _is_valid_quantized_linear_optimization_pattern(output_dtype): def fn(match): - output_dtype = _get_pattern_output_dtype(match) if output_dtype is not None: # Only keep matched pattern with same output_dtype qlinear_node_after_weight_prepack = filter_nodes( @@ -381,15 +402,16 @@ def _register_quantized_linear_lowering( pattern, pass_number, computation_op, + output_dtype, unary_attr, + original_pattern_output_dtype=torch.float32, ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_linear_optimization_pattern(), + extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype), pass_number=pass_number, ) def qlinear(match: Match, *args, **kwargs): - output_dtype = _get_pattern_output_dtype(match) # Activation QParams x, x_scale, x_zp = ( kwargs["x"], @@ -409,6 +431,9 @@ def qlinear(match: Match, *args, **kwargs): # Output QParams o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0 o_zero_point = kwargs["o_zp"] if output_dtype is None else 0 + assert ( + kwargs["output_dtype"] is original_pattern_output_dtype + ) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase assert ( kwargs["postop_name"] == "none" ) # Expected no post op fused in weight prepack phase @@ -435,7 +460,7 @@ def qlinear(match: Match, *args, **kwargs): return qlinear -def _is_valid_quantized_conv_binary_optimization_pattern(): +def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype): # Check if it's a valid Conv Binary Pattern: # * qconv2d_pointwise should only has one users # * Extra input of binary node comes from dequant pattern @@ -445,7 +470,6 @@ def _is_valid_quantized_conv_binary_optimization_pattern(): # ancestor nodes of the compute node, except for the binary node # connected to the compute node. def fn(match): - output_dtype = _get_pattern_output_dtype(match) compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0] # qconv2d_pointwise should only have one user if len(compute_node.users) != 1: @@ -461,8 +485,7 @@ def fn(match): assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or ( - extra_input_of_binary_node.target - != quantized_decomposed.dequantize_per_tensor.default + extra_input_of_binary_node.target != aten.mul.Tensor ): return False @@ -513,15 +536,15 @@ def _register_quantized_conv_binary_lowering( pattern, pass_number, computation_op, + output_dtype, binary_unary_attr, ): @register_lowering_pattern( pattern, - extra_check=_is_valid_quantized_conv_binary_optimization_pattern(), + extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype), pass_number=pass_number, ) def qconv_binary(match: Match, *args, **kwargs): - output_dtype = _get_pattern_output_dtype(match) x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] accum = ( kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"] @@ -606,11 +629,13 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): conv_unary_replace_patterns = { UnaryAttr("none", [], ""): generate_pattern_with_output_quant( get_dequantize_qconv_pt2e_pattern(1), + has_to_fp32_before_quant=is_bf16, ), UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( generate_pattern_with_unary( get_dequantize_qconv_pt2e_pattern(1), aten.relu.default ), + has_to_fp32_before_quant=is_bf16, ), UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( _unary_fusion_pattern( @@ -619,7 +644,7 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): 1, is_bf16, ), - with_dtype_convert=is_bf16, + has_to_fp32_before_quant=False, ), UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant( _unary_fusion_pattern( @@ -628,7 +653,7 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): 2, is_bf16, ), - with_dtype_convert=is_bf16, + has_to_fp32_before_quant=False, ), UnaryAttr("swish", [], ""): generate_pattern_with_output_quant( _unary_fusion_pattern( @@ -637,7 +662,7 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): 2, is_bf16, ), - with_dtype_convert=is_bf16, + has_to_fp32_before_quant=False, ), } @@ -647,7 +672,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): patterns, 1, # pass_number torch.ops.onednn.qconv2d_pointwise, # computation_op + None, # output_dtype, None is the default value for int8 output unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output @@ -655,34 +682,22 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): UnaryAttr("relu", [], ""): generate_pattern_with_unary( get_dequantize_qconv_pt2e_pattern(1), aten.relu.default ), - UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _hardtanh_fusion, - get_dequantize_qconv_pt2e_pattern(1), - 1, - is_bf16, - ), - Arg(), + UnaryAttr("hardtanh", [], ""): _unary_fusion_pattern( + _hardtanh_fusion, + get_dequantize_qconv_pt2e_pattern(1), + 1, is_bf16, ), - UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _hardswish_fusion, - get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), - 2, - is_bf16, - ), - Arg(), + UnaryAttr("hardswish", [], ""): _unary_fusion_pattern( + _hardswish_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, is_bf16, ), - UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _silu_fusion, - get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), - 2, - is_bf16, - ), - Arg(), + UnaryAttr("swish", [], ""): _unary_fusion_pattern( + _silu_fusion, + get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, is_bf16, ), } @@ -693,7 +708,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): patterns, 2, # pass_number torch.ops.onednn.qconv2d_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, ) # QLinear @@ -703,9 +720,11 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): linear_unary_replace_patterns = { UnaryAttr("none", [], ""): generate_pattern_with_output_quant( qlinear_pattern, + has_to_fp32_before_quant=is_bf16, ), UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + has_to_fp32_before_quant=is_bf16, ), UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant( _unary_fusion_pattern( @@ -716,18 +735,18 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): 2, is_bf16, ), - with_dtype_convert=is_bf16, + has_to_fp32_before_quant=False, ), UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant( _unary_fusion_pattern( _gelu_fusion_tanh, get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), 4, is_bf16, ), - with_dtype_convert=is_bf16, + has_to_fp32_before_quant=False, ), } @@ -736,7 +755,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): patterns, 1, # pass_number torch.ops.onednn.qlinear_pointwise, # computation_op + None, # output_dtype unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, ) # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output @@ -744,28 +765,20 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): UnaryAttr("relu", [], ""): generate_pattern_with_unary( qlinear_pattern, aten.relu.default ), - UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_erf, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 2 - ), - 2, - is_bf16, + UnaryAttr("gelu", [], "none"): _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 ), - Arg(), + 2, is_bf16, ), - UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert( - _unary_fusion_pattern( - _gelu_fusion_tanh, - get_qlinear_pt2e_pattern( - x_scale_zp_are_tensors, 1 if is_bf16 else 4 - ), - 4, - is_bf16, + UnaryAttr("gelu", [], "tanh"): _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 ), - Arg(), + 4, is_bf16, ), } @@ -775,7 +788,9 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): patterns, 2, # pass_number torch.ops.onednn.qlinear_pointwise, # computation_op + original_pattern_output_dtype, # output_dtype unary_attr, # unary_attr + original_pattern_output_dtype=original_pattern_output_dtype, ) @@ -807,6 +822,7 @@ def __init__( dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, ), + has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add, ), BinaryUnaryAttr( "sum", 1.0, "relu", [], "" @@ -820,6 +836,7 @@ def __init__( ), aten.relu.default, ), + has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add, ), } @@ -828,6 +845,7 @@ def __init__( patterns, 0, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + None, # output_dtype binary_unary_attr, # binary_unary_attr ) @@ -853,6 +871,11 @@ def __init__( patterns, 0, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Note that for int8-mixed-bf16 and non-inplace add, because we have + # q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs, + # the output dtype will be float32. + # For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output. + torch.bfloat16, binary_unary_attr, # binary_unary_attr ) else: @@ -860,6 +883,7 @@ def __init__( patterns, 1, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + torch.float32, binary_unary_attr, # binary_unary_attr ) @@ -881,6 +905,8 @@ def __init__( patterns, 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + # Same output dtype setting as conv-add-relu pattern + torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32, binary_unary_attr, # binary_unary_attr ) @@ -936,8 +962,6 @@ def qmaxpool2d(match: Match, *args, **kwargs): ceil_mode, ) computation_args, _ = require_channels_last(computation_op, *computation_args) - counters["inductor"]["qmaxpool2d_matcher_count"] += 1 - counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qmaxpool2d @@ -972,7 +996,7 @@ def _register_quantization_maxpool2d(): for max_pool2d_args in max_pool2d_args_list: dequantize_maxpool2d_pattern = CallFunction( aten.max_pool2d_with_indices.default, - get_dequantize_per_tensor_activation_pattern(), + dequantize_per_tensor_activation_pattern, KeywordArg("kernel_size"), *max_pool2d_args, ) @@ -1009,23 +1033,26 @@ def _is_input_output_same_scale_zp(check_node): def fn(match): # Ensure all the inputs and output has same scale and zero point # Step 1: Check inputs/output zero point - # Get dequant nodes at input - dequant_nodes = filter_nodes( - match.nodes, quantized_decomposed.dequantize_per_tensor.default - ) - zero_points = [node.args[2] for node in dequant_nodes] - # Get quant nodes at output - quant_nodes = filter_nodes( - match.nodes, quantized_decomposed.quantize_per_tensor.default - ) - assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" - zero_points.append(quant_nodes[0].args[2]) + sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor) + zero_points = [node.args[1] for node in sub_nodes] + add_nodes = filter_nodes(match.nodes, aten.add.Tensor) + assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(add_nodes[0].args[1]) if not all(zero_point == zero_points[0] for zero_point in zero_points): return False # Step 2: Check inputs/output scale - scales = [node.args[1] for node in dequant_nodes] - scales.append(quant_nodes[0].args[1]) + mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor) + # We need to find mul node at output since the scale value is reciprocal to input scale. + # Mul node at output should connect to cat node directly. + scales = [ + ( + mul_node.args[1] + if mul_node.args[0].target is check_node # type: ignore[union-attr] + else 1.0 / mul_node.args[1] # type: ignore[operator] + ) + for mul_node in mul_nodes + ] if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] return False @@ -1045,20 +1072,22 @@ def _register_quantized_cat_lowering( def qcat(match: Match, inputs, dim, **kwargs): # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] uint8_inputs = [input[0] for input in inputs] - counters["inductor"]["qcat_matcher_count"] += 1 - counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) return L[computation_op](uint8_inputs, dim) return qcat _raw_dequantize_per_tensor_activation_pattern = CallFunction( - quantized_decomposed.dequantize_per_tensor.default, - Arg(), - Arg(), - Arg(), - Arg(), - Arg(), + aten.mul.Tensor, + CallFunction( + aten.sub.Tensor, + CallFunction( + prims.convert_element_type.default, + Arg(), + Arg(), + ), + Arg(), + ), Arg(), ) @@ -1096,7 +1125,7 @@ def qreshape(match: Match, *args, **kwargs): def _register_quantization_reshape(): dequantize_reshape_pattern = CallFunction( torch.ops.aten.reshape.default, - get_dequantize_per_tensor_activation_pattern(), + dequantize_per_tensor_activation_pattern, KeywordArg("shape"), ) _register_quantized_reshape_lowering( @@ -1245,33 +1274,35 @@ def _inner(match): assert dtype in [torch.float32, torch.bfloat16] dequant_pattern_end_node = match.output_node() if dequant_pattern_end_node.target not in [ - quantized_decomposed.dequantize_per_tensor.default, + aten.mul.Tensor, prims.convert_element_type.default, aten.reshape.default, ]: return False if dequant_pattern_end_node.target is aten.reshape.default: - dequant_node = ( - dequant_pattern_end_node.args[ - 0 - ] # pattern: linear <- reshape <- dequant + mul_node = ( + dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul if dtype == torch.float32 else dequant_pattern_end_node.args[0].args[ 0 - ] # pattern: linear <- reshape <- to_bf16 <- dequant + ] # pattern: linear <- reshape <- to_bf16 <- mul ) else: - dequant_node = ( - dequant_pattern_end_node # pattern: linear <- dequant + mul_node = ( + dequant_pattern_end_node # pattern: linear <- mul if dtype == torch.float32 else dequant_pattern_end_node.args[ 0 - ] # pattern: linear <- to_bf16 <- dequant + ] # pattern: linear <- to_bf16 <- mul ) + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] if ( - dequant_node.target is quantized_decomposed.dequantize_per_tensor.default + mul_node.target is aten.mul.Tensor + and sub_node.target is aten.sub.Tensor + and to_fp32_node.target is prims.convert_element_type.default and len(list(dequant_pattern_end_node.users)) > 1 ): # If dequant pattern has more than 1 users, then do dequant promoted @@ -1332,10 +1363,10 @@ def clone_to_new_node(graph, source_node, user_node): # Find the start node and end node of a dequant pattern # * End node should be the match.output_node() - # * Start node should be the node of dequantize_per_tensor + # * Start node should be the node of dtype convert to float32 dequant_pattern_end_node = match.output_node() assert dequant_pattern_end_node.target in [ - quantized_decomposed.dequantize_per_tensor.default, + aten.mul.Tensor, prims.convert_element_type.default, aten.reshape.default, ] @@ -1343,10 +1374,15 @@ def clone_to_new_node(graph, source_node, user_node): # For a dequant pattern, we should expect see the node list as: # * OPT(aten.reshape.default) # * OPT(prims.convert_element_type.default) (to_bf16) - # * dequantize_per_tensor + # * aten.mul + # * aten.sub + # * prims.convert_element_type.default (to_fp32) def _find_first_node_in_dequant_pattern(_node): - if _node.target is quantized_decomposed.dequantize_per_tensor.default: - # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + if ( + _node.target is prims.convert_element_type.default + and _node.args[1] == torch.float32 + ): + # For a dequant pattern, we expect the start node is a to_fp32 node return _node else: assert ( @@ -1358,11 +1394,6 @@ def _find_first_node_in_dequant_pattern(_node): dequant_pattern_end_node ) - assert ( - dequant_pattern_start_node.target - is quantized_decomposed.dequantize_per_tensor.default - ) - # Clone the dequant pattern for each user node graph = match.graph user_node_list = list(dequant_pattern_end_node.users) @@ -1398,14 +1429,22 @@ def _inner(match): return False assert dtype in [torch.float32, torch.bfloat16] - if dtype == torch.float32: - dequant_node = conv_node.args[0] + mul_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] - dequant_node = convert_to_bf16.args[0] + mul_node = convert_to_bf16.args[0] + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] - if len(list(dequant_node.users)) != 1: + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): # Ensure the dequant pattern only has 1 user # since we will delete the dequant pattern here return False @@ -1438,10 +1477,12 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): conv_node = match.output_node() assert conv_node.target is aten.convolution.default if dtype == torch.float32: - dequant_node = conv_node.args[0] + mul_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] - dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + mul_node = convert_to_bf16.args[0] # type: ignore[union-attr] + sub_node = mul_node.args[0] # type: ignore[union-attr] + to_fp32_node = sub_node.args[0] # type: ignore[union-attr] has_clone_to_channel_last_node_in_pattern = ( conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] ) @@ -1544,7 +1585,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Erase the dequant pattern if dtype == torch.bfloat16: graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined] - graph.erase_node(dequant_node) + # Erase the dequant pattern + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) # Erase the dequant per channel pattern if clone_node is not None: graph.erase_node(clone_node) @@ -1564,7 +1608,7 @@ def _generate_dequant_convolution_node_pattern( dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + dequantize_per_tensor_activation_pattern, KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1624,7 +1668,7 @@ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): return linear_node, output_reshape_node -def _get_linear_dq_node( +def _get_linear_dq_mul_node( linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous ): act_reshape_node = None @@ -1635,30 +1679,30 @@ def _get_linear_dq_node( act_reshape_node = linear_node.args[input_index] assert act_reshape_node.target is aten.reshape.default if dtype == torch.float32: - # pattern: linear -> reshape -> dequant - dequant_node = act_reshape_node.args[0] + # pattern: linear -> reshape -> mul + mul_node = act_reshape_node.args[0] else: - # pattern: linear -> reshape -> to_bf16 -> dequant + # pattern: linear -> reshape -> to_bf16 -> mul activation_to_bf16_node = act_reshape_node.args[0] - dequant_node = activation_to_bf16_node.args[0] + mul_node = activation_to_bf16_node.args[0] else: # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous act_expand_node = linear_node.args[input_index] assert act_expand_node.target is aten.expand.default if dtype == torch.float32: - dequant_node = act_expand_node.args[0] + mul_node = act_expand_node.args[0] else: activation_to_bf16_node = act_expand_node.args[0] - dequant_node = activation_to_bf16_node.args[0] + mul_node = activation_to_bf16_node.args[0] else: if dtype == torch.float32: - # pattern: linear -> dequant - dequant_node = linear_node.args[input_index] + # pattern: linear -> mul + mul_node = linear_node.args[input_index] else: - # pattern: linear -> to_bf16 -> dequant + # pattern: linear -> to_bf16 -> mul activation_to_bf16_node = linear_node.args[input_index] - dequant_node = activation_to_bf16_node.args[0] - return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + mul_node = activation_to_bf16_node.args[0] + return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): @@ -1671,21 +1715,27 @@ def _inner(match): input_index = 1 if linear_node.target is aten.addmm.default else 0 assert dtype in [torch.float32, torch.bfloat16] + ( - dequant_node, + mul_node, _, _, _, - ) = _get_linear_dq_node( + ) = _get_linear_dq_mul_node( linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous ) - assert dequant_node.target in [ - quantized_decomposed.dequantize_per_tensor.default, - quantized_decomposed.dequantize_per_tensor.tensor, - ] + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] - if len(list(dequant_node.users)) != 1: + assert to_fp32_node.target is prims.convert_element_type.default + assert sub_node.target is aten.sub.Tensor + assert mul_node.target is aten.mul.Tensor + if ( + len(list(to_fp32_node.users)) != 1 + or len(list(sub_node.users)) != 1 + or len(list(mul_node.users)) != 1 + ): # Ensure the dequant pattern only has 1 user # since we will delete the dequant pattern here return False @@ -1770,14 +1820,17 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): weight_index = input_index + 1 ( - dequant_node, + mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node, - ) = _get_linear_dq_node( + ) = _get_linear_dq_mul_node( linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous ) + sub_node = mul_node.args[0] + to_fp32_node = sub_node.args[0] + if input_dim_exceeds_two and not input_contiguous: wgt_expand_node = linear_node.args[weight_index] assert wgt_expand_node.target is aten.expand.default @@ -1885,7 +1938,9 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): if dtype == torch.bfloat16: graph.erase_node(activation_to_bf16_node) # Erase the dequant pattern - graph.erase_node(dequant_node) + graph.erase_node(mul_node) + graph.erase_node(sub_node) + graph.erase_node(to_fp32_node) # Erase the dequant per channel pattern graph.erase_node(t_node) if dtype == torch.bfloat16: @@ -1899,10 +1954,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_linear_node_pattern( - _dequant_per_channel_pattern, - dtype=torch.float32, - input_dim_exceeds_two=False, - is_tensor_overload=False, + _dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1912,7 +1964,7 @@ def _generate_dequant_linear_node_pattern( KeywordArg("b"), _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + dequantize_per_tensor_activation_pattern, KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1929,7 +1981,7 @@ def _generate_dequant_linear_node_pattern( aten.mm.default, _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + dequantize_per_tensor_activation_pattern, KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1948,7 +2000,6 @@ def _generate_dequant_bmm_node_pattern( _dequant_per_channel_pattern, dtype=torch.float32, with_bias=False, - is_tensor_overload=False, ): # When activation of linear dim exceed 2 and not contiguous t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -1959,7 +2010,7 @@ def _generate_dequant_bmm_node_pattern( CallFunction( aten.expand.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + dequantize_per_tensor_activation_pattern, KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -1990,21 +2041,16 @@ def _generate_qlinear_weight_prepack_patterns( input_dim_exceeds_two=False, input_contiguous=True, with_bias=False, - is_tensor_overload=False, ): if input_dim_exceeds_two and not input_contiguous: return _generate_dequant_bmm_node_pattern( dequantize_per_channel_weight_pattern, dtype, with_bias, - is_tensor_overload, ) else: return _generate_dequant_linear_node_pattern( - dequantize_per_channel_weight_pattern, - dtype, - input_dim_exceeds_two, - is_tensor_overload, + dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two ) @@ -2036,7 +2082,7 @@ def _register_dequant_promotion(): _register_dequant_promotion_pass( _may_generate_pattern_with_reshape( _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + dequantize_per_tensor_activation_pattern, KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), @@ -2094,15 +2140,13 @@ def _register_qlinear_weight_prepack(): # | OPT(add) | linear_weight_prepack_cases = itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + [torch.float32, torch.bfloat16], [True, False] ) # Step 1: register patterns from mm and addmm - for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + for dtype, input_dim_exceeds_two in linear_weight_prepack_cases: weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( - dtype, - input_dim_exceeds_two, - is_tensor_overload=is_tensor_overload, + dtype, input_dim_exceeds_two ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. @@ -2119,15 +2163,14 @@ def _register_qlinear_weight_prepack(): # https://github.com/pytorch/pytorch/blob/ # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 # in this case, we can convert it back to qlinear - for dtype, with_bias, is_tensor_overload in itertools.product( - [torch.float32, torch.bfloat16], [True, False], [True, False] + for dtype, with_bias in itertools.product( + [torch.float32, torch.bfloat16], [True, False] ): bmm_pattern = _generate_qlinear_weight_prepack_patterns( dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, with_bias=with_bias, - is_tensor_overload=is_tensor_overload, ) _register_qlinear_weight_prepack_pass( bmm_pattern, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index f7d2784c7ff58..a07fb4d8b1b09 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1144,170 +1144,6 @@ def inner_fn(idx): ) -@register_lowering( - quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None -) -def quantized_decomposed_quantize_per_tensor_default( - input: TensorBox, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> TensorBox: - if input.get_dtype() == torch.bfloat16: - input = to_dtype(input, torch.float32) - assert ( - input.get_dtype() == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" - - input_loader = input.make_loader() - - def inner_fn(idx, scale, zero_point): - input = input_loader(idx) - inv_scale, zero_point = _create_constants( - 1.0 / scale, zero_point, dtype=torch.float32 - ) - val = ops.round(input * inv_scale) + zero_point - qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) - clamped = ops.minimum(ops.maximum(val, qmin), qmax) - return ops.to_dtype(clamped, dtype) - - return Pointwise.create( - device=input.get_device(), - dtype=dtype, - inner_fn=functools.partial( - inner_fn, scale=float(scale), zero_point=int(zero_point) - ), - ranges=input.get_size(), - ) - - -@register_lowering( - quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None -) -def quantized_decomposed_dequantize_per_tensor_default( - input: TensorBox, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> TensorBox: - assert ( - input.get_dtype() == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" - - input_loader = input.make_loader() - - def inner_fn(idx, scale, zero_point): - input = input_loader(idx) - scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) - val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale - return val - - return Pointwise.create( - device=input.get_device(), - dtype=torch.float32, - inner_fn=functools.partial( - inner_fn, scale=float(scale), zero_point=int(zero_point) - ), - ranges=input.get_size(), - ) - - -@register_lowering( - quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None -) -def quantized_decomposed_quantize_per_tensor_tensor( - input: TensorBox, - scale: TensorBox, - zero_point: TensorBox, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> TensorBox: - if input.get_dtype() == torch.bfloat16: - input = to_dtype(input, torch.float32) - assert ( - input.get_dtype() == torch.float32 - ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" - assert len(scale.get_size()) == 0 or ( - len(scale.get_size()) == 1 and scale.get_size()[0] == 1 - ), "expect scale as scalar tensor" - assert len(zero_point.get_size()) == 0 or ( - len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 - ), "expect zero_point as scalar tensor" - - input_loader = input.make_loader() - scale_loader = scale.make_loader() - zero_point_loader = zero_point.make_loader() - - def inner_fn(idx): - input = input_loader(idx) - _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) - _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) - if scale.dtype != torch.float32: - _scale = ops.to_dtype(_scale, torch.float32) - if zero_point.dtype != torch.float32: - _zero_point = ops.to_dtype(_zero_point, torch.float32) - val = ops.round(input * ops.reciprocal(_scale)) + _zero_point - qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) - clamped = ops.minimum(ops.maximum(val, qmin), qmax) - return ops.to_dtype(clamped, dtype) - - return Pointwise.create( - device=input.get_device(), - dtype=dtype, - inner_fn=inner_fn, - ranges=input.get_size(), - ) - - -@register_lowering( - quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None -) -def quantized_decomposed_dequantize_per_tensor_tensor( - input: TensorBox, - scale: TensorBox, - zero_point: TensorBox, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> TensorBox: - assert len(scale.get_size()) == 0 or ( - len(scale.get_size()) == 1 and scale.get_size()[0] == 1 - ), "expect scale as scalar tensor" - assert len(zero_point.get_size()) == 0 or ( - len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 - ), "expect zero_point as scalar tensor" - assert ( - input.get_dtype() == dtype - ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" - - input_loader = input.make_loader() - scale_loader = scale.make_loader() - zero_point_loader = zero_point.make_loader() - - def inner_fn(idx): - input = input_loader(idx) - _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) - _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) - if scale.dtype != torch.float32: - _scale = ops.to_dtype(_scale, torch.float32) - if zero_point.dtype != torch.float32: - _zero_point = ops.to_dtype(_zero_point, torch.float32) - val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale - return val - - return Pointwise.create( - device=input.get_device(), - dtype=torch.float32, - inner_fn=inner_fn, - ranges=input.get_size(), - ) - - @register_lowering(aten.cat) def cat(inputs, dim=0): cpu_device = inputs[0].get_device().type == "cpu"