Skip to content

Commit

Permalink
Revert "[Inductor] [Quant] Enable lowering of quant per tensor and re…
Browse files Browse the repository at this point in the history
…factor quant pattern (#124041)"

This reverts commit 33e6791.

Reverted #124041 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think there is a land race with the change https://hud.pytorch.org/pytorch/pytorch/commit/33e6791645b5950b0f39301f55b8a4a79c0ca847 ([comment](#124041 (comment)))
  • Loading branch information
pytorchmergebot committed May 9, 2024
1 parent ca579c1 commit ea3f625
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 434 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ static at::Tensor _quantized_convolution_onednn(
}
tensor src_scales_t = tensor(ideep::scale_t(1, act_scale));
tensor wei_scales_t = tensor(weights_scales);
tensor dst_scales_t = tensor(ideep::scale_t(1, inv_output_scale));
tensor dst_scales_t = tensor(ideep::scale_t(1, 1.0/inv_output_scale));
tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point));
tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
if (act_scale != 1.0f) {
Expand Down Expand Up @@ -1707,7 +1707,7 @@ static at::Tensor _quantized_convolution_onednn(
ideep::convolution_forward::prepare(
params, src, packed_weight, expected_bias, dst_dims, dst,
stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups,
src_scales, weights_scales, ideep::scale_t(1, 1.0f / inv_output_scale),
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
output_scale = 1.0f / output_scale;
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
Expand Down
121 changes: 65 additions & 56 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,15 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
# int8_mixed_fp32: [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [convert_element_type_1, sub, mul_1, optional(convert_element_type_4),
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
12 if int8_mixed_bf16 else 8,
16 if int8_mixed_bf16 else 12,
)

self._test_common(
Expand Down Expand Up @@ -683,13 +683,14 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
[qconv2d_pointwise_default_1, convert_element_type_5, clamp_min_1, clamp_max_1, mul_2, round_2, add_1, clamp_min_2,
clamp_max_1, mul_2, round_2, add_1, clamp_min_2, clamp_max_2, convert_element_type_8
[qconv2d_pointwise_default, convert_element_type_13, clamp_min_3, clamp_max_3]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardtanh(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=11,
qconv2d_unary_matcher_nodes=14,
)

@skipIfNoDynamoSupport
Expand All @@ -709,14 +710,14 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
[qconv2d_pointwise_default_1, convert_element_type_5, add_1, clamp_min_1,
clamp_max_1, mul_2, div, mul_3, round_2, add_2, clamp_min_2, clamp_max_2, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, add_3, clamp_min_3, clamp_max_3, mul_5, div_1]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardswish(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=17,
qconv2d_unary_matcher_nodes=20,
)

@skipIfNoDynamoSupport
Expand All @@ -736,14 +737,14 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
[qconv2d_pointwise_default_1, convert_element_type_5, sigmoid, mul_2,
mul_3, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, sigmoid_1, mul_5]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.SiLU(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=11,
qconv2d_unary_matcher_nodes=14,
)

def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
Expand Down Expand Up @@ -1027,17 +1028,17 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6
)
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, quantize_per_tensor]
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 7)

self._test_common(
mod,
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def test_qat_qconv2d_relu6(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())

@skipIfNoDynamoSupport
Expand All @@ -1115,6 +1117,7 @@ def test_qat_qconv2d_hardtanh(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())

@skipIfNoDynamoSupport
Expand All @@ -1124,6 +1127,7 @@ def test_qat_qconv2d_silu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())

@skipIfNoDynamoSupport
Expand All @@ -1133,6 +1137,7 @@ def test_qat_qconv2d_hardswish(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern with qat flow.
"""

self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1171,17 +1176,18 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11)

self._test_common(
mod,
Expand Down Expand Up @@ -1230,17 +1236,18 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12)

self._test_common(
mod,
Expand Down Expand Up @@ -1287,16 +1294,16 @@ def forward(self, x):

def matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
# [dequantize_per_tensor]
# [convert_element_type_3, sub_1, mul_3]
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 3)
# 2. Dequant-conv pattern matched in quantization weight prepack * 3
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 18
)
# 3. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, add_3]
Expand Down Expand Up @@ -1438,7 +1445,7 @@ def matcher_check_fn():
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
13 if bias else 12,
17 if bias else 16,
)

self._qlinear_cpu_test_helper(
Expand Down Expand Up @@ -1466,7 +1473,7 @@ def matcher_check_fn():
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
17 if bias else 16,
21 if bias else 20,
)

self._qlinear_cpu_test_helper(
Expand Down Expand Up @@ -1715,16 +1722,12 @@ def forward(self, x1, x2):
x1 = torch.randn((2, 4))
x2 = torch.randn((2, 5))

def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)

self._test_common(
mod,
(x1, x2),
2,
8,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1760,19 +1763,22 @@ def forward(self, x):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)

def matcher_check_fn():
self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)

# Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qmaxpool2d * 1
# [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2,
# clamp_min_2, clamp_max_2, convert_element_type_4]
self._test_common(
mod,
(v,),
6,
31,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1846,19 +1852,22 @@ def forward(self, x):

mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)

def matcher_check_fn():
self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)

# Totally 10 pattern_matcher_count, 49 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 5, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qcat * 1
# [convert_element_type_3, sub_1, mul_3, convert_element_type_7, sub_3, mul_7, cat, mul_8, round_5,
# add_4, clamp_min_4, clamp_max_4, convert_element_type_8]
self._test_common(
mod,
(v,),
10,
49,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

# https://github.com/pytorch/pytorch/issues/99841.
Expand Down
10 changes: 5 additions & 5 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4309,7 +4309,7 @@ def _test_qlinear_pt2e_helper(
if post_op in ("none", "relu", "gelu"):
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
post_op, unary_post_op_args, post_op_algo
)
if post_op == "relu":
Expand All @@ -4330,7 +4330,7 @@ def _test_qlinear_pt2e_helper(
accum = accum.bfloat16()
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
accum, x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
Expand All @@ -4348,7 +4348,7 @@ def _test_qlinear_pt2e_helper(
binary_alpha = 1.0 # we only support alpha=1.0 now
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
x2, 1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
Expand Down Expand Up @@ -6796,7 +6796,7 @@ def _test_qconv_impl_cpu_tensor(
pads,
dilations,
groups,
Y_scale,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_zero_point,
qconv_output_dtype,
post_op.binary_attr,
Expand All @@ -6818,7 +6818,7 @@ def _test_qconv_impl_cpu_tensor(
pads,
dilations,
groups,
Y_scale,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_zero_point,
qconv_output_dtype,
post_op.unary_attr,
Expand Down

0 comments on commit ea3f625

Please sign in to comment.