Arm backend: Add MAX_POOL2D tosa dialect op#18970
Conversation
Change-Id: I2db9e00104174952eead47820dd28cfcd8942ff6 Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18970
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 6 Cancelled Jobs, 7 PendingAs of commit 81cb360 with merge base 9d72936 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
Adds support for lowering 2D max pooling to a TOSA dialect MAX_POOL2D op in the Arm backend, including a rewrite pass and corresponding operator visitor updates.
Changes:
- Introduce a new TOSA dialect fake op
MAX_POOL2Dwith shape/dtype validation and output meta computation. - Add
RewriteMaxPool2dPassand integrate it into the Arm pass pipelines (includingSizeAdjustInputPassdependencies). - Update the Arm operator visitor to lower
tosa.MAX_POOL2D.default, and add a new unit test covering the rewrite.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
backends/arm/tosa/dialect/ops/max_pool2d.py |
Adds fake TOSA op schema + output meta for MAX_POOL2D. |
backends/arm/tosa/dialect/__init__.py |
Ensures the new dialect op module is imported/registered. |
backends/arm/_passes/rewrite_max_pool2d_pass.py |
New pass rewriting edge max_pool2d into backend TOSA MAX_POOL2D. |
backends/arm/_passes/arm_pass_manager.py |
Inserts RewriteMaxPool2dPass into the TOSA pipeline. |
backends/arm/_passes/size_adjust_input_pass.py |
Adds RewriteMaxPool2dPass to required-after passes. |
backends/arm/_passes/__init__.py |
Exposes the new rewrite pass. |
backends/arm/operators/op_tosa_max_pool2d.py |
Switches visitor to handle tosa.MAX_POOL2D.default and uses provided attrs directly. |
backends/arm/operators/__init__.py |
Registers the new op_tosa_max_pool2d visitor (removes old import). |
backends/arm/test/passes/test_rewrite_max_pool2d_pass.py |
Adds rewrite-pass test coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if len(args) > 2 and args[2] is not None and len(args[2]) > 0: | ||
| stride = _to_2tuple(args[2]) | ||
| else: | ||
| stride = kernel |
There was a problem hiding this comment.
args[2] (stride) can be an int (e.g., max_pool2d(..., stride=2)), but this code calls len(args[2]), which will raise TypeError: object of type 'int' has no len(). Normalize stride via type checks (int/Sequence/None) before using len() and before passing it to _to_2tuple so the pass works for both scalar and tuple/list strides.
| if len(args) > 2 and args[2] is not None and len(args[2]) > 0: | |
| stride = _to_2tuple(args[2]) | |
| else: | |
| stride = kernel | |
| raw_stride = args[2] if len(args) > 2 else None | |
| if raw_stride is None: | |
| stride = kernel | |
| else: | |
| stride = _to_2tuple(raw_stride) |
| ) | ||
| pipeline.pop_stage( | ||
| "run_method_and_compare_outputs" | ||
| ) # Cannnot run aten graph with tosa dialect ops |
There was a problem hiding this comment.
Typo in comment: "Cannnot" should be "Cannot".
| ) # Cannnot run aten graph with tosa dialect ops | |
| ) # Cannot run aten graph with tosa dialect ops |
oscarandersson8218
left a comment
There was a problem hiding this comment.
Doesn't seem to need any Buck-changes. LGTM!
|
Fail unrelated |
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell