Arm backend: Lower MXFP Linear to TOSA#19969
Conversation
Add fake TOSA dialect support and serializer lowering for CAST_TO_BLOCK_SCALED. Co-authored-by: Sebastian Larsson <sebastian.larsson@arm.com> Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Ic7cdab5134f0fb9502f5985563f0662286ef5fb7
Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Sebastian Larsson <sebastian.larsson@arm.com> Change-Id: Iab2e1cf2ed21047bbc2a7a51604b9230fe2f2819
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19969
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 0603d37 with merge base f0d9991 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label ciflow/trunk |
|
@pytorchbot label "partner: arm" |
|
@pytorchbot label "release notes: arm" |
There was a problem hiding this comment.
Pull request overview
This PR adds end-to-end support in the Arm backend to lower MXFP Linear into explicit TOSA MXFP operators by introducing new TOSA dialect ops (cast-to-block-scaled + block-scaled matmul), wiring up serialization visitors, and inserting a rewrite pass in the Arm TOSA pipeline. It also expands dtype mapping/serialization to cover MXFP-related FP8 types and updates/extends the test suite accordingly.
Changes:
- Add FP8 dtype mappings and broaden spec checks to recognize MXFP-enabled FP8 usage.
- Introduce TOSA dialect ops + Arm operator visitors for
CAST_TO_BLOCK_SCALEDandMATMUL_T_BLOCK_SCALED. - Add
RewriteMXFPLinearPassand update Arm tests/pipelines to validate the new lowering path.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/tosa/mapping.py | Map additional FP8 dtypes and validate FP8 support against TOSA extensions. |
| backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py | Define a fake TOSA op for MXFP block-scaled matmul with shape/dtype validation. |
| backends/arm/tosa/dialect/ops/cast_to_block_scaled.py | Define a fake TOSA op for MXFP block-scaled casting with shape/dtype validation. |
| backends/arm/tosa/dialect/init.py | Ensure newly added dialect ops modules are imported/registered. |
| backends/arm/test/targets.bzl | Rehome MXFP linear op test and add new TOSA dialect tests to Bazel targets. |
| backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py | Add pass-level tests asserting custom MXFP linear op is rewritten into TOSA MXFP ops. |
| backends/arm/test/ops/mxfp/test_mxfp_linear.py | Refactor and expand MXFP linear tests using new pipelines; add channels-last case; add VGF xfails. |
| backends/arm/test/ops/mxfp/common.py | Add shared MXFP pipeline helpers/stages for TOSA/VGF test execution. |
| backends/arm/test/ops/mxfp/init.py | Add package marker for MXFP op tests. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py | Add fake-op level tests for MATMUL_T_BLOCK_SCALED. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py | Add fake-op level tests for CAST_TO_BLOCK_SCALED. |
| backends/arm/process_node.py | Extend tensor serialization path to support torch.float8_e8m0fnu via ml_dtypes. |
| backends/arm/operators/op_tosa_matmul_t_block_scaled.py | Add serializer visitor for MATMUL_T_BLOCK_SCALED. |
| backends/arm/operators/op_tosa_cast_to_block_scaled.py | Add serializer visitor for multi-output CAST_TO_BLOCK_SCALED. |
| backends/arm/operators/init.py | Import/register the new operator visitor modules. |
| backends/arm/operator_support/tosa_supported_operators.py | Allow MX custom op partitioning under mxfp; adjust dtype disallow list for FP8 under mxfp. |
| backends/arm/_passes/rewrite_mxfp_linear.py | Implement the rewrite of tosa_mxfp.linear into explicit TOSA MXFP ops. |
| backends/arm/_passes/arm_pass_manager.py | Insert the new rewrite pass into the TOSA lowering pipeline. |
| backends/arm/_passes/init.py | Export the new rewrite pass from the passes package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ): | ||
| return False | ||
|
|
||
| return True |
| tosa_spec.support_extension("fp8e5m2") or tosa_spec.support_extension("mxfp") | ||
| ): | ||
| disallowed_dtypes.append(torch.float8_e5m2) | ||
| if tosa_spec.is_U55_subset: |
| inputs: List[TosaArg], | ||
| output: TosaArg, | ||
| ) -> None: | ||
| validate_num_inputs(self.target, inputs, 2) |
| # TODO(MLETORCH-2018): This is a local workaround for multi-output TOSA ops. | ||
| # Remove it once twe can handle multiple outputs generally. | ||
| output_names = _ordered_getitem_output_names(node) |
| from executorch.backends.arm.operators.operator_validation_utils import ( | ||
| validate_num_inputs, | ||
| ) | ||
| from executorch.backends.arm.tosa.mapping import TosaArg |
| f"{CastToBlockScaledVisitor.target}: Expected exactly two getitem outputs, got {len(ordered_users)}" | ||
| ) | ||
|
|
||
| return [user.name for user in ordered_users] |
zingo
left a comment
There was a problem hiding this comment.
OK to merge if tests are OK and a bests effort have been made to update buck2 files.
|
The timeout fails on test_smaller_stories is known and being worked on separately. |
Yes, I believe things look all right and buck2 files are (hopefully) updated correctly. |
|
Hi @martinlsm @zingo I probably need to revert this PR. It's breaking our internal CI and mind re-landing this from internal first with our engineer? |
|
Example error from internal CI - |
Of cause, sorry for the problem, i created an revert here from this PR if you need it feel free approve and merge directly if you need. I can't self approve and due to time zones diffs we dont get back here in about 8h. |
|
Thank you @zingo Also I made some modifications to make buck work: backends/arm/test/targets.bzl |
|
Thanks @kirklandsign for providing the correct files. We reverted the PR, but here is a resubmission that contains your fixed buck2 files: #20065 |
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani