feat(MLX): Add handlers for aten.flip and aten.trunc#18982
feat(MLX): Add handlers for aten.flip and aten.trunc#18982Jah-yee wants to merge 1 commit intopytorch:mainfrom
Conversation
- flip: decompose using SliceNode with step=-1 chained per dim - trunc: use where(x>=0, floor(x), ceil(x)) since MLX lacks native trunc Both are good-first-issues from pytorch/executorch. Signed-off-by: RoomWithOutRoof <RoomWithOutRoof@users.noreply.github.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18982
Note: Links to docs will display an error until the docs builds have been completed. ❌ 14 New Failures, 1 Cancelled Job, 2 Unrelated FailuresAs of commit 1c4b0a1 with merge base 490ec5c ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review the PR |
|
Claude finished @nil-is-all's task in 3m 52s —— View job Review of PR #18982: Add handlers for aten.flip and aten.trunc
OverallBoth handlers follow the established patterns in aten.flip handler (
|
| # | Severity | Issue |
|---|---|---|
| 1 | Medium | Add dynamic shape guard for flip dimension sizes |
| 2 | Low | Verify negative dims work with MLX SliceNode axis |
| Tests | Medium | Add test cases for both ops in test_ops.py |
The rest are nitpicks. Nice work on the decomposition approaches — using SliceNode chaining for flip and the where(x>=0, floor, ceil) pattern for trunc are clean solutions that avoid needing new MLX runtime ops.
Good day
Summary
This PR adds MLX op handlers for two PyTorch aten operators:
Implementation Details
aten.flip (issue #18918)
Uses with chained for each flip dimension:
For each dimension in the dims list, we emit a slice from to with step=-1. Chaining these slices achieves the flip operation without needing a dedicated MLX kernel.
aten.trunc (issue #18923)
Uses the mathematical decomposition:
Since MLX doesn't have a native trunc operation, we decompose it using existing nodes (GreaterEqualNode, FloorNode, CeilNode, WhereNode). This is the same approach used in other PyTorch decomposition strategies.
Testing
Both handlers use existing MLX nodes (SliceNode, IdCopyNode, GreaterEqualNode, FloorNode, CeilNode, WhereNode) — no new schema or runtime code needed.
References
aten.trunc#18923 (trunc handler)Thank you for your attention. If there are any issues or suggestions, please leave a comment and I will address them promptly.
Warmly,
RoomWithOutRoof