diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index f039263a37..93b848dc71 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -10,4 +10,6 @@ import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv import pytensor.link.mlx.dispatch.blockwise +import pytensor.link.mlx.dispatch.extra_ops +import pytensor.link.mlx.dispatch.sort # isort: on diff --git a/pytensor/link/mlx/dispatch/core.py b/pytensor/link/mlx/dispatch/core.py index 6783698523..be3ed37e3a 100644 --- a/pytensor/link/mlx/dispatch/core.py +++ b/pytensor/link/mlx/dispatch/core.py @@ -61,14 +61,17 @@ def split(x, axis, splits): # Resolve constants for significant performance improvement (14x speedup) if constant_axis is not None: axis = int(constant_axis) + else: + raise ValueError( + "Symbolic axis is not supported in MLX Split implementation." + ) if constant_splits is not None: - splits = constant_splits - cumsum_splits = np.cumsum(splits[:-1]) + splits_arr = mx.array(constant_splits) else: - # Dynamic case - use MLX operations splits_arr = mx.array(splits) - cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() + + cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # Validation checks if len(splits) != op.len_splits: diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index e8d9e53f42..05331dccaf 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -1,6 +1,7 @@ from functools import singledispatch import mlx.core as mx +import mlx.nn as mlx_nn import numpy as np from pytensor.link.mlx.dispatch.basic import mlx_funcify @@ -40,7 +41,7 @@ ) from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.special import Softmax, SoftmaxGrad +from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @mlx_funcify.register(DimShuffle) @@ -142,6 +143,16 @@ def softmax_grad(dy, sm): return softmax_grad +@mlx_funcify.register(LogSoftmax) +def mlx_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + + def log_softmax(x): + return mlx_nn.log_softmax(x, axis=axis) + + return log_softmax + + @mlx_funcify.register(Softplus) def mlx_funcify_Softplus(op, **kwargs): def softplus(x): diff --git a/pytensor/link/mlx/dispatch/extra_ops.py b/pytensor/link/mlx/dispatch/extra_ops.py new file mode 100644 index 0000000000..cd8a0e2c8d --- /dev/null +++ b/pytensor/link/mlx/dispatch/extra_ops.py @@ -0,0 +1,35 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.extra_ops import CumOp, Repeat + + +@mlx_funcify.register(CumOp) +def mlx_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + match mode: + case "add": + return mx.cumsum(x, axis=axis) + case "mul": + return mx.cumprod(x, axis=axis) + case _: + raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX") + + return cumop + + +@mlx_funcify.register(Repeat) +def jax_funcify_Repeat(op, **kwargs): + axis = op.axis + + def repeat(x, repeats, axis=axis): + if not isinstance(repeats, int): + raise NotImplementedError( + "MLX repeat does not support sequence-valued repeat argument." + ) + return mx.repeat(x, repeats, axis=axis) + + return repeat diff --git a/pytensor/link/mlx/dispatch/sort.py b/pytensor/link/mlx/dispatch/sort.py new file mode 100644 index 0000000000..61e13dbba6 --- /dev/null +++ b/pytensor/link/mlx/dispatch/sort.py @@ -0,0 +1,38 @@ +import warnings + +import mlx.core as mx + +from pytensor.link.mlx.dispatch.basic import mlx_funcify +from pytensor.tensor.sort import ArgSortOp, SortOp + + +@mlx_funcify.register(SortOp) +def mlx_funcify_Sort(op, **kwargs): + kind = op.kind + if kind != "quicksort": + warnings.warn( + message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be " + f"ignored.", + category=UserWarning, + ) + + def sort(x, axis): + return mx.sort(x, axis=axis) + + return sort + + +@mlx_funcify.register(ArgSortOp) +def mlx_funcify_ArgSort(op, **kwargs): + kind = op.kind + if kind != "quicksort": + warnings.warn( + message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be " + f"ignored.", + category=UserWarning, + ) + + def argsort(x, axis): + return mx.argsort(x, axis=axis) + + return argsort diff --git a/tests/link/mlx/test_core.py b/tests/link/mlx/test_core.py index 4539f44cec..d50d3a9959 100644 --- a/tests/link/mlx/test_core.py +++ b/tests/link/mlx/test_core.py @@ -2,9 +2,15 @@ import pytest import pytensor +from pytensor import config from pytensor import tensor as pt from pytensor.tensor.basic import Alloc -from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx +from tests.link.mlx.test_basic import ( + compare_mlx_and_py, + compile_mode, + mlx_mode_no_compile, + mx, +) def test_alloc_with_different_shape_types(): @@ -137,3 +143,24 @@ def test_empty_dynamic_shape(): "used inside compiled functions", ): f_compiled(3, 4) + + +def test_split_const_axis_const_splits_compiled(): + x = pt.vector("x") + splits = [2, 3] + outs = pt.split(x, splits, len(splits), axis=0) + compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")]) + + +def test_split_dynamic_axis_const_splits(): + x = pt.matrix("x") + axis = pt.scalar("axis", dtype="int64") + splits = [1, 2, 3] + outs = pt.split(x, splits, len(splits), axis=axis) + + test_input = np.arange(12).astype(config.floatX).reshape(2, 6) + + with pytest.raises( + ValueError, match="Symbolic axis is not supported in MLX Split implementation" + ): + compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)]) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index 63507a9e41..7f7e77f761 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -30,7 +30,7 @@ from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import min as pt_min from pytensor.tensor.math import sum as pt_sum -from pytensor.tensor.special import SoftmaxGrad, softmax +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, vector, vectors from tests.link.mlx.test_basic import compare_mlx_and_py @@ -97,6 +97,15 @@ def test_softmax_grad(axis): compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_logsoftmax(axis): + x = matrix("x") + x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = log_softmax(x, axis=axis) + + compare_mlx_and_py([x], [out], [x_test_value]) + + @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) @pytest.mark.parametrize("axis", [0, 1]) def test_logsumexp_benchmark(size, axis, benchmark): diff --git a/tests/link/mlx/test_extra_ops.py b/tests/link/mlx/test_extra_ops.py new file mode 100644 index 0000000000..d3ce794f1d --- /dev/null +++ b/tests/link/mlx/test_extra_ops.py @@ -0,0 +1,24 @@ +import numpy as np +import pytest + +from pytensor.configdefaults import config +from pytensor.tensor import extra_ops as pt_extra_ops +from pytensor.tensor.type import matrix +from tests.link.mlx.test_basic import compare_mlx_and_py + + +mx = pytest.importorskip("mlx.core") + + +def test_extra_ops(): + a = matrix("a") + a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + out = pt_extra_ops.cumsum(a, axis=0) + compare_mlx_and_py([a], [out], [a_test]) + + out = pt_extra_ops.cumprod(a, axis=1) + compare_mlx_and_py([a], [out], [a_test]) + + out = pt_extra_ops.repeat(a, 3, axis=1) + compare_mlx_and_py([a], [out], [a_test]) diff --git a/tests/link/mlx/test_sort.py b/tests/link/mlx/test_sort.py new file mode 100644 index 0000000000..8d37a76ab1 --- /dev/null +++ b/tests/link/mlx/test_sort.py @@ -0,0 +1,22 @@ +import numpy as np +import pytest + +from pytensor.tensor.sort import argsort, sort +from pytensor.tensor.type import matrix +from tests.link.mlx.test_basic import compare_mlx_and_py + + +@pytest.mark.parametrize("axis", [None, -1]) +@pytest.mark.parametrize("func", (sort, argsort)) +def test_sort(func, axis): + x = matrix("x", shape=(2, 2), dtype="float64") + out = func(x, axis=axis) + arr = np.array([[1.0, 4.0], [5.0, 2.0]]) + compare_mlx_and_py([x], [out], [arr]) + + +def test_sort_invalid_kind_warning(): + x = matrix("x", shape=(2, 2), dtype="float64") + z = sort(x, axis=-1, kind="mergesort") + with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"): + z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX")