Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated quantization pattern matching to work with optional CallFunction in pattern expressions #123444

Closed
wants to merge 1 commit into from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Apr 5, 2024

Stack from ghstack (oldest at bottom):

Description

This PR adds a flag to _TargetArgsExpr to set (aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.div.Tensor) ops as optional during the pattern matching. For example, we want to match the following pattern:

CallFunction(aten.mul.Tensor, CallFunction(aten.sub.Tensor, CallFunction(prims.convert_element_type.default, KeywordArg('x'), KeywordArg('x_dq_dtype')), KeywordArg('x_zp')), KeywordArg('x_scale'))

then the following subgraphs would match:

convert_element_type_3 = torch.ops.prims.convert_element_type.default(convert_element_type_2, torch.float32)
sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_3, 0);  convert_element_type_3 = None
mul_3 = torch.ops.aten.mul.Tensor(sub_1, 0.015795549377799034);  sub_1 = None

and

convert_element_type_3 = torch.ops.prims.convert_element_type.default(convert_element_type_2, torch.float32)
mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_3, 0.015795549377799034);

and

convert_element_type_3 = torch.ops.prims.convert_element_type.default(convert_element_type_2, torch.float32)
sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_3, 1.0);  convert_element_type_3 = None

and

convert_element_type_3 = torch.ops.prims.convert_element_type.default(convert_element_type_2, torch.float32)

etc

However, the full control for the subgraph replacement can be still handled in the pattern.extra_check function.

Context

It can be beneficial on cpu to remove from the graph redundant ops like (see #123445):

sub = torch.ops.aten.sub.Tensor(node, 0);
mul = torch.ops.aten.mul.Tensor(sub, 1.0);
add = torch.ops.aten.sub.Tensor(mul, 0);
div = torch.ops.aten.div.Tensor(add, 1.0);

but this leads to quantization pattern mismatches due to missing mul, add, sub, div nodes.

…ion in pattern expressions

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123444

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Unrelated Failures

As of commit 86ddd88 with merge base 5b0ce8f (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.


if is_optional is None:
maybe_optional_ops = (aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.div.Tensor)
if len(self.fns) == 1 and self.fns[0] in maybe_optional_ops:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it makes sense for all mul, add, sub, div nodes to be optional. It's also quite surprising that part of the pattern won't appear in the match context when we hit this path.

A few alternatives that may be worth considering:

  • Can we pass is_optional=True explicitly in the pattern definition?
  • Could we just add more pattern definitions for quantization? I see a lot of the quantization patterns already have multiple variants that are generated for different dtypes, e.g.
    def _may_generate_pattern_with_dtype_convert(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand on my reasoning, I was thinking about adding a pattern a.mul(b).add(c) -> fma(a, b, c). With this PR I think that pattern would match literally everything since both CallFunction nodes would be optional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass is_optional=True explicitly in the pattern definition?

Yes, it can be possible as well. Discussing some time ago with Mario (cc @lezcano ) he suggested to set default to None and resolve it to true for this list of ops.

Could we just add more pattern definitions for quantization?

yes, this can be an option but would require a bit more code when generating patterns.

):
# if pattern is optional let's skip it and continue matching
pattern = pattern.args[0]
child_match = ctx.match(pattern, child_node)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC when we have pat1 + pat2 then this will match pat1 as if it was pat1 + 0 but not pat2? What happens if pat2 is another CallFunction and not just an argument?

@@ -1145,7 +1145,7 @@ def matcher_check_fn():
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4 (optional),
# clamp_min_3, clamp_max_3, convert_element_type_6]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
Copy link
Collaborator

@peterbell10 peterbell10 Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question for @jerryzh168. It seems like the current implementation of quantization is:

  1. dynamo captures torch.ops.quantized_decomposed.foo
  2. Inductor-specific decompositions decompose these into aten ops
  3. Inductor's pattern matcher matches those aten ops to recover the original meaning of quantized_decomposed.foo.

My question is why not keep quantized_decomposed.foo in the AOTDispatch graph? This would make pattern matching much simpler and ensure that joint graph passes don't mess up the pattern before it can be matched in the post-grad passes. If the concern is not wanting to write lowerings, then you should still be able to decompose any quantization ops that remain after the quantization pattern matcher is finished.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I suppose maybe you want all core aten ops for export's sake. In which case I think the best thing would be to have a set of "un-decompose" patterns that runs on the joint graph to undo such decompositions before any other passes run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@jerryzh168 jerryzh168 Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure, I think we probably want to preserve these ops in AOTI path. maybe these are decomposed for performance optimizations, maybe @leslie-fang-intel or @Xia-Weiwen knows

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @peterbell10 @jerryzh168, thanks for your comments. Also cc @jgong5 for this discussion.

  • Yean, currently we do the torch.ops.quantized_decomposed.quant/dequant decomposition during the AOTAutograd which is before pattern matcher phase. I suppose we can move it to the Inductor lowering phase and removing corresponding decomposition. But it may require some efforts to ensure no bc-breaking.
  • Another question, I think Quantization is just one of the specific cases which has been broken with the changes in Fixed arange decomp for float dtype #123445. Some other operators which has been decomposed with add/mul may also have potential vulnerabilities to this change. Can we move this pass after pattern matcher to avoid the breaking of existing patterns?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this pass after pattern matcher to avoid the breaking of existing patterns?

This pass is done during the joint graph pass which makes sense to me as the more we simplify the graph before partitioning, potentially the better partition we can get. On the other hand the quantization patterns are registered post-grad which will always be after the joint graph phase.

One solution might be to move quantization passes into the joint graph phase, and move this pass to the end of the joint graph passes. That would allow us to match quantization patterns on the raw decomposed graph while still partitioning the simplified graph.

Does that sound acceptable?

Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment @peterbell10, so looks like we may have 2 options now.

  • Option 1: Don't do the quant/dequant decomposition and left the optimization in the lowering phase. In this way we have see these operators (un-decomposed) in the quantization pattern matcher phase. This option makes the quantization pattern clearer but may require some effort to refactor and ensure no bc-breaking.
  • Option 2: As @peterbell10 mentioned above as move the quantization passes into the joint graph phase, and move remove_no_ops pass to the end of the joint graph passes. The change effort should be less than option 1.

@jgong5 any suggestions?

Copy link
Collaborator

@peterbell10 peterbell10 Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 ping
edit: nvm, I somehow missed the PR linked below in my notifications.

@vfdev-5 vfdev-5 marked this pull request as draft April 9, 2024 08:52
@vfdev-5 vfdev-5 closed this Apr 23, 2024
leslie-fang-intel added a commit that referenced this pull request Apr 28, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Apr 28, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 6, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 6, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 7, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 7, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 9, 2024
…uant pattern (#124041)

**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: #124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
leslie-fang-intel added a commit that referenced this pull request May 9, 2024
…t per tensor and refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request May 9, 2024
… refactor quant pattern"


**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 9, 2024
…uant pattern (#124041)

**Summary**
Per the discussion in #123444, the `decomposed quant/dequant` patterns changed after #123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in #123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: #124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
@github-actions github-actions bot deleted the gh/vfdev-5/17/head branch May 30, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants