Add FuseTosaTransposesPass with elementwise propagation#18947
Add FuseTosaTransposesPass with elementwise propagation#18947Ninja91 wants to merge 1 commit intopytorch:mainfrom
Conversation
Summary:
Part of the Ethos-U55/U85 optimization stack targeting -22.6% NPU cycle reduction on Wake EMG.
Adds FuseTosaTransposesPass that eliminates redundant TOSA TRANSPOSE operations through four optimizations:
1. **Identity elimination** — remove TRANSPOSE with identity permutation [0,1,2,3]
2. **Inverse-pair cancellation** — remove TRANSPOSE→TRANSPOSE pairs that compose to identity
3. **Composition** — fuse consecutive non-inverse TRANSPOSEs into a single TRANSPOSE
4. **Propagation** — move TRANSPOSE through layout-agnostic ops (RESCALE, elementwise) to enable more cancellations
The propagation pattern handles the common case where ToTosaMemoryFormatPass inserts TRANSPOSE pairs around view_copy rank boundaries, with RESCALE and elementwise ops in between:
TRANSPOSE(p) → RESCALE → relu → RESCALE → TRANSPOSE(inv(p)) → RESCALE → relu → RESCALE
For binary elementwise ops (ADD, MUL, SUB), propagation is safe only when the non-primary operand is broadcast-safe (scalar or 1-element tensor).
## Impact
Combined with FuseConsecutiveRescalesPass (next diff), reduces total NPU cycles on Wake/U55 by 22.6%. TRANSPOSE elimination directly reduces TRANSPOSE HW ops and enables the RESCALE fusion pass to find more fusible pairs.
Reviewed By: davidxili
Differential Revision: D92901685
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18947
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 2 Unrelated FailuresAs of commit aff53bf with merge base a489707 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Introduces a new Arm backend optimization pass to eliminate redundant TOSA TRANSPOSE operations (including propagation through elementwise ops) as part of the Ethos-U55/U85 optimization stack, and adds targeted tests/scripts to validate transpose reduction behavior.
Changes:
- Add
FuseTosaTransposesPassimplementing identity elimination, inverse-pair cancellation, composition, and elementwise propagation. - Integrate the new pass into the Arm TOSA pipeline immediately after
ToTosaMemoryFormatPass. - Add unit tests for common patterns (conv chains, pooling, fan-out) plus propagation-through-elementwise cases, and a standalone comparison script.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/_passes/fuse_tosa_transposes_pass.py | New optimization pass to remove/fuse/cancel TOSA transposes, including propagation through elementwise ops. |
| backends/arm/_passes/arm_pass_manager.py | Wires FuseTosaTransposesPass into the standard Arm TOSA pipeline after memory-format transposes are inserted. |
| backends/arm/_passes/init.py | Exposes FuseTosaTransposesPass via the Arm passes package exports. |
| backends/arm/test/passes/test_fuse_tosa_transposes.py | Adds unit tests validating transpose counts and functional correctness, including propagation cases. |
| backends/arm/test/passes/fuse_tosa_transposes_comparison.py | Adds a runnable script to compare transpose counts pre/post optimization on representative models. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"(eliminated {before_count - after_count}), iterations={iteration}" | ||
| ) | ||
|
|
||
| return PassResult(graph_module, modified_overall) |
There was a problem hiding this comment.
FuseTosaTransposesPass.call() rewires/removes TRANSPOSE nodes in a way that can change intermediate tensor shapes (especially for the propagation pattern), but it never calls super().call(graph_module) to retrace and refresh meta['val'] / fake-tensor metadata. Several downstream Arm passes rely on get_first_fake_tensor(node) and meta['val'] being accurate; leaving stale metadata here can break later passes or lead to incorrect lowering. After graph mutations + recompile, invoke super().call(graph_module) (like ToTosaMemoryFormatPass/BroadcastArgsPass do) or otherwise re-run the metadata propagation.
| return PassResult(graph_module, modified_overall) | |
| refreshed_result = super().call(graph_module) | |
| return PassResult(refreshed_result.graph_module, modified_overall) |
| def test_identity_transpose_elimination() -> None: | ||
| """ | ||
| Test that identity transposes are eliminated. | ||
| Uses a simple pass-through module. | ||
| """ | ||
|
|
||
| class IdentityModule(torch.nn.Module): | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return x | ||
|
|
||
| def get_inputs(self) -> input_t: | ||
| return (torch.rand(1, 16, 8, 8),) | ||
|
|
||
| module = IdentityModule() | ||
| pipeline = PassPipeline[input_t]( | ||
| module, | ||
| module.get_inputs(), | ||
| pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], | ||
| passes_with_exported_program=[ | ||
| ToTosaMemoryFormatPass, | ||
| FuseTosaTransposesPass, | ||
| ], | ||
| ) | ||
| pipeline.pop_stage("run_method_and_compare_outputs") | ||
| pipeline.run() | ||
|
|
There was a problem hiding this comment.
test_identity_transpose_elimination() claims to verify identity transpose removal, but the pipeline is created without ops_before/ops_after assertions (and no other explicit checks). As written, this test only verifies the pass pipeline runs, not that identity transposes were eliminated. Consider adding an explicit check (e.g., expected TRANSPOSE count or verifying no identity-permutation TRANSPOSE nodes remain) so the test fails if the optimization regresses.
| def run_comparison( | ||
| model: nn.Module, | ||
| inputs: Tuple[torch.Tensor, ...], | ||
| model_name: str | ||
| ) -> Dict[str, int]: | ||
| """ | ||
| Run comparison of TRANSPOSE counts with and without FuseTosaTransposesPass. | ||
| """ | ||
| print(f"\n{'='*60}") | ||
| print(f"Testing: {model_name}") | ||
| print(f"{'='*60}") | ||
|
|
||
| # Run pipeline WITHOUT FuseTosaTransposesPass (baseline) | ||
| print("\n[1] Running WITHOUT FuseTosaTransposesPass...") | ||
| pipeline_baseline = PassPipeline( | ||
| model, | ||
| inputs, | ||
| pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], | ||
| passes_with_exported_program=[ToTosaMemoryFormatPass], | ||
| ) | ||
| pipeline_baseline.pop_stage("run_method_and_compare_outputs") | ||
| result_baseline = pipeline_baseline.run() | ||
|
|
||
| baseline_count = count_transposes(result_baseline.graph_module) | ||
| print(f" TRANSPOSE ops (baseline): {baseline_count}") | ||
|
|
||
| # Run pipeline WITH FuseTosaTransposesPass (optimized) | ||
| print("\n[2] Running WITH FuseTosaTransposesPass...") | ||
| pipeline_optimized = PassPipeline( | ||
| model, | ||
| inputs, | ||
| pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], | ||
| passes_with_exported_program=[ | ||
| ToTosaMemoryFormatPass, | ||
| FuseTosaTransposesPass, | ||
| ], | ||
| ) | ||
| pipeline_optimized.pop_stage("run_method_and_compare_outputs") | ||
| result_optimized = pipeline_optimized.run() | ||
|
|
||
| optimized_count = count_transposes(result_optimized.graph_module) | ||
| print(f" TRANSPOSE ops (optimized): {optimized_count}") | ||
|
|
||
| # Calculate reduction | ||
| reduction = baseline_count - optimized_count | ||
| reduction_pct = (reduction / baseline_count * 100) if baseline_count > 0 else 0 | ||
|
|
||
| print(f"\n[3] Results Summary:") | ||
| print(f" Baseline: {baseline_count} TRANSPOSE ops") | ||
| print(f" Optimized: {optimized_count} TRANSPOSE ops") | ||
| print(f" Reduction: {reduction} ops ({reduction_pct:.1f}%)") | ||
|
|
||
| return { | ||
| "model": model_name, | ||
| "baseline": baseline_count, | ||
| "optimized": optimized_count, | ||
| "reduction": reduction, | ||
| "reduction_pct": reduction_pct, | ||
| } |
There was a problem hiding this comment.
run_comparison() is annotated to return Dict[str, int], but the returned dict includes a string value ('model') and a float ('reduction_pct'). With pyre-strict enabled, this should be a type error. Use a TypedDict/dataclass for the result, or widen the return type (e.g., Dict[str, object] or dict[str, int | float | str]).
| def _target_name(target: object) -> str: | ||
| """Extract a recognizable name from a node target for string matching.""" | ||
| name = str(target) | ||
| # Handle exir_ops.backend.tosa.RESCALE.default → "RESCALE" | ||
| # Handle exir_ops.edge.aten.add.Tensor → "add.Tensor" | ||
| parts = name.rsplit(".", 2) | ||
| if len(parts) >= 2: | ||
| # For "backend__ops_tosa_RESCALE_default" patterns | ||
| if "RESCALE" in name: | ||
| return "RESCALE" | ||
| # Return the last two parts for ATen ops: "add.Tensor", "clamp.default", etc. | ||
| return ".".join(parts[-2:]) | ||
| return name |
There was a problem hiding this comment.
_target_name() uses str(target), but EdgeOpOverload/BackendOpOverload str includes schema text (see exir/dialects/edge/_ops.py), so the rsplit/join logic returns strings like "add.Tensor>: schema = ..." that will never match entries in _UNARY_ELEMENTWISE_TARGET_NAMES/_BINARY_ELEMENTWISE_TARGET_NAMES. This effectively disables elementwise propagation for edge ops (and likely makes the new propagation tests fail). Prefer using target.name (when present) and parsing that, or otherwise strip the schema suffix before matching.
|
Thanks for the PR, the tosa.TRANSPOSE is however in the process of being deprecated in favor of using regular permutes consistently, see #18948. This rework should solve the same problems as you are fixing here though! |
Summary:
Part of the Ethos-U55/U85 optimization stack targeting -22.6% NPU cycle reduction on Wake EMG.
Adds FuseTosaTransposesPass that eliminates redundant TOSA TRANSPOSE operations through four optimizations:
The propagation pattern handles the common case where ToTosaMemoryFormatPass inserts TRANSPOSE pairs around view_copy rank boundaries, with RESCALE and elementwise ops in between:
TRANSPOSE(p) → RESCALE → relu → RESCALE → TRANSPOSE(inv(p)) → RESCALE → relu → RESCALE
For binary elementwise ops (ADD, MUL, SUB), propagation is safe only when the non-primary operand is broadcast-safe (scalar or 1-element tensor).
Impact
Combined with FuseConsecutiveRescalesPass (next diff), reduces total NPU cycles on Wake/U55 by 22.6%. TRANSPOSE elimination directly reduces TRANSPOSE HW ops and enables the RESCALE fusion pass to find more fusible pairs.
Reviewed By: davidxili
Differential Revision: D92901685