Arm backend: Fix quantized constant-folding for aten.cat lists (#18971)#19064
Arm backend: Fix quantized constant-folding for aten.cat lists (#18971)#19064perheld merged 1 commit intopytorch:mainfrom
Conversation
…ch#18971) FuseConstantArgsPass resolved input_qparams by flattened input-node index, while FoldAndAnnotateQParamsPass stores them by top-level argument index. For aten.cat with a list-valued tensor argument, this caused only the first tensor to be dequantized before folding, which corrupted the fused constant. Resolve qparams by top-level argument index and propagate that qparam through nested list and tuple arguments. Add a regression test for quantized aten.cat constant folding with list-valued tensor inputs. Signed-off-by: Per Held <per.held@arm.com> Change-Id: I6e1a012d82a5dbeecb403c440a2944953dd5cba7
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19064
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Cancelled Jobs, 3 Unrelated FailuresAs of commit ca01b3e with merge base c48ea12 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
Fixes Arm backend constant-folding for quantized ops by aligning how input_qparams are resolved with how they’re produced, specifically for ops like aten.cat where tensor inputs can be nested inside list/tuple arguments.
Changes:
- Update
FuseConstantArgsPassto resolveinput_qparamsby top-level positional argument index (and propagate that qparam through nested list/tuple args). - Add a regression test that constant-folds a quantized
aten.catwhose tensor inputs are passed via a list/tuple.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
backends/arm/_passes/fuse_constant_ops_pass.py |
Fix qparam lookup to use top-level arg index and apply it to nested list/tuple tensor args during constant folding. |
backends/arm/test/passes/test_fuse_constant_ops_pass.py |
Add a regression test ensuring quantized constant-folding for aten.cat with list/tuple tensor inputs produces the correct fused constant. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cat_node = next( | ||
| node | ||
| for node in exported_program.graph_module.graph.nodes | ||
| if node.op == "call_function" |
| exported_program.graph_module | ||
| ) | ||
|
|
||
| assert list(exported_program.state_dict) == ["aten_cat_default_fused_const"] |
FuseConstantArgsPass resolved input_qparams by flattened input-node index, while FoldAndAnnotateQParamsPass stores them by top-level argument index. For aten.cat with a list-valued tensor argument, this caused only the first tensor to be dequantized before folding, which corrupted the fused constant.
Resolve qparams by top-level argument index and propagate that qparam through nested list and tuple arguments. Add a regression test for quantized aten.cat constant folding with list-valued tensor inputs.
Change-Id: I6e1a012d82a5dbeecb403c440a2944953dd5cba7
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell