Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
RepeatNode,
ReshapeNode,
RMSNormNode,
RollNode,
RopeNode,
RoundNode,
RsqrtNode,
Expand Down Expand Up @@ -1677,6 +1678,45 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.roll.default])
def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 3, "aten.roll")
require_kwargs(P.kwargs(n), set(), "aten.roll")
x = args[0]
shifts_arg = args[1]
dims_arg = args[2] if len(args) > 2 else []

shifts = [shifts_arg] if isinstance(shifts_arg, int) else list(shifts_arg)
dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg)

# Flat roll (torch.roll with dims=[]) would require reshape + roll +
# reshape at the graph level. Not yet supported; Swin-style usage always
# passes explicit dims.
if not dims:
raise NotImplementedError(
"aten.roll without dims (flat roll) is not supported by the MLX "
"delegate yet."
)
if len(shifts) != len(dims):
raise ValueError(
f"aten.roll: shifts and dims must have the same length, got "
f"shifts={shifts} (len={len(shifts)}) dims={dims} (len={len(dims)})"
)
require_static_ints(dims, "dims", "aten.roll")

out = P.make_or_get_slot(n)
P.emit(
RollNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
shift=[P.to_int_or_vid(s) for s in shifts],
axes=dims,
)
)
return out


@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
Expand Down
10 changes: 10 additions & 0 deletions backends/mlx/runtime/MLXInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,13 @@ inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) {
}
}

inline void exec_roll(const RollNode& n, ExecutionState& st, StreamOrDevice s) {
const auto& x = st.const_tensor_ref(n.x);
auto shifts = to_shape(n.shift, st);
std::vector<int> axes(n.axes.begin(), n.axes.end());
st.set_tensor(n.out, roll(x, shifts, axes, s));
}

inline void
exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) {
const auto& x = st.const_tensor_ref(n.x);
Expand Down Expand Up @@ -2199,6 +2206,9 @@ class Interpreter {
case OpCode::REPEAT:
ops::exec_repeat(std::get<RepeatNode>(instr.node), st, s);
break;
case OpCode::ROLL:
ops::exec_roll(std::get<RollNode>(instr.node), st, s);
break;
case OpCode::SORT:
ops::exec_sort(std::get<SortNode>(instr.node), st, s);
break;
Expand Down
13 changes: 12 additions & 1 deletion backends/mlx/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,16 @@ table ArgPartitionNode {
axis: int32;
}

// Shift tensor elements along specified axes with wrap-around.
// Maps to mlx::core::roll(a, shifts, axes).
// Flat roll (torch.roll with dims=None) is not yet supported.
table RollNode {
x: Tid (required);
out: Tid (required);
shift: [IntOrVid] (required); // Shift amount per axis (can be dynamic)
axes: [int32] (required); // Axes to roll along; len(shift) == len(axes)
}


// =============================================================================
// Math ops - Unary element-wise
Expand Down Expand Up @@ -1113,7 +1123,8 @@ union OpNode {
GatherMmNode,
GatherQmmNode,
ScanNode,
MetalKernelNode
MetalKernelNode,
RollNode
// BC: Add new op nodes here (append only)
}

Expand Down
53 changes: 53 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,59 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (x,)


class RollModel(nn.Module):
"""Model that rolls a tensor along specified dimensions."""

def __init__(self, shifts: Tuple[int, ...], dims: Tuple[int, ...]):
super().__init__()
self.shifts = shifts
self.dims = dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.roll(x, shifts=self.shifts, dims=self.dims)


@register_test
class RollTest(OpTestCase):
"""Test case for torch.roll()."""

name = "roll"
rtol = 1e-5
atol = 1e-5

def __init__(
self,
input_shape: Tuple[int, ...] = (4, 5),
shifts: Tuple[int, ...] = (1,),
dims: Tuple[int, ...] = (0,),
):
self.input_shape = input_shape
self.shifts = shifts
self.dims = dims
shift_str = ",".join(str(s) for s in shifts)
dim_str = ",".join(str(d) for d in dims)
self.name = f"roll_shift({shift_str})_dim({dim_str})"

@classmethod
def get_test_configs(cls) -> List["RollTest"]:
return [
cls(input_shape=(8,), shifts=(2,), dims=(0,)),
cls(input_shape=(4, 5), shifts=(1,), dims=(0,)),
cls(input_shape=(4, 5), shifts=(-2,), dims=(1,)),
cls(input_shape=(3, 4, 5), shifts=(3,), dims=(2,)),
cls(input_shape=(3, 4, 5), shifts=(1, 2), dims=(0, 2)),
cls(input_shape=(3, 4, 5), shifts=(-1, -2, -3), dims=(0, 1, 2)),
cls(input_shape=(3, 4, 5), shifts=(2,), dims=(-1,)),
]

def create_model(self) -> nn.Module:
return RollModel(self.shifts, self.dims)

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
x = torch.randn(self.input_shape)
return (x,)


class CatNModel(nn.Module):
"""Model that concatenates N tensors along a dimension."""

Expand Down
Loading