Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.extra_ops
import pytensor.link.mlx.dispatch.sort
# isort: on
11 changes: 7 additions & 4 deletions pytensor/link/mlx/dispatch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ def split(x, axis, splits):
# Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None:
axis = int(constant_axis)
else:
raise ValueError(
"Symbolic axis is not supported in MLX Split implementation."
)

if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
splits_arr = mx.array(constant_splits)
else:
# Dynamic case - use MLX operations
splits_arr = mx.array(splits)
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()

cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()

# Validation checks
if len(splits) != op.len_splits:
Expand Down
13 changes: 12 additions & 1 deletion pytensor/link/mlx/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

import mlx.core as mx
import mlx.nn as mlx_nn
import numpy as np

from pytensor.link.mlx.dispatch.basic import mlx_funcify
Expand Down Expand Up @@ -40,7 +41,7 @@
)
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import Softmax, SoftmaxGrad
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


@mlx_funcify.register(DimShuffle)
Expand Down Expand Up @@ -142,6 +143,16 @@ def softmax_grad(dy, sm):
return softmax_grad


@mlx_funcify.register(LogSoftmax)
def mlx_funcify_LogSoftmax(op, **kwargs):
axis = op.axis

def log_softmax(x):
return mlx_nn.log_softmax(x, axis=axis)

return log_softmax


@mlx_funcify.register(Softplus)
def mlx_funcify_Softplus(op, **kwargs):
def softplus(x):
Expand Down
35 changes: 35 additions & 0 deletions pytensor/link/mlx/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.extra_ops import CumOp, Repeat


@mlx_funcify.register(CumOp)
def mlx_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode

def cumop(x, axis=axis, mode=mode):
match mode:
case "add":
return mx.cumsum(x, axis=axis)
case "mul":
return mx.cumprod(x, axis=axis)
case _:
raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX")

return cumop


@mlx_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis

def repeat(x, repeats, axis=axis):
if not isinstance(repeats, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is known at dispatch time, raise then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeats is a symbolic input. We only know axis as dispatch time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a weird mlx-specific limitation. There might be a work-around, but I don't want to do it in this PR. Just getting some common cases covered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you know whether repeats is 0d or 1d at dispatch time? actually isn't the op always 1d now and we use alloc for 0d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? It's a symbolic input

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh by checking op.inputs[1] :fivehead:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by checking node.inputs. Node is the second argument to all dispatches

raise NotImplementedError(
"MLX repeat does not support sequence-valued repeat argument."
)
return mx.repeat(x, repeats, axis=axis)

return repeat
38 changes: 38 additions & 0 deletions pytensor/link/mlx/dispatch/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import warnings

import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp


@mlx_funcify.register(SortOp)
def mlx_funcify_Sort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)

def sort(x, axis):
return mx.sort(x, axis=axis)

return sort


@mlx_funcify.register(ArgSortOp)
def mlx_funcify_ArgSort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)

def argsort(x, axis):
return mx.argsort(x, axis=axis)

return argsort
29 changes: 28 additions & 1 deletion tests/link/mlx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
import pytest

import pytensor
from pytensor import config
from pytensor import tensor as pt
from pytensor.tensor.basic import Alloc
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
from tests.link.mlx.test_basic import (
compare_mlx_and_py,
compile_mode,
mlx_mode_no_compile,
mx,
)


def test_alloc_with_different_shape_types():
Expand Down Expand Up @@ -137,3 +143,24 @@ def test_empty_dynamic_shape():
"used inside compiled functions",
):
f_compiled(3, 4)


def test_split_const_axis_const_splits_compiled():
x = pt.vector("x")
splits = [2, 3]
outs = pt.split(x, splits, len(splits), axis=0)
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])


def test_split_dynamic_axis_const_splits():
x = pt.matrix("x")
axis = pt.scalar("axis", dtype="int64")
splits = [1, 2, 3]
outs = pt.split(x, splits, len(splits), axis=axis)

test_input = np.arange(12).astype(config.floatX).reshape(2, 6)

with pytest.raises(
ValueError, match="Symbolic axis is not supported in MLX Split implementation"
):
compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)])
11 changes: 10 additions & 1 deletion tests/link/mlx/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, softmax
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, vector, vectors
from tests.link.mlx.test_basic import compare_mlx_and_py

Expand Down Expand Up @@ -97,6 +97,15 @@ def test_softmax_grad(axis):
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)

compare_mlx_and_py([x], [out], [x_test_value])


@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
Expand Down
24 changes: 24 additions & 0 deletions tests/link/mlx/test_extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import pytest

from pytensor.configdefaults import config
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py


mx = pytest.importorskip("mlx.core")


def test_extra_ops():
a = matrix("a")
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))

out = pt_extra_ops.cumsum(a, axis=0)
compare_mlx_and_py([a], [out], [a_test])

out = pt_extra_ops.cumprod(a, axis=1)
compare_mlx_and_py([a], [out], [a_test])

out = pt_extra_ops.repeat(a, 3, axis=1)
compare_mlx_and_py([a], [out], [a_test])
22 changes: 22 additions & 0 deletions tests/link/mlx/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest

from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py


@pytest.mark.parametrize("axis", [None, -1])
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_mlx_and_py([x], [out], [arr])


def test_sort_invalid_kind_warning():
x = matrix("x", shape=(2, 2), dtype="float64")
z = sort(x, axis=-1, kind="mergesort")
with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"):
z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX")