-
Notifications
You must be signed in to change notification settings - Fork 767
Arm backend: Improve dtype validation #15871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arm backend: Improve dtype validation #15871
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15871
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit de9a130 with merge base a2078c6 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves dtype validation in the ARM backend's node visitors by making dtype support conditional on TOSA specifications and extensions, adding bool dtype support to several operators, and refining dtype lists for better consistency and correctness.
- Enhanced dtype validation to conditionally support INT16 and INT48 based on TOSA 1.0 "int16" extension availability
- Added BOOL dtype support to operators like permute, repeat, expand, slice, and cat
- Removed INT8/INT16 from comparison operators (eq, lt, le, gt, ge) to align with TOSA spec constraints
- Added test cases with bool tensors and xfail markers for U55 hardware which doesn't support bool
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/test/ops/test_repeat.py | Added bool test case with xfail for U55 |
| backends/arm/test/ops/test_permute.py | Added bool test case with xfail for U55, renamed int16 quantization test functions for clarity (int16→16a8w) |
| backends/arm/test/ops/test_expand.py | Added bool test case with xfail for U55 |
| backends/arm/operators/ops_identity.py | Added conditional dtype validation based on TOSA spec and int16 extension support |
| backends/arm/operators/op_where.py | Deduplicated BOOL from supported dtypes list (moved to base list) |
| backends/arm/operators/op_tosa_transpose.py | Reordered dtype list for consistency (BOOL first, then INT types, then FP types) |
| backends/arm/operators/op_tosa_table.py | Refactored dtype validation to conditionally support INT16 input/INT32 output based on int16 extension |
| backends/arm/operators/op_tosa_resize.py | Improved dtype validation with conditional support for INT16/INT48 based on int16 extension |
| backends/arm/operators/op_tosa_matmul.py | Added conditional support for INT16 input and INT48 output based on int16 extension |
| backends/arm/operators/op_sum.py | Added explicit dtype validation for INT32 and FP32 |
| backends/arm/operators/op_slice.py | Added BOOL dtype support |
| backends/arm/operators/op_repeat.py | Added BOOL dtype support |
| backends/arm/operators/op_permute.py | Added BOOL dtype support |
| backends/arm/operators/op_mul.py | Added INT8 and INT16 dtype support |
| backends/arm/operators/op_max_pool2d.py | Added conditional INT16 support based on int16 extension |
| backends/arm/operators/op_lt.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_le.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_index_select.py | Added validation utilities and renamed unused variable from 'index' to '_' |
| backends/arm/operators/op_gt.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_ge.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_eq.py | Removed INT8 and INT16 from supported input dtypes |
| backends/arm/operators/op_clamp.py | Added conditional INT16 support based on int16 extension |
| backends/arm/operators/op_cat.py | Added BOOL dtype support and conditional INT16 support, improved validation with proper TosaArg conversion |
| backends/arm/operators/op_avg_pool2d.py | Added conditional INT16 support based on int16 extension |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
This seem to get fails in the Arm tests that need to pass before merge, not sure if it's flakey tests, if so a rerun/rebase could help :( |
zingo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved but It would be nice to address the randint in the test case backends/arm/test/ops/test_expand.py se review comment on it.
|
Seem to have some arm tests fails :( |
9be7bfd to
c3946a4
Compare
Improve dtype validation in NodeVisitors. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: Ieb9ced1ae8d2db916e6c8bc0b45773a640d330db
7f93b26 to
de9a130
Compare
|
Cortex-m tests unrelated, broken independent of this PR |
Improve dtype validation in node-visitors. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
Improve dtype validation in node-visitors.
cc @freddan80 @per @zingo @digantdesai