In [1]:
# General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify

# Import for testing
import pytest
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value

# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import CumOp

@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
    axis = op.axis
    mode = op.mode

    def cumop(x, axis=axis, mode=mode):
        if mode == "add":
            return jnp.cumsum(x, axis=axis)
        else:
            raise NotImplementedError("JAX does not support cumprod function at the moment.")

    return cumop


def test_jax_CumOp():
    """Test JAX conversion of the `CumOp` `Op`."""
    a = pt.matrix("a")
    a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
    
    out = pt.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
    
    with pytest.raises(NotImplementedError):
        out = pt.cumprod(a, axis=1)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

test_jax_CumOp()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
