Skip to content

Commit

Permalink
[Quant] [PT2] Remove the annotation of conv linear output in x86Induc…
Browse files Browse the repository at this point in the history
…torQuantizer

ghstack-source-id: 3e7adf22fc79dbbbd1c68486d31df18e940cfb3b
Pull Request resolved: #112140
  • Loading branch information
leslie-fang-intel committed Oct 27, 2023
1 parent e382e2f commit 0929ae1
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 140 deletions.
95 changes: 38 additions & 57 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,22 +434,17 @@ def forward(self, x):
mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)

# Totally pattern_matcher_count 4,
# pattern_matcher_nodes 17
# Totally pattern_matcher_count 2,
# pattern_matcher_nodes 8
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
4,
17,
2,
8,
check_quantization=True,
)

Expand All @@ -476,22 +471,19 @@ def forward(self, x):
mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)

# Totally pattern_matcher_count 4,
# pattern_matcher_nodes 18
# Totally pattern_matcher_count 3,
# pattern_matcher_nodes 10
# 1. pair of to_int8 and to_fp32 at conv input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-conv pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. pair of to_int8 and to_fp32 at conv output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, relu, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
# 3. Quantization fusion in post-grad fusion pass
# [qconv2d_pointwise_default, relu]
self._test_common(
mod,
(v,),
4,
18,
3,
10,
check_quantization=True,
)

Expand Down Expand Up @@ -531,8 +523,8 @@ def forward(self, x):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 8 pattern_matcher_count, 39 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# Totally 6 pattern_matcher_count, 26 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
Expand All @@ -549,13 +541,12 @@ def forward(self, x):
# 4. Quantization fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3,
# mul_6, round_4, add_4, clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, add_2]
self._test_common(
mod,
(v,),
7,
37,
6,
26,
check_quantization=True,
)

Expand Down Expand Up @@ -598,8 +589,8 @@ def forward(self, x):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 8 pattern_matcher_count, 40 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1, and graph output * 1
# Totally 6 pattern_matcher_count, 27 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 at conv input * 1, extra input of add * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
Expand All @@ -616,13 +607,12 @@ def forward(self, x):
# 4. Quantization fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu,
# mul_6, round_4, add_4, clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, add_3, relu]
self._test_common(
mod,
(v,),
7,
38,
6,
27,
check_quantization=True,
)

Expand Down Expand Up @@ -661,8 +651,8 @@ def forward(self, x):
mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)

# Totally 11 pattern_matcher_count, 54 pattern_matcher_nodes for conv
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1, and graph output * 1
# Totally 9 pattern_matcher_count, 41 pattern_matcher_nodes for conv
# 1. Pair of to_int8 and to_fp32 at conv input * 2, extra input of add * 1
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
Expand All @@ -679,13 +669,12 @@ def forward(self, x):
# 4. Quantization fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 5. Qconv2d_add * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, add_3]
self._test_common(
mod,
(v,),
10,
52,
9,
41,
check_quantization=True,
)

Expand All @@ -710,21 +699,16 @@ def forward(self, x):
mod = M(bias).eval()
v = torch.randn((2, 4))

# Totally pattern_matcher_count 4, pattern_matcher_nodes 17
# Totally pattern_matcher_count 2, pattern_matcher_nodes 8
# 1. pair of to_int8 and to_fp32 at input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-linear pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, t, addmm/mm]
# 3. pair of to_int8 and to_fp32 at output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qlinear_pointwise_default, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
self._test_common(
mod,
(v,),
4,
17,
2,
8,
check_quantization=True,
)

Expand All @@ -750,21 +734,18 @@ def forward(self, x):
mod = M(bias).eval()
v = torch.randn((2, 4))

# Totally pattern_matcher_count 4, pattern_matcher_nodes 18
# Totally pattern_matcher_count 3, pattern_matcher_nodes 10
# 1. pair of to_int8 and to_fp32 at input matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. dequant-linear pattern matched in quantization weight prepack
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, t, addmm/mm]
# 3. pair of to_int8 and to_fp32 at output matched in pointless_convert pass
# at torch/_inductor/fx_passes/joint_graph.py: [convert_element_type_2, convert_element_type_3]
# 4. Quantization fusion in post-grad fusion pass
# [qlinear_pointwise_default, relu, div_1, round_2, add_1,
# clamp_min_1, clamp_max_1, convert_element_type_2]
# 3. Quantization fusion in post-grad fusion pass
# [qlinear_pointwise_default, relu]
self._test_common(
mod,
(v,),
4,
18,
3,
10,
check_quantization=True,
)

Expand Down Expand Up @@ -803,8 +784,8 @@ def forward(self, x):
mod = M().eval()
v = torch.rand((2, 4))

# Totally 11 pattern_matcher_count, 50 pattern_matcher_nodes for linear
# 1. Pair of to_int8 and to_fp32 at linear input * 2, extra input of add * 1, and graph output * 1
# Totally 6 pattern_matcher_count, 30 pattern_matcher_nodes for linear
# 1. Pair of to_int8 and to_fp32 at linear input,
# matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# NB: since quant workflow now duplicates DQ node, for each user, we wont necessarily see
Expand All @@ -818,13 +799,13 @@ def forward(self, x):
# [convert_element_type_3, sub_1, mul_3]
# 3. Dequant-linear pattern matched in quantization weight prepack * 3
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, permute, addmm]
# 4. Quantization fusion in post-grad fusion pass * 3
# 4. Quantization fusion in post-grad fusion pass * 1
# [qlinear_pointwise_default, mul_6, round_4, add_3, clamp_min_3, clamp_max_3, convert_element_type_6]
self._test_common(
mod,
(v,),
10,
48,
6,
30,
check_quantization=True,
)

Expand Down
Loading

0 comments on commit 0929ae1

Please sign in to comment.