Skip to content

[X86] Enable FP8 patterns for AOT Inductor#4099

Merged
Xia-Weiwen merged 14 commits intopytorch:mainfrom
CaoE:aoti_support
Apr 28, 2026
Merged

[X86] Enable FP8 patterns for AOT Inductor#4099
Xia-Weiwen merged 14 commits intopytorch:mainfrom
CaoE:aoti_support

Conversation

@CaoE
Copy link
Copy Markdown
Contributor

@CaoE CaoE commented Mar 17, 2026

The graph obtained by AOT Inductor differs from that obtained by the regular torch.compile:

  1. Add fp8 support for concat_dequant_quant.
  2. The scale may be lifted to a lifted_tensor. Therefore, the fp8 pattern needs to be extended to adapt to AOTI.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 17, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4099

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit 8d20988 with merge base 2c41725 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 17, 2026

Hi @CaoE!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Mar 26, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@CaoE CaoE changed the title enable fp8 qlinear patterns for AOT Inductor Enable FP8 qlinear patterns for AOT Inductor Mar 26, 2026
@Xia-Weiwen Xia-Weiwen requested a review from Copilot March 26, 2026 06:54
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the x86 Inductor fusion/pattern-matching logic to better handle AOT Inductor-produced FX graphs for FP8 QLinear, where extra _assert_tensor_metadata nodes and lifted scale values can appear.

Changes:

  • Extend FP8/int8 dequant→linear prepack patterns to tolerate _assert_tensor_metadata users and register variants based on dequant user-count.
  • Generalize extraction of constant (scalar) scales to also support lifted constants (get_attr) and other FX node forms.
  • Add an AOT Inductor (AOTI) execution/codepath to the x86 inductor fusion tests and run FP8 tests with/without AOTI when supported.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
torchao/testing/pt2e/utils.py Avoids using .data when feeding FP8 weight tensor into the dequant op in a test module.
torchao/quantization/pt2e/inductor_passes/x86.py Updates pattern matching to support AOTI graph differences (assert-metadata users, lifted scale constants) and adds user-count pattern variants.
test/quantization/pt2e/test_x86inductor_fusion.py Adds an AOTI compilation/execution path to common test helpers and extends FP8 test coverage to include AOTI runs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread test/quantization/pt2e/test_x86inductor_fusion.py Outdated
Comment thread torchao/quantization/pt2e/inductor_passes/x86.py
Comment thread torchao/quantization/pt2e/inductor_passes/x86.py Outdated
@Xia-Weiwen Xia-Weiwen added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 27, 2026
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review March 27, 2026 05:29
@Xia-Weiwen Xia-Weiwen changed the title Enable FP8 qlinear patterns for AOT Inductor [X86] Enable FP8 qlinear patterns for AOT Inductor Mar 27, 2026
@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 April 2, 2026 02:51
@CaoE CaoE marked this pull request as draft April 2, 2026 07:10
@CaoE
Copy link
Copy Markdown
Contributor Author

CaoE commented Apr 2, 2026

This PR depends on pytorch/pytorch#178120

@CaoE CaoE marked this pull request as ready for review April 21, 2026 13:55
@CaoE CaoE requested review from andrewor14 and vkuzo as code owners April 21, 2026 13:55
counters.clear()
torch._dynamo.reset()

