From 726c72160bf871ac8ea0112693b5321e5a95d277 Mon Sep 17 00:00:00 2001 From: Ishan Godawatta Date: Wed, 22 Apr 2026 01:10:28 +0100 Subject: [PATCH] feat(mlx): add handler for aten.roll Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema table, the custom handler for the (shifts, dims) args, the exec_roll runtime, and test cases covering 1D, 2D, multi-axis, negative shifts, and negative dims. Flat roll (dims=[]) is explicitly NotImplementedError for now; all known use cases (Swin Transformer shift-window attention) pass dims. Fixes #18919 --- backends/mlx/ops.py | 40 ++++++++++++++++++++ backends/mlx/runtime/MLXInterpreter.h | 10 +++++ backends/mlx/serialization/schema.fbs | 13 ++++++- backends/mlx/test/test_ops.py | 53 +++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 1 deletion(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..91a50ac0b98 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -116,6 +116,7 @@ RepeatNode, ReshapeNode, RMSNormNode, + RollNode, RopeNode, RoundNode, RsqrtNode, @@ -1677,6 +1678,45 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.roll.default]) +def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.roll") + require_kwargs(P.kwargs(n), set(), "aten.roll") + x = args[0] + shifts_arg = args[1] + dims_arg = args[2] if len(args) > 2 else [] + + shifts = [shifts_arg] if isinstance(shifts_arg, int) else list(shifts_arg) + dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg) + + # Flat roll (torch.roll with dims=[]) would require reshape + roll + + # reshape at the graph level. Not yet supported; Swin-style usage always + # passes explicit dims. + if not dims: + raise NotImplementedError( + "aten.roll without dims (flat roll) is not supported by the MLX " + "delegate yet." + ) + if len(shifts) != len(dims): + raise ValueError( + f"aten.roll: shifts and dims must have the same length, got " + f"shifts={shifts} (len={len(shifts)}) dims={dims} (len={len(dims)})" + ) + require_static_ints(dims, "dims", "aten.roll") + + out = P.make_or_get_slot(n) + P.emit( + RollNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + shift=[P.to_int_or_vid(s) for s in shifts], + axes=dims, + ) + ) + return out + + @REGISTRY.register(target=[torch.ops.aten.index.Tensor]) def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot: args = P.args(n) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 9fa08ab722d..1ef2713de18 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1726,6 +1726,13 @@ inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) { } } +inline void exec_roll(const RollNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto shifts = to_shape(n.shift, st); + std::vector axes(n.axes.begin(), n.axes.end()); + st.set_tensor(n.out, roll(x, shifts, axes, s)); +} + inline void exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) { const auto& x = st.const_tensor_ref(n.x); @@ -2199,6 +2206,9 @@ class Interpreter { case OpCode::REPEAT: ops::exec_repeat(std::get(instr.node), st, s); break; + case OpCode::ROLL: + ops::exec_roll(std::get(instr.node), st, s); + break; case OpCode::SORT: ops::exec_sort(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 6e8d6f47db8..89254019689 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -668,6 +668,16 @@ table ArgPartitionNode { axis: int32; } +// Shift tensor elements along specified axes with wrap-around. +// Maps to mlx::core::roll(a, shifts, axes). +// Flat roll (torch.roll with dims=None) is not yet supported. +table RollNode { + x: Tid (required); + out: Tid (required); + shift: [IntOrVid] (required); // Shift amount per axis (can be dynamic) + axes: [int32] (required); // Axes to roll along; len(shift) == len(axes) +} + // ============================================================================= // Math ops - Unary element-wise @@ -1113,7 +1123,8 @@ union OpNode { GatherMmNode, GatherQmmNode, ScanNode, - MetalKernelNode + MetalKernelNode, + RollNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7ba3902e436..d5a349d05c0 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -855,6 +855,59 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x,) +class RollModel(nn.Module): + """Model that rolls a tensor along specified dimensions.""" + + def __init__(self, shifts: Tuple[int, ...], dims: Tuple[int, ...]): + super().__init__() + self.shifts = shifts + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.roll(x, shifts=self.shifts, dims=self.dims) + + +@register_test +class RollTest(OpTestCase): + """Test case for torch.roll().""" + + name = "roll" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 5), + shifts: Tuple[int, ...] = (1,), + dims: Tuple[int, ...] = (0,), + ): + self.input_shape = input_shape + self.shifts = shifts + self.dims = dims + shift_str = ",".join(str(s) for s in shifts) + dim_str = ",".join(str(d) for d in dims) + self.name = f"roll_shift({shift_str})_dim({dim_str})" + + @classmethod + def get_test_configs(cls) -> List["RollTest"]: + return [ + cls(input_shape=(8,), shifts=(2,), dims=(0,)), + cls(input_shape=(4, 5), shifts=(1,), dims=(0,)), + cls(input_shape=(4, 5), shifts=(-2,), dims=(1,)), + cls(input_shape=(3, 4, 5), shifts=(3,), dims=(2,)), + cls(input_shape=(3, 4, 5), shifts=(1, 2), dims=(0, 2)), + cls(input_shape=(3, 4, 5), shifts=(-1, -2, -3), dims=(0, 1, 2)), + cls(input_shape=(3, 4, 5), shifts=(2,), dims=(-1,)), + ] + + def create_model(self) -> nn.Module: + return RollModel(self.shifts, self.dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + class CatNModel(nn.Module): """Model that concatenates N tensors along a dimension."""