Revert "Arm backend: Lower MXFP Linear to TOSA"#20047
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20047
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled JobAs of commit aa82cd0 with merge base 4c9c444 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR reverts the Arm backend change that lowered MXFP Linear into explicit TOSA block-scaled operators, restoring the previous behavior by removing the MXFP-specific dialect ops, visitors, rewrite pass, and associated tests.
Changes:
- Remove TOSA dialect ops and Arm operator visitors for
CAST_TO_BLOCK_SCALED/MATMUL_T_BLOCK_SCALED, plus the MXFP linear rewrite pass. - Tighten FP8 extension validation to require the dedicated
fp8*extensions (rather than allowingmxfpas a substitute). - Simplify/relocate MXFP linear tests to only validate eager CPU/reference behavior.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/tosa/mapping.py | Removes MXFP-as-fallback extension validation for FP8 dtypes. |
| backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py | Deletes fake op definition for block-scaled matmul. |
| backends/arm/tosa/dialect/ops/cast_to_block_scaled.py | Deletes fake op definition for block-scaled cast. |
| backends/arm/tosa/dialect/init.py | Stops importing the removed dialect ops. |
| backends/arm/test/targets.bzl | Updates/removes test targets related to the reverted MXFP lowering pipeline. |
| backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py | Deletes tests for the removed rewrite pass. |
| backends/arm/test/ops/test_mxfp_linear.py | Simplifies MXFP linear tests to eager CPU reference checks. |
| backends/arm/test/ops/mxfp/common.py | Deletes MXFP pipeline helper utilities. |
| backends/arm/test/ops/mxfp/init.py | Deletes empty package marker for removed MXFP test helpers. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py | Deletes tests for removed TOSA dialect MXFP ops. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py | Deletes tests for removed TOSA dialect MXFP ops. |
| backends/arm/process_node.py | Removes float8_e8m0fnu handling from tensor serialization path. |
| backends/arm/operators/op_tosa_matmul_t_block_scaled.py | Deletes TOSA serializer visitor for block-scaled matmul. |
| backends/arm/operators/op_tosa_cast_to_block_scaled.py | Deletes TOSA serializer visitor for block-scaled cast. |
| backends/arm/operators/init.py | Stops importing the removed operator visitors. |
| backends/arm/operator_support/tosa_supported_operators.py | Removes MXFP custom-op support list and loosens dtype gating changes tied to MXFP. |
| backends/arm/_passes/rewrite_mxfp_linear.py | Deletes the MXFP linear rewrite pass implementation. |
| backends/arm/_passes/arm_pass_manager.py | Removes the rewrite pass from the Arm pass pipeline. |
| backends/arm/_passes/init.py | Stops exporting the removed rewrite pass. |
Comments suppressed due to low confidence (2)
backends/arm/test/ops/test_mxfp_linear.py:185
test_datais a callable (values intest_data_fparelambda: (...)) produced bycommon.parametrize, but the helper annotates it astorch.Tensor. This is misleading for readers and type checkers.
backends/arm/test/ops/test_mxfp_linear.py:214test_datahere is a callable provided bycommon.parametrize(dict values are lambdas), not atorch.Tensor. The current annotation is incorrect/misleading.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: | ||
| tensor = tensor.detach().cpu().contiguous() | ||
| if tensor.dtype in ( | ||
| torch.bfloat16, | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| torch.float8_e8m0fnu, | ||
| ): | ||
| if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): | ||
| try: | ||
| import ml_dtypes # type: ignore[import-not-found] | ||
| except ImportError as e: | ||
| raise RuntimeError( | ||
| f"ml_dtypes is required to serialize {tensor.dtype} tensors for TOSA. " | ||
| "Have you run setup.sh?" | ||
| ) from e | ||
|
|
||
| ml_dtype_map = { | ||
| torch.bfloat16: (torch.uint16, ml_dtypes.bfloat16), | ||
| torch.float8_e4m3fn: (torch.uint8, ml_dtypes.float8_e4m3fn), | ||
| torch.float8_e5m2: (torch.uint8, ml_dtypes.float8_e5m2), | ||
| torch.float8_e8m0fnu: (torch.uint8, ml_dtypes.float8_e8m0fnu), | ||
| } | ||
| storage_dtype, ml_dtype = ml_dtype_map[tensor.dtype] | ||
| return tensor.view(storage_dtype).numpy().view(ml_dtype) |
Reverts #19969
cc @digantdesai @freddan80 @per @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani