Skip to content

Commit

Permalink
[Inductor] [Quant] Enable lowering of quant per tensor and refactor q…
Browse files Browse the repository at this point in the history
…uant pattern (#124041)

**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: #124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed May 9, 2024
1 parent 96c8447 commit d83ab88
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 388 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ static at::Tensor _quantized_convolution_onednn(
}
tensor src_scales_t = tensor(ideep::scale_t(1, act_scale));
tensor wei_scales_t = tensor(weights_scales);
tensor dst_scales_t = tensor(ideep::scale_t(1, 1.0/inv_output_scale));
tensor dst_scales_t = tensor(ideep::scale_t(1, inv_output_scale));
tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point));
tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
if (act_scale != 1.0f) {
Expand Down Expand Up @@ -1707,7 +1707,7 @@ static at::Tensor _quantized_convolution_onednn(
ideep::convolution_forward::prepare(
params, src, packed_weight, expected_bias, dst_dims, dst,
stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups,
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_scales, weights_scales, ideep::scale_t(1, 1.0f / inv_output_scale),
src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,6 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
output_scale = 1.0f / output_scale;
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
Expand Down
121 changes: 56 additions & 65 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,15 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
# int8_mixed_fp32: [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [convert_element_type_1, sub, mul_1, optional(convert_element_type_4),
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
16 if int8_mixed_bf16 else 12,
12 if int8_mixed_bf16 else 8,
)

self._test_common(
Expand Down Expand Up @@ -683,14 +683,13 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, clamp_min_1, clamp_max_1, mul_2, round_2, add_1, clamp_min_2,
clamp_max_1, mul_2, round_2, add_1, clamp_min_2, clamp_max_2, convert_element_type_8
[qconv2d_pointwise_default, convert_element_type_13, clamp_min_3, clamp_max_3]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardtanh(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=14,
qconv2d_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
Expand All @@ -710,14 +709,14 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, add_1, clamp_min_1,
clamp_max_1, mul_2, div, mul_3, round_2, add_2, clamp_min_2, clamp_max_2, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, add_3, clamp_min_3, clamp_max_3, mul_5, div_1]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardswish(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=20,
qconv2d_unary_matcher_nodes=17,
)

@skipIfNoDynamoSupport
Expand All @@ -737,14 +736,14 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, sigmoid, mul_2,
mul_3, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, sigmoid_1, mul_5]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.SiLU(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=14,
qconv2d_unary_matcher_nodes=11,
)

def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
Expand Down Expand Up @@ -1028,17 +1027,17 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4
)
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# [qconv2d_pointwise_default, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 7)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2)

self._test_common(
mod,
Expand Down Expand Up @@ -1107,7 +1106,6 @@ def test_qat_qconv2d_relu6(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())

@skipIfNoDynamoSupport
Expand All @@ -1117,7 +1115,6 @@ def test_qat_qconv2d_hardtanh(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())

@skipIfNoDynamoSupport
Expand All @@ -1127,7 +1124,6 @@ def test_qat_qconv2d_silu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())

@skipIfNoDynamoSupport
Expand All @@ -1137,7 +1133,6 @@ def test_qat_qconv2d_hardswish(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1176,18 +1171,17 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4)

self._test_common(
mod,
Expand Down Expand Up @@ -1236,18 +1230,17 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5)

self._test_common(
mod,
Expand Down Expand Up @@ -1294,16 +1287,16 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# [dequantize_per_tensor]
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 3)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
# 2. Dequant-conv pattern matched in quantization weight prepack * 3
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 18
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
)
# 3. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, add_3]
Expand Down Expand Up @@ -1445,7 +1438,7 @@ def matcher_check_fn():
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
17 if bias else 16,
13 if bias else 12,
)

self._qlinear_cpu_test_helper(
Expand Down Expand Up @@ -1473,7 +1466,7 @@ def matcher_check_fn():
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
21 if bias else 20,
17 if bias else 16,
)

self._qlinear_cpu_test_helper(
Expand Down Expand Up @@ -1722,12 +1715,16 @@ def forward(self, x1, x2):
x1 = torch.randn((2, 4))
x2 = torch.randn((2, 5))

def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)

self._test_common(
mod,
(x1, x2),
2,
8,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1763,22 +1760,19 @@ def forward(self, x):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qmaxpool2d * 1
# [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2,
# clamp_min_2, clamp_max_2, convert_element_type_4]

def matcher_check_fn():
self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)

self._test_common(
mod,
(v,),
6,
31,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1852,22 +1846,19 @@ def forward(self, x):

mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
# Totally 10 pattern_matcher_count, 49 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 5, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qcat * 1
# [convert_element_type_3, sub_1, mul_3, convert_element_type_7, sub_3, mul_7, cat, mul_8, round_5,
# add_4, clamp_min_4, clamp_max_4, convert_element_type_8]

def matcher_check_fn():
self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)

self._test_common(
mod,
(v,),
10,
49,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

# https://github.com/pytorch/pytorch/issues/99841.
Expand Down
10 changes: 5 additions & 5 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4309,7 +4309,7 @@ def _test_qlinear_pt2e_helper(
if post_op in ("none", "relu", "gelu"):
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
post_op, unary_post_op_args, post_op_algo
)
if post_op == "relu":
Expand All @@ -4330,7 +4330,7 @@ def _test_qlinear_pt2e_helper(
accum = accum.bfloat16()
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
accum, x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
Expand All @@ -4348,7 +4348,7 @@ def _test_qlinear_pt2e_helper(
binary_alpha = 1.0 # we only support alpha=1.0 now
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
x2, 1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
Expand Down Expand Up @@ -6796,7 +6796,7 @@ def _test_qconv_impl_cpu_tensor(
pads,
dilations,
groups,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_scale,
Y_zero_point,
qconv_output_dtype,
post_op.binary_attr,
Expand All @@ -6818,7 +6818,7 @@ def _test_qconv_impl_cpu_tensor(
pads,
dilations,
groups,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_scale,
Y_zero_point,
qconv_output_dtype,
post_op.unary_attr,
Expand Down
Loading

0 comments on commit d83ab88

Please sign in to comment.