def aoti_compile(model, inputs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support compile_options too?

Comment on lines +1587 to +1592
import torch._inductor.constant_folding as cf

if hasattr(cf, "add_dont_constant_fold"):
cf.add_dont_constant_fold(
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this part do? Can you add a comment to explain?

quantizer=None,
compile_options={}, # noqa: B006
is_fp8=False,
is_aoti=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_aoti may be better

Comment on lines +1597 to +1598
for is_aoti in aoti_options:
for bias in [True, False]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use itertools.product or parametrize for nested options.

False,
]
try:
import torch._inductor.constant_folding as cf
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines +3375 to +3376
for is_aoti in aoti_options:
for length in [2, 3]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use itertools.product or parametrize for nested options.

Comment thread test/quantization/pt2e/test_x86inductor_fusion.py
@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @Xia-Weiwen's task in 3m 2s —— View job


Code Review for PR #4099

  • Gather context and read changed files
  • Analyze PR diff against main
  • Review torchao/quantization/pt2e/inductor_passes/x86.py
  • Review test/quantization/pt2e/test_x86inductor_fusion.py
  • Review torchao/testing/pt2e/utils.py
  • Check existing reviewer feedback from @Xia-Weiwen
  • Post comprehensive review

Summary

This PR extends x86 Inductor fusion passes to support AOT Inductor (AOTI) FX graphs for FP8 QLinear patterns. The key changes are:

  1. New _extract_const_float_from_node utility — consolidates scale extraction logic, adding support for get_attr (lifted tensors from AOTI) alongside aten.full and meta-val fallback.
  2. FP8 concat-dequant-quant pattern — extends _is_valid_concat_dq_q_pattern and registers a new FP8 variant.
  3. AOTI test path — adds aoti_compile helper and is_aoti parameter to the test infrastructure.
  4. .data removal — removes .data usage when passing FP8 weight tensors.

Issues Found

1. is_aoti not effective for FP8 qlinear code-checking path (Confirming Copilot's earlier finding)

In _qlinear_test_helper (test_x86inductor_fusion.py:1544), when is_fp8=True, the call sets include_ops=[] and a non-empty exclude_ops. This causes check_code=True in _test_common, which enters the if check_code: branch at line 236 — a path that always uses torch.compile and never calls aoti_compile, regardless of is_aoti.

This means test_fp8_qlinear_cpu with is_aoti=True does not actually exercise the AOTI compilation path for verifying that excluded ops are absent. You should either:

  • Add AOTI handling in the check_code branch (e.g., use run_and_get_code equivalent for AOTI), or
  • Skip the exclude_ops check when is_aoti=True so it falls through to the check_quantization branch that does handle AOTI.

2. aoti_compile ignores compile_options (@Xia-Weiwen's point)

The aoti_compile function (test_x86inductor_fusion.py:193) does not accept or forward compile_options. While current callers don't rely on this, it creates an inconsistency with the torch.compile path. Consider accepting compile_options as a parameter for parity, even if currently unused.

3. Duplicated AOTI capability detection pattern (Agreeing with @Xia-Weiwen)

The following block is copy-pasted 3 times (test_fp8_qlinear_cpu:1586, test_fp8_scaled_embedding_bag:3236, test_fp8_concat_dequant_quant:3362):

aoti_options = [False]
try:
    import torch._inductor.constant_folding as cf
    if hasattr(cf, "add_dont_constant_fold"):
        cf.add_dont_constant_fold(
            torch.ops.torchao.dequantize_affine_float8_non_decomposed.default
        )
        aoti_options = [False, True]
finally:
    pass

This should be extracted into a helper function or fixture. The try/finally: pass pattern is also unusual — if the intent is to catch import errors, try/except ImportError would be clearer. As-is, if add_dont_constant_fold raises an exception, it would propagate through the finally: pass.

Additionally, as @Xia-Weiwen suggested, consider using itertools.product or @parametrize for the is_aoti × bias / is_aoti × length loops.

4. Missing comment explaining AOTI capability detection

The try/finally block with add_dont_constant_fold is non-obvious. As @Xia-Weiwen noted (line 1592), a comment explaining why this is needed would help — specifically that AOTI requires preventing constant folding of the FP8 dequant op to keep the pattern matchable.

5. _extract_const_float_from_node drops torch.tensor handling

The old inline _extract_const_float in the scaled-embedding-bag pass handled a torch.tensor([scalar]) case:

if val.target is torch.tensor and len(val.args) >= 1:
    data = val.args[0]
    if isinstance(data, (list, tuple)) and len(data) == 1:
        ...

The new _extract_const_float_from_node does not include this case. Instead it adds get_attr support and meta-val fallback. If the torch.tensor pattern can still appear in practice (e.g., from user code), this is a regression. If it's now handled via the meta-val fallback, a comment noting this would be helpful.

6. Variable shadowing of builtin len

In test_int8_concat_dequant_quant (line 3318): for len in input_len_list: shadows the builtin len function. The new FP8 test correctly uses length instead. The int8 test should also be updated for consistency, though that's pre-existing.


Minor Observations

  • .data removal (torchao/testing/pt2e/utils.py:258, test_x86inductor_fusion.py:3135,3141): Using self.weight instead of self.weight.data is the correct approach for tracing/export compatibility. Good change.

  • _is_valid_concat_dq_q_pattern refactoring (x86.py:3119): The refactored function is cleaner. The final check return dq_node.args[0].target is torch.ops.aten.cat.default is a nice simplification of the old separate cat_node variable + if not ... return False / return True.

  • FP8 pattern in _register_concat_dq_q_pattern (x86.py:3223-3240): The new FP8 pattern registration using keyword args (fp8_dq_scale, fp8_q_scale) is well-structured and consistent with the existing int8 pattern.

  • concat_dq_q_fusion handler (x86.py:3161): The fusion handler correctly works for both int8 and FP8 since it operates on the match structure generically (erase q/dq nodes, replace uses with cat node).


Verdict

The core logic in x86.py looks sound — _extract_const_float_from_node is a well-structured utility, the FP8 concat pattern is correctly registered, and the o_inv_scale extraction is properly generalized. The main concerns are in the test infrastructure: the AOTI path isn't fully exercised for qlinear code checking, the capability detection pattern should be deduplicated, and the dropped torch.tensor case in the extraction utility warrants verification.

@CaoE
Copy link
Copy Markdown
Contributor Author

CaoE commented Apr 24, 2026

@claude review again.

@CaoE CaoE changed the title [X86] Enable FP8 qlinear patterns for AOT Inductor [X86] Enable FP8 patterns for AOT Inductor Apr 24, 2026
torch._dynamo.reset()

def aoti_compile(model, inputs, compile_options=None, get_source_code=False):
# compile_options (e.g. dynamic=True) are torch.compile() kwargs and do not
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I think it's better to remove the compile_options argument

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines +3394 to +3395
for use_aoti in _get_fp8_aoti_options():
for length in [2, 3]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use itertools.product or parametrize?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised

*inputs,
)
if use_aoti:
# compile_options (e.g. dynamic=True) are torch.compile() kwargs and do not
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment appears to be a duplicate of the one above. It would be best to delete it to keep things concise.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude review again

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @Xia-Weiwen's task in 2m 12s —— View job


Re-review of PR #4099

  • Gather context and read current diff against main
  • Review torchao/quantization/pt2e/inductor_passes/x86.py
  • Review test/quantization/pt2e/test_x86inductor_fusion.py
  • Review torchao/testing/pt2e/utils.py
  • Check if previous review issues have been addressed
  • Post updated review

Progress Since Last Review

The author has addressed several items from the previous review and from @Xia-Weiwen's feedback:

Previous Issue Status
Duplicated AOTI capability detection Fixed — extracted into _get_fp8_aoti_options() helper with clear docstring
try/finally: pass pattern Fixed — now uses try/except ImportError: pass
Missing comments for AOTI capability detection Fixed_get_fp8_aoti_options() has a thorough docstring
is_aotiuse_aoti naming Fixed — consistent use of use_aoti
itertools.product for nested loops Fixed — all FP8 test loops now use itertools.product
len builtin shadowing Fixed — changed to length in test_int8_concat_dequant_quant
_extract_const_float_from_node missing torch.tensor case Fixed — case 2 now handles torch.tensor([scalar])
compile_options for aoti_compile Fixedcompile_options argument removed from _test_common for AOTI path per @Xia-Weiwen's suggestion

Remaining Issues

1. expected is undefined when use_aoti=True and check_code=True with check_quantization=False

At test_x86inductor_fusion.py:309-311:

if not check_quantization:
    # Skip due to reduce range setting for Quantization on preCI system.
    torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)

When use_aoti=True, the check_code branch (line 283) sets actual but never sets expected — that variable is only assigned in the else (non-AOTI) branch at line 289. If this path is ever reached with use_aoti=True and check_quantization=False, it will raise a NameError.

Currently this is not triggered because _qlinear_test_helper always passes check_quantization=True alongside is_fp8=True. However, it's a latent bug that would surface if any future test uses use_aoti=True with check_code=True and check_quantization=False. Consider either:

  • Computing expected = mod(*inputs) before the if use_aoti branch, or
  • Adding an explicit guard like if not check_quantization and not use_aoti: for the assert_close call

2. _is_valid_concat_dq_q_pattern: assert vs validation return

In x86.py:3130-3141, the FP8 path uses assert to verify node targets, while the int8 path also uses assert. These assertions will crash with AssertionError if a pattern match provides unexpected nodes, rather than gracefully returning False. Since this is an extra_check callback, returning False is the expected way to reject a match. If the pattern matcher can ever call this with non-matching nodes (e.g., during graph mutations), these asserts would be problematic.

This is pre-existing behavior for the int8 path, so not necessarily blocking, but worth noting that assert in validation callbacks is fragile.

3. Minor: aoti_compile defined inside _test_common on every call

The aoti_compile function is defined as a closure inside _test_common (line 222) on every invocation, even when use_aoti=False. This is harmless (the function is never called in that case), but if cleanliness is desired, it could be moved outside or gated behind if use_aoti:. This is very minor and not blocking.


Code Quality Observations (Positive)

  • _extract_const_float_from_node (x86.py:2461-2502): Well-structured with clear case comments. Handles all four extraction patterns (scalar, aten.full, aten.tensor, get_attr) plus meta-val fallback. The ordering is sensible — structural extraction is preferred before falling back to meta values.

  • _get_fp8_aoti_options() (test_x86inductor_fusion.py:108-133): Clean deduplication of the capability detection logic with a clear docstring explaining why add_dont_constant_fold is needed.

  • FP8 concat-dq-q pattern (x86.py:3233-3250): The pattern registration is clean and consistent with the int8 version. Using keyword args (fp8_dq_scale, fp8_q_scale) instead of positional args is the right choice for the FP8 op signature.

  • .data removal (utils.py:258, test_x86inductor_fusion.py:3173,3179): Correct fix — .data bypasses tracing and would cause issues with export/AOTI.


Verdict

The code is in good shape after the revisions. The main actionable item is Issue #1 (potential NameError for expected in the check_code + use_aoti path). Issues #2 and #3 are minor and non-blocking. The core fusion logic in x86.py is sound.


@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @Xia-Weiwen's task in 3m 42s —— View job


Re-review of PR #4099

  • Gather context and read current diff against main
  • Review torchao/quantization/pt2e/inductor_passes/x86.py
  • Review test/quantization/pt2e/test_x86inductor_fusion.py
  • Review torchao/testing/pt2e/utils.py
  • Check status of previously reported issues
  • Post comprehensive review

Progress Since Last Review

All items from the second review have been addressed:

Previous Issue Status
expected undefined when use_aoti=True + check_code=True Fixedexpected = mod(*inputs) is now computed before the if use_aoti branch (line 285)
All items from first review (dedup, try/except, naming, itertools.product, torch.tensor case, len shadowing) Fixed — confirmed in current diff

Remaining Issues

1. _get_fp8_aoti_options() has a persistent global side effect

_get_fp8_aoti_options() (test_x86inductor_fusion.py:108-133) calls cf.add_dont_constant_fold(...) every time it's invoked, and this registration is global and permanent — it affects all subsequent AOTI compilations in the process. This means:

  • The use_aoti=False iteration that runs first in itertools.product(_get_fp8_aoti_options(), ...) already has the side effect applied.
  • If _get_fp8_aoti_options() is called from multiple tests (it's called from 3 test methods), add_dont_constant_fold is called 3 times for the same op.
  • If add_dont_constant_fold is not idempotent, this could cause issues.

Consider either: (a) making it idempotent with @functools.lru_cache(None), or (b) guarding with a module-level flag. This is a minor robustness concern since the registration is likely idempotent, but worth noting.

2. test_fp8_scaled_embedding_bag_with_output_quant not exercised with AOTI

test_fp8_scaled_embedding_bag_with_output_quant (line 3303) calls self._test_scaled_embedding_bag_helper(torch.float8_e4m3fn, True) without iterating over _get_fp8_aoti_options(). Compare with test_fp8_scaled_embedding_bag (line 3270) which does iterate. If the output-quant variant should also be tested under AOTI, this seems like an omission.

3. _extract_const_float_from_node: get_attr traversal uses string splitting

In x86.py:2484, str(v.target).split(".") is used to traverse nested attributes. While this works, using v.target directly (which is already a string for get_attr nodes) would be clearer — str() is redundant if v.target is already a str. More importantly, if v.target were ever a non-string (though unlikely for get_attr), str() could produce unexpected attribute names. This is very minor.


Code Quality Observations

  • _extract_const_float_from_node (x86.py:2461-2502): Well-structured utility covering all four extraction patterns. The ordering (structural extraction before meta-val fallback) is correct. The torch.tensor case is now properly included.

  • _is_valid_concat_dq_q_pattern refactoring (x86.py:3129-3163): Clean split between FP8 and int8 paths. The FP8 path correctly uses _extract_const_float_from_node before the math.isclose comparison, handling get_attr nodes from AOTI. Returning False instead of assert for target validation is an improvement over the original code.

  • _register_concat_dequant_quant_pass with extra_check parameter (x86.py:3166): Good generalization — making extra_check injectable avoids duplicating the registration logic.

  • FP8 concat pattern registration (x86.py:3237-3254): Clean and consistent with the int8 version. Using keyword args for the FP8 op signature is correct.

  • aoti_compile closure (test_x86inductor_fusion.py:224-246): The AOTI compilation path is well-implemented — export, compile, package, and optionally extract source code. The _disable_aten_to_metadata_assertions context manager is appropriate for export compatibility.

  • .data removal (utils.py:258, test_x86inductor_fusion.py:3175,3179): Correct — .data bypasses tracing and breaks export/AOTI.

  • _get_fp8_aoti_options() docstring: Clear explanation of why add_dont_constant_fold is needed. Good improvement.


Verdict

The code is in good shape. The core fusion logic in x86.py is sound, the test infrastructure properly handles both torch.compile and AOTI paths, and all feedback from previous reviews has been addressed. The remaining items are minor (idempotency of side effect, one test missing AOTI coverage, cosmetic string handling) and not blocking. The PR is ready for merge.


View job

@Xia-Weiwen
Copy link
Copy Markdown
Collaborator

The CI failures are unrelated.

@Xia-Weiwen Xia-Weiwen merged commit 9472d7d into pytorch:main Apr 28, 2026
17 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants