[X86] Enable FP8 patterns for AOT Inductor#4099
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4099
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New FailuresAs of commit 8d20988 with merge base 2c41725 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @CaoE! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Pull request overview
This PR extends the x86 Inductor fusion/pattern-matching logic to better handle AOT Inductor-produced FX graphs for FP8 QLinear, where extra _assert_tensor_metadata nodes and lifted scale values can appear.
Changes:
- Extend FP8/int8 dequant→linear prepack patterns to tolerate
_assert_tensor_metadatausers and register variants based on dequant user-count. - Generalize extraction of constant (scalar) scales to also support lifted constants (
get_attr) and other FX node forms. - Add an AOT Inductor (AOTI) execution/codepath to the x86 inductor fusion tests and run FP8 tests with/without AOTI when supported.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
torchao/testing/pt2e/utils.py |
Avoids using .data when feeding FP8 weight tensor into the dequant op in a test module. |
torchao/quantization/pt2e/inductor_passes/x86.py |
Updates pattern matching to support AOTI graph differences (assert-metadata users, lifted scale constants) and adds user-count pattern variants. |
test/quantization/pt2e/test_x86inductor_fusion.py |
Adds an AOTI compilation/execution path to common test helpers and extends FP8 test coverage to include AOTI runs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
This PR depends on pytorch/pytorch#178120 |
| counters.clear() | ||
| torch._dynamo.reset() | ||
|
|
||
| def aoti_compile(model, inputs): |
There was a problem hiding this comment.
Support compile_options too?
| import torch._inductor.constant_folding as cf | ||
|
|
||
| if hasattr(cf, "add_dont_constant_fold"): | ||
| cf.add_dont_constant_fold( | ||
| torch.ops.torchao.dequantize_affine_float8_non_decomposed.default | ||
| ) |
There was a problem hiding this comment.
What does this part do? Can you add a comment to explain?
| quantizer=None, | ||
| compile_options={}, # noqa: B006 | ||
| is_fp8=False, | ||
| is_aoti=False, |
There was a problem hiding this comment.
use_aoti may be better
| for is_aoti in aoti_options: | ||
| for bias in [True, False]: |
There was a problem hiding this comment.
Better to use itertools.product or parametrize for nested options.
| False, | ||
| ] | ||
| try: | ||
| import torch._inductor.constant_folding as cf |
| for is_aoti in aoti_options: | ||
| for length in [2, 3]: |
There was a problem hiding this comment.
Better to use itertools.product or parametrize for nested options.
|
@claude review |
|
Claude finished @Xia-Weiwen's task in 3m 2s —— View job Code Review for PR #4099
SummaryThis PR extends x86 Inductor fusion passes to support AOT Inductor (AOTI) FX graphs for FP8 QLinear patterns. The key changes are:
Issues Found1.
|
|
@claude review again. |
| torch._dynamo.reset() | ||
|
|
||
| def aoti_compile(model, inputs, compile_options=None, get_source_code=False): | ||
| # compile_options (e.g. dynamic=True) are torch.compile() kwargs and do not |
There was a problem hiding this comment.
In that case, I think it's better to remove the compile_options argument
| for use_aoti in _get_fp8_aoti_options(): | ||
| for length in [2, 3]: |
There was a problem hiding this comment.
Use itertools.product or parametrize?
| *inputs, | ||
| ) | ||
| if use_aoti: | ||
| # compile_options (e.g. dynamic=True) are torch.compile() kwargs and do not |
There was a problem hiding this comment.
This comment appears to be a duplicate of the one above. It would be best to delete it to keep things concise.
|
@claude review again |
|
Claude finished @Xia-Weiwen's task in 2m 12s —— View job Re-review of PR #4099
Progress Since Last ReviewThe author has addressed several items from the previous review and from @Xia-Weiwen's feedback:
Remaining Issues1.
|
|
@claude review |
|
Claude finished @Xia-Weiwen's task in 3m 42s —— View job Re-review of PR #4099
Progress Since Last ReviewAll items from the second review have been addressed:
Remaining Issues1.
|
|
The CI failures are unrelated. |
The graph obtained by AOT Inductor differs from that obtained by the regular torch.compile:
concat_dequant_quant.