From 9cb5e8215389abc2a6b6a837aef8488741a6cc61 Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 12 Dec 2024 17:25:07 +0100 Subject: [PATCH 01/30] first commit for a decorator that transforms JAX to pytensor --- pytensor/link/jax/ops.py | 424 +++++++++++++++++++++++++++++++ tests/link/jax/test_as_jax_op.py | 26 ++ 2 files changed, 450 insertions(+) create mode 100644 pytensor/link/jax/ops.py create mode 100644 tests/link/jax/test_as_jax_op.py diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..130ece6eda --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,424 @@ +"""Convert a jax function to a pytensor compatible function.""" + +import functools as ft +import logging +from collections.abc import Sequence + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_map, tree_unflatten + +import pytensor.compile.builders +import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op +from pytensor.link.jax.dispatch import jax_funcify + + +log = logging.getLogger(__name__) + + +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + + +def as_jax_op(jaxfunc, name=None): + """Return a Pytensor from a JAX jittable function. + + This decorator transforms any JAX jittable function into a function that accepts + and returns `pytensor.Variables`. The jax jittable function can accept any + nested python structure (pytrees) as input, and return any nested Python structure. + + It requires to define the output types of the returned values as pytensor types. A + unique name should also be passed in case the name of the jaxfunc is identical to + some other node. The design of this function is based on + https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ + + Parameters + ---------- + jaxfunc : jax jittable function + function for which the node is created, can return multiple tensors as a tuple. + It is required that all return values are able to transformed to + pytensor.Variable. + name: str + Name of the created pytensor Op, defaults to the name of the passed function. + Only used internally in the pytensor graph. + + Returns + ------- + A function which can be used in a pymc.Model as function, is differentiable + and the resulting model can be compiled either with the default C backend, or + the JAX backend. + + + Notes + ----- + The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, + available at + `pymc-labls.io `__. + To accept functions and non pytensor variables as input, the function make use + of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + variables. Shapes are inferred using + :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + """ + + def func(*args, **kwargs): + """Return a pytensor from a jax jittable function.""" + ### Split variables: in the ones that will be transformed to JAX inputs, + ### pytensor.Variables; _WrappedFunc, that are functions that have been returned + ### from a transformed function; and the rest, static variables that are not + ### transformed. + + pt_vars, static_vars_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + + func_vars, static_vars = eqx.partition( + static_vars_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) + pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) + """ + def func_unwrapped(vars_all, static_vars): + vars, vars_from_func = vars_all["vars"], vars_all["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine(vars, static_vars, func_vars_evaled) + return self.jaxfunc(*args, **kwargs) + """ + + pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + pt_vars_types_flat = [var.type for var in pt_vars_flat] + shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) + + dummy_inputs_jax = jax.tree_util.tree_map( + lambda var, shape: jnp.empty( + [int(dim.eval()) for dim in shape], dtype=var.type.dtype + ), + pt_vars, + shapes_vars, + ) + + # Combine the static variables with the inputs, and split them again in the + # output. Static variables don't take part in the graph, or might be a + # a function that is returned. + jaxfunc_partitioned, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + + func_flattened = _flatten_func(jaxfunc_partitioned, vars_treedef) + + jaxtypes_outvars = jax.eval_shape( + ft.partial(jaxfunc_partitioned, vars=dummy_inputs_jax), + ) + + jaxtypes_outvars_flat, outvars_treedef = tree_flatten(jaxtypes_outvars) + + pttypes_outvars = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) + for var in jaxtypes_outvars_flat + ] + + ### Call the function that accepts flat inputs, which in turn calls the one that + ### combines the inputs and static variables. + jitted_sol_op_jax = jax.jit(func_flattened) + len_gz = len(pttypes_outvars) + + vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) + jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + + if name is None: + curr_name = jaxfunc.__name__ + else: + curr_name = name + + # Get classes that creates a Pytensor Op out of our function that accept + # flattened inputs. They are created each time, to set a custom name for the + # class. + SolOp, VJPSolOp = _return_pytensor_ops_classes(curr_name) + + local_op = SolOp( + vars_treedef, + outvars_treedef, + input_types=pt_vars_types_flat, + output_types=pttypes_outvars, + jitted_sol_op_jax=jitted_sol_op_jax, + jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, + ) + + @jax_funcify.register(SolOp) + def sol_op_jax_funcify(op, **kwargs): + return local_op.perform_jax + + @jax_funcify.register(VJPSolOp) + def vjp_sol_op_jax_funcify(op, **kwargs): + return local_op.vjp_sol_op.perform_jax + + ### Evaluate the Pytensor Op and return unflattened results + output_flat = local_op(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] # tree_unflatten expects a sequence. + outvars = tree_unflatten(outvars_treedef, output_flat) + + static_outfuncs, static_outvars = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + + static_outfuncs_flat, treedef_outfuncs = jax.tree_util.tree_flatten( + static_outfuncs, is_leaf=callable + ) + for i_func, _ in enumerate(static_outfuncs_flat): + static_outfuncs_flat[i_func] = _WrappedFunc( + jaxfunc, i_func, *args, **kwargs + ) + + static_outfuncs = jax.tree_util.tree_unflatten( + treedef_outfuncs, static_outfuncs_flat + ) + static_vars = eqx.combine(static_outfuncs, static_outvars, is_leaf=callable) + + output = eqx.combine(outvars, static_vars, is_leaf=callable) + + return output + + return func + + +class _WrappedFunc: + def __init__(self, exterior_func, i_func, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.i_func = i_func + vars, static_vars = eqx.partition( + (self.args, self.kwargs), _filter_ptvars, is_leaf=callable + ) + self.vars = vars + self.static_vars = static_vars + self.exterior_func = exterior_func + + def __call__(self, *args, **kwargs): + # If called, assume that args and kwargs are pytensors, so return the result + # as pytensors. + def f(func, *args, **kwargs): + res = func(*args, **kwargs) + return res + + return as_jax_op(f)(self, *args, **kwargs) + + def get_vars(self): + return self.vars + + def get_func_with_vars(self, vars): + # Use other variables than the saved ones, to generate the function. This + # is used to transform vars externally from pytensor to JAX, and use the + # then create the function which is returned. + + args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) + output = self.exterior_func(*args, **kwargs) + outfuncs, _ = eqx.partition(output, callable, is_leaf=callable) + outfuncs_flat, _ = jax.tree_util.tree_flatten(outfuncs, is_leaf=callable) + interior_func = outfuncs_flat[self.i_func] + return interior_func + + +def _get_vjp_sol_op_jax(jaxfunc, len_gz): + def vjp_sol_op_jax(args): + y0 = args[:-len_gz] + gz = args[-len_gz:] + if len(gz) == 1: + gz = gz[0] + + def func(*inputs): + return jaxfunc(inputs) + + primals, vjp_fn = jax.vjp(func, *y0) + gz = tree_map( + lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), + gz, + primals, + ) + if len(y0) == 1: + return vjp_fn(gz)[0] + else: + return tuple(vjp_fn(gz)) + + return vjp_sol_op_jax + + +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Partition the jax function into static and non-static variables. + + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function + + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} + + def jaxfunc_partitioned(vars): + vars, vars_from_func = vars["vars"], vars["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine( + vars, static_vars, func_vars_evaled, is_leaf=callable + ) + + out = jaxfunc(*args, **kwargs) + outvars, static_out = eqx.partition(out, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return outvars + + return jaxfunc_partitioned, static_out_dic + + +### Construct the function that accepts flat inputs and returns flat outputs. +def _flatten_func(jaxfunc, vars_treedef): + def func_flattened(vars_flat): + vars = tree_unflatten(vars_treedef, vars_flat) + outvars = jaxfunc(vars) + outvars_flat, _ = tree_flatten(outvars) + return _normalize_flat_output(outvars_flat) + + return func_flattened + + +def _normalize_flat_output(output): + if len(output) > 1: + return tuple( + output + ) # Transform to tuple because jax makes a difference between + # tuple and list and not pytensor + else: + return output[0] + + +def _return_pytensor_ops_classes(name): + class SolOp(Op): + def __init__( + self, + input_treedef, + output_treeedef, + input_types, + output_types, + jitted_sol_op_jax, + jitted_vjp_sol_op_jax, + ): + self.vjp_sol_op = None + self.input_treedef = input_treedef + self.output_treedef = output_treeedef + self.input_types = input_types + self.output_types = output_types + self.jitted_sol_op_jax = jitted_sol_op_jax + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, *inputs): + self.num_inputs = len(inputs) + + # Define our output variables + outputs = [pt.as_tensor_variable(type()) for type in self.output_types] + self.num_outputs = len(outputs) + + self.vjp_sol_op = VJPSolOp( + self.input_treedef, + self.input_types, + self.jitted_vjp_sol_op_jax, + ) + + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_sol_op_jax(inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_sol_op_jax(inputs) + return results + + def grad(self, inputs, output_gradients): + # If a output is not used, it is disconnected and doesn't have a gradient. + # Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + if None not in self.output_types[i].shape: + output_gradients[i] = pt.zeros( + self.output_types[i].shape, self.output_types[i].dtype + ) + else: + output_gradients[i] = pt.zeros((), self.output_types[i].dtype) + result = self.vjp_sol_op(inputs, output_gradients) + + if self.num_inputs > 1: + return result + else: + return (result,) # Pytensor requires a tuple here + + # vector-jacobian product Op + class VJPSolOp(Op): + def __init__( + self, + input_treedef, + input_types, + jitted_vjp_sol_op_jax, + ): + self.input_treedef = input_treedef + self.input_types = input_types + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, y0, gz): + y0 = [ + pt.as_tensor_variable( + _y, + ).astype(self.input_types[i].dtype) + for i, _y in enumerate(y0) + ] + gz_not_disconntected = [ + pt.as_tensor_variable(_gz) + for _gz in gz + if not isinstance(_gz.type, DisconnectedType) + ] + outputs = [in_type() for in_type in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0 + gz_not_disconntected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if len(self.input_types) > 1: + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if self.num_outputs == 1: + if isinstance(results, Sequence): + return results[0] + else: + return results + else: + return tuple(results) + + SolOp.__name__ = name + SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name]) + + VJPSolOp.__name__ = "VJP_" + name + VJPSolOp.__qualname__ = ".".join( + VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name] + ) + + return SolOp, VJPSolOp diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py new file mode 100644 index 0000000000..6feb1124a2 --- /dev/null +++ b/tests/link/jax/test_as_jax_op.py @@ -0,0 +1,26 @@ +import jax +import numpy as np + +from pytensor import config +from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import as_jax_op +from pytensor.tensor import tensor +from tests.link.jax.test_basic import compare_jax_and_py + +def test_as_jax_op1(): + # 2 parameters input, single output + rng = np.random.default_rng(14) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x + y) + + out = f(x, y) + + fg = FunctionGraph([x, y], [out]) + fn, _ = compare_jax_and_py(fg, test_values) From d3a277ebbe6772d8aa2bfe85eb7741ca8463ed2a Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 12 Dec 2024 17:47:44 +0100 Subject: [PATCH 02/30] Add more tests --- pytensor/link/jax/ops.py | 7 +- tests/link/jax/test_as_jax_op.py | 368 ++++++++++++++++++++++++++++++- 2 files changed, 371 insertions(+), 4 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 130ece6eda..1b2325293d 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -241,7 +241,11 @@ def func(*inputs): primals, vjp_fn = jax.vjp(func, *y0) gz = tree_map( - lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), + lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)).astype( + primal.dtype + ), # Also cast to the dtype of the primal, this shouldn't be + # necessary, but it happens that the returned dtype of the gradient isn't + # the same anymore. gz, primals, ) @@ -326,6 +330,7 @@ def make_node(self, *inputs): self.num_inputs = len(inputs) # Define our output variables + print(self.output_types) outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 6feb1124a2..8d404d76db 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -1,15 +1,20 @@ import jax +import jax.numpy as jnp import numpy as np +import pytest -from pytensor import config +import pytensor.tensor as pt +from pytensor import config, grad from pytensor.graph.fg import FunctionGraph from pytensor.link.jax.ops import as_jax_op +from pytensor.scalar import all_types from pytensor.tensor import tensor from tests.link.jax.test_basic import compare_jax_and_py + def test_as_jax_op1(): # 2 parameters input, single output - rng = np.random.default_rng(14) + rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) test_values = [ @@ -21,6 +26,363 @@ def f(x, y): return jax.nn.sigmoid(x + y) out = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op2(): + # 2 parameters input, tuple output + rng = np.random.default_rng(2) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x + y), y * 2 + + out, _ = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op3(): + # 2 parameters input, list output + rng = np.random.default_rng(3) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return [jax.nn.sigmoid(x + y), y * 2] + + out, _ = f(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op4(): + # single 1d input, tuple output + rng = np.random.default_rng(4) + x = tensor("a", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return jax.nn.sigmoid(x), x * 2 + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op5(): + # single 0d input, tuple output + rng = np.random.default_rng(5) + x = tensor("a", shape=()) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return jax.nn.sigmoid(x), x + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op6(): + # single input, list output + rng = np.random.default_rng(6) + x = tensor("a", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + @as_jax_op + def f(x): + return [jax.nn.sigmoid(x), 2 * x] + + out, _ = f(x) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op7(): + # 2 parameters input with pytree, tuple output + rng = np.random.default_rng(7) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + y_tmp = {"y": y, "y2": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op8(): + # 2 parameters input with pytree, pytree output + rng = np.random.default_rng(8) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op9(): + # 2 parameters input with pytree, pytree output and non-graph argument + rng = np.random.default_rng(9) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, non_model_arg): + return jnp.exp(x), jax.tree_util.tree_map(jax.nn.sigmoid, y) + + out = f(x, y_tmp, "Hello World!") + grad_out = grad(pt.sum(out[0]), [x]) + + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op10(): + # Use "None" in shape specification and have a non-used output of higher rank + rng = np.random.default_rng(10) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x[:, None] @ y[None], jnp.exp(x) + + out = f(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op11(): + # Test unknown static shape + rng = np.random.default_rng(11) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x = pt.cumsum(x) # Now x has an unknown shape + + @as_jax_op + def f(x, y): + return x * jnp.ones(3) + + out = f(x, y) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + - fg = FunctionGraph([x, y], [out]) +def test_as_jax_op12(): + # Test non-array return values + rng = np.random.default_rng(12) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, message): + return x * jnp.ones(3), "Success: " + message + + out = f(x, y, "Hi") + grad_out = grad(pt.sum(out[0]), [x]) + + fg = FunctionGraph([x, y], [out[0], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_as_jax_op13(): + # Test nested functions + rng = np.random.default_rng(13) + x = tensor("a", shape=(3,)) + y = tensor("b", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f_internal(y): + def f_ret(t): + return y + t + + def f_ret2(t): + return f_ret(t) + t**2 + + return f_ret, y**2 * jnp.ones(1), f_ret2 + + f, y_pow, f2 = f_internal(y) + + @as_jax_op + def f_outer(x, dict_other): + f, y_pow = dict_other["func"], dict_other["y"] + return x * jnp.ones(3), f(x) * y_pow + + out = f_outer(x, {"func": f, "y": y_pow}) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +class TestDtypes: + @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) + def test_different_in_output(self, in_dtype, out_dtype): + x = tensor("a", shape=(3,), dtype=in_dtype) + y = tensor("b", shape=(3,), dtype=in_dtype) + + if "int" in in_dtype: + test_values = [ + np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + else: + test_values = [ + np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(out_dtype) + + out = f(x, y) + assert out.dtype == out_dtype + + if "float" in in_dtype and "float" in out_dtype: + grad_out = grad(out[0], [x, y]) + assert grad_out[0].dtype == in_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) + def test_test_different_inputs(self, in1_dtype, in2_dtype): + x = tensor("a", shape=(3,), dtype=in1_dtype) + y = tensor("b", shape=(3,), dtype=in2_dtype) + + if "int" in in1_dtype: + test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] + else: + test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)] + if "int" in in2_dtype: + test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype)) + else: + test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(in1_dtype) + + out = f(x, y) + assert out.dtype == in1_dtype + + if "float" in in1_dtype and "float" in in2_dtype: + # In principle, the gradient should also be defined if the second input is + # an integer, but it doesn't work for some reason. + grad_out = grad(out[0], [x]) + assert grad_out[0].dtype == in1_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) From 59141620751171547c6220b02ef658172d490545 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 13:18:10 +0100 Subject: [PATCH 03/30] Define JAXOp outside of the decorator --- pytensor/link/jax/ops.py | 263 +++++++++++++++++++-------------------- 1 file changed, 130 insertions(+), 133 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 1b2325293d..60a3581550 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -84,16 +84,8 @@ def func(*args, **kwargs): ) vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) - """ - def func_unwrapped(vars_all, static_vars): - vars, vars_from_func = vars_all["vars"], vars_all["vars_from_func"] - func_vars_evaled = tree_map( - lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func - ) - args, kwargs = eqx.combine(vars, static_vars, func_vars_evaled) - return self.jaxfunc(*args, **kwargs) - """ + # Infer shapes and types of the variables pt_vars_flat, vars_treedef = tree_flatten(pt_vars) pt_vars_types_flat = [var.type for var in pt_vars_flat] shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) @@ -135,17 +127,30 @@ def func_unwrapped(vars_all, static_vars): vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + # Get classes that creates a Pytensor Op out of our function that accept + # flattened inputs. They are created each time, to set a custom name for the + # class. + class JAXOp_local(JAXOp): + pass + + class VJPJAXOp_local(VJPJAXOp): + pass + if name is None: curr_name = jaxfunc.__name__ else: curr_name = name + JAXOp_local.__name__ = curr_name + JAXOp_local.__qualname__ = ".".join( + JAXOp_local.__qualname__.split(".")[:-1] + [curr_name] + ) - # Get classes that creates a Pytensor Op out of our function that accept - # flattened inputs. They are created each time, to set a custom name for the - # class. - SolOp, VJPSolOp = _return_pytensor_ops_classes(curr_name) + VJPJAXOp_local.__name__ = "VJP_" + curr_name + VJPJAXOp_local.__qualname__ = ".".join( + VJPJAXOp_local.__qualname__.split(".")[:-1] + ["VJP_" + curr_name] + ) - local_op = SolOp( + local_op = JAXOp_local( vars_treedef, outvars_treedef, input_types=pt_vars_types_flat, @@ -154,14 +159,6 @@ def func_unwrapped(vars_all, static_vars): jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, ) - @jax_funcify.register(SolOp) - def sol_op_jax_funcify(op, **kwargs): - return local_op.perform_jax - - @jax_funcify.register(VJPSolOp) - def vjp_sol_op_jax_funcify(op, **kwargs): - return local_op.vjp_sol_op.perform_jax - ### Evaluate the Pytensor Op and return unflattened results output_flat = local_op(*pt_vars_flat) if not isinstance(output_flat, Sequence): @@ -307,123 +304,123 @@ def _normalize_flat_output(output): return output[0] -def _return_pytensor_ops_classes(name): - class SolOp(Op): - def __init__( - self, - input_treedef, - output_treeedef, - input_types, - output_types, - jitted_sol_op_jax, - jitted_vjp_sol_op_jax, - ): - self.vjp_sol_op = None - self.input_treedef = input_treedef - self.output_treedef = output_treeedef - self.input_types = input_types - self.output_types = output_types - self.jitted_sol_op_jax = jitted_sol_op_jax - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax - - def make_node(self, *inputs): - self.num_inputs = len(inputs) - - # Define our output variables - print(self.output_types) - outputs = [pt.as_tensor_variable(type()) for type in self.output_types] - self.num_outputs = len(outputs) - - self.vjp_sol_op = VJPSolOp( - self.input_treedef, - self.input_types, - self.jitted_vjp_sol_op_jax, - ) +class JAXOp(Op): + def __init__( + self, + input_treedef, + output_treeedef, + input_types, + output_types, + jitted_sol_op_jax, + jitted_vjp_sol_op_jax, + ): + self.vjp_sol_op = None + self.input_treedef = input_treedef + self.output_treedef = output_treeedef + self.input_types = input_types + self.output_types = output_types + self.jitted_sol_op_jax = jitted_sol_op_jax + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, *inputs): + self.num_inputs = len(inputs) + + # Define our output variables + print(self.output_types) + outputs = [pt.as_tensor_variable(type()) for type in self.output_types] + self.num_outputs = len(outputs) + + self.vjp_sol_op = VJPJAXOp( + self.input_treedef, + self.input_types, + self.jitted_vjp_sol_op_jax, + ) - return Apply(self, inputs, outputs) + return Apply(self, inputs, outputs) - def perform(self, node, inputs, outputs): - results = self.jitted_sol_op_jax(inputs) - if self.num_outputs > 1: - for i in range(self.num_outputs): - outputs[i][0] = np.array(results[i], self.output_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.output_types[0].dtype) + def perform(self, node, inputs, outputs): + results = self.jitted_sol_op_jax(inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_sol_op_jax(inputs) + return results + + def grad(self, inputs, output_gradients): + # If a output is not used, it is disconnected and doesn't have a gradient. + # Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + if None not in self.output_types[i].shape: + output_gradients[i] = pt.zeros( + self.output_types[i].shape, self.output_types[i].dtype + ) + else: + output_gradients[i] = pt.zeros((), self.output_types[i].dtype) + result = self.vjp_sol_op(inputs, output_gradients) - def perform_jax(self, *inputs): - results = self.jitted_sol_op_jax(inputs) - return results + if self.num_inputs > 1: + return result + else: + return (result,) # Pytensor requires a tuple here + + +# vector-jacobian product Op +class VJPJAXOp(Op): + def __init__( + self, + input_treedef, + input_types, + jitted_vjp_sol_op_jax, + ): + self.input_treedef = input_treedef + self.input_types = input_types + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, y0, gz): + y0 = [ + pt.as_tensor_variable( + _y, + ).astype(self.input_types[i].dtype) + for i, _y in enumerate(y0) + ] + gz_not_disconntected = [ + pt.as_tensor_variable(_gz) + for _gz in gz + if not isinstance(_gz.type, DisconnectedType) + ] + outputs = [in_type() for in_type in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0 + gz_not_disconntected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if len(self.input_types) > 1: + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) - def grad(self, inputs, output_gradients): - # If a output is not used, it is disconnected and doesn't have a gradient. - # Set gradient here to zero for those outputs. - for i in range(self.num_outputs): - if isinstance(output_gradients[i].type, DisconnectedType): - if None not in self.output_types[i].shape: - output_gradients[i] = pt.zeros( - self.output_types[i].shape, self.output_types[i].dtype - ) - else: - output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_sol_op(inputs, output_gradients) - - if self.num_inputs > 1: - return result - else: - return (result,) # Pytensor requires a tuple here - - # vector-jacobian product Op - class VJPSolOp(Op): - def __init__( - self, - input_treedef, - input_types, - jitted_vjp_sol_op_jax, - ): - self.input_treedef = input_treedef - self.input_types = input_types - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax - - def make_node(self, y0, gz): - y0 = [ - pt.as_tensor_variable( - _y, - ).astype(self.input_types[i].dtype) - for i, _y in enumerate(y0) - ] - gz_not_disconntected = [ - pt.as_tensor_variable(_gz) - for _gz in gz - if not isinstance(_gz.type, DisconnectedType) - ] - outputs = [in_type() for in_type in self.input_types] - self.num_outputs = len(outputs) - return Apply(self, y0 + gz_not_disconntected, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) - if len(self.input_types) > 1: - for i, result in enumerate(results): - outputs[i][0] = np.array(result, self.input_types[i].dtype) + def perform_jax(self, *inputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if self.num_outputs == 1: + if isinstance(results, Sequence): + return results[0] else: - outputs[0][0] = np.array(results, self.input_types[0].dtype) + return results + else: + return tuple(results) - def perform_jax(self, *inputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) - if self.num_outputs == 1: - if isinstance(results, Sequence): - return results[0] - else: - return results - else: - return tuple(results) - SolOp.__name__ = name - SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name]) +@jax_funcify.register(JAXOp) +def sol_op_jax_funcify(op, **kwargs): + return op.perform_jax - VJPSolOp.__name__ = "VJP_" + name - VJPSolOp.__qualname__ = ".".join( - VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name] - ) - return SolOp, VJPSolOp +@jax_funcify.register(VJPJAXOp) +def vjp_sol_op_jax_funcify(op, **kwargs): + return op.perform_jax From b810e99d05067d6337af2745809b49748efacc28 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 18:00:41 +0100 Subject: [PATCH 04/30] Added comment regarding flattening of inputs --- pytensor/link/jax/ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 60a3581550..fba6022232 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -85,8 +85,12 @@ def func(*args, **kwargs): vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) - # Infer shapes and types of the variables + # Flatten nested python structures, e.g. {"a": tensor_a, "b": [tensor_b]} + # becomes [tensor_a, tensor_b], because pytensor ops only accepts lists of + # pytensor.Variables as input. pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + + # Infer shapes and types of the variables pt_vars_types_flat = [var.type for var in pt_vars_flat] shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) From df76e7391bbeef96220639306178bcac2f6c5a63 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 20:16:48 +0100 Subject: [PATCH 05/30] Add as_jax_op to pytensor.__init__.py and to documentation --- doc/conf.py | 1 + doc/library/index.rst | 7 +++++++ pytensor/__init__.py | 12 ++++++++++++ pytensor/link/jax/ops.py | 31 ++++++++++++++----------------- tests/link/jax/test_as_jax_op.py | 3 +-- 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index e10dcffb90..48d81730ba 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,6 +38,7 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/doc/library/index.rst b/doc/library/index.rst index e9b362f8db..70506f6120 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -61,6 +61,13 @@ Convert to Variable .. autofunction:: pytensor.as_symbolic(...) +Wrap JAX functions +================== + +.. autofunction:: as_jax_op(...) + + Alias for :func:`pytensor.link.jax.ops.as_jax_op` + Debug ===== diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 3c925ac2f2..a7f9aa8058 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph +try: + import pytensor.link.jax.ops + from pytensor.link.jax.ops import as_jax_op +except ImportError as e: + import_error_as_jax_op = e + + def as_jax_op(*args, **kwargs): + raise ImportError( + "JAX and/or equinox are not installed. Install them" + " to use this function: pip install pytensor[jax]" + ) from import_error_as_jax_op + # isort: on diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index fba6022232..6cb23470db 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -25,32 +25,29 @@ def _filter_ptvars(x): def as_jax_op(jaxfunc, name=None): - """Return a Pytensor from a JAX jittable function. + """Return a Pytensor function from a JAX jittable function. - This decorator transforms any JAX jittable function into a function that accepts - and returns `pytensor.Variables`. The jax jittable function can accept any - nested python structure (pytrees) as input, and return any nested Python structure. - - It requires to define the output types of the returned values as pytensor types. A - unique name should also be passed in case the name of the jaxfunc is identical to - some other node. The design of this function is based on - https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ + This decorator transforms any JAX-jittable function into a function that accepts + and returns `pytensor.Variable`. The JAX-jittable function can accept any + nested python structure (a `Pytree + `_) as input, and might return + any nested Python structure. Parameters ---------- - jaxfunc : jax jittable function - function for which the node is created, can return multiple tensors as a tuple. - It is required that all return values are able to transformed to - pytensor.Variable. - name: str + jaxfunc : JAX-jittable function + JAX function which will be wrapped in a Pytensor Op. + name: str, optional Name of the created pytensor Op, defaults to the name of the passed function. Only used internally in the pytensor graph. Returns ------- - A function which can be used in a pymc.Model as function, is differentiable - and the resulting model can be compiled either with the default C backend, or - the JAX backend. + Callable : + A function which expects a nested python structure of `pytensor.Variable` and + static variables as inputs and returns `pytensor.Variable` with the same + API as the original jaxfunc. The resulting model can be compiled either with the + default C backend or the JAX backend. Notes diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 8d404d76db..3842278a04 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -4,9 +4,8 @@ import pytest import pytensor.tensor as pt -from pytensor import config, grad +from pytensor import as_jax_op, config, grad from pytensor.graph.fg import FunctionGraph -from pytensor.link.jax.ops import as_jax_op from pytensor.scalar import all_types from pytensor.tensor import tensor from tests.link.jax.test_basic import compare_jax_and_py From c2338fb016b79485bfdaeb0f2d73a4d35b1544c5 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 15 Dec 2024 20:27:16 +0100 Subject: [PATCH 06/30] Add [jax] requirement to readthedocs in order to read the docstring of as_jax_op --- doc/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/environment.yml b/doc/environment.yml index 7b564e8fb0..5b1f8790dc 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -25,4 +25,4 @@ dependencies: - ablog - pip - pip: - - -e .. + - -e ..[jax] From 89474ae967d864dd30f9b174beecd33b6d1e0e65 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 16 Dec 2024 12:24:44 +0100 Subject: [PATCH 07/30] Added an example to the docstring of as_jax_op --- pytensor/link/jax/ops.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 6cb23470db..11833a7884 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -49,6 +49,33 @@ def as_jax_op(jaxfunc, name=None): API as the original jaxfunc. The resulting model can be compiled either with the default C backend or the JAX backend. + Examples + -------- + + We define a JAX function `f_jax` that accepts a matrix `x`, a vector `y` and a + dictionary as input. This is transformed to a pytensor function with the decorator + `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. + for evaluation and calculating gradients. + + >>> import numpy + >>> import jax.numpy as jnp + >>> import pytensor + >>> import pytensor.tensor as pt + >>> x = pt.tensor("x", shape=(2,)) + >>> y = pt.tensor("y", shape=(2, 2)) + >>> a = pt.tensor("a", shape=()) + >>> args_dict = {"a": a} + >>> @pytensor.as_jax_op + ... def f_jax(x, y, args_dict): + ... z = jnp.dot(x, y) + args_dict["a"] + ... return z + >>> z = f_jax(x, y, args_dict) + >>> z_sum = pt.sum(z) + >>> grad_wrt_a = pt.grad(z_sum, a) + >>> f_all = pytensor.function([x, y, a], [z_sum, grad_wrt_a]) + >>> f_all(numpy.array([1, 2]), numpy.array([[1, 2], [3, 4]]), 1) + [array(19.), array(2.)] + Notes ----- @@ -327,7 +354,6 @@ def make_node(self, *inputs): self.num_inputs = len(inputs) # Define our output variables - print(self.output_types) outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) From 4e2e005a0b865afb66310137e3117e2829c69b22 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 3 Feb 2025 17:44:35 +0100 Subject: [PATCH 08/30] Use infer_static_shape, currently still with the possibility to use the previous approach for testing purposes --- pytensor/link/jax/ops.py | 41 +++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 11833a7884..8b20370330 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -24,7 +24,7 @@ def _filter_ptvars(x): return isinstance(x, pt.Variable) -def as_jax_op(jaxfunc, name=None): +def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): """Return a Pytensor function from a JAX jittable function. This decorator transforms any JAX-jittable function into a function that accepts @@ -57,10 +57,10 @@ def as_jax_op(jaxfunc, name=None): `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. for evaluation and calculating gradients. - >>> import numpy - >>> import jax.numpy as jnp - >>> import pytensor - >>> import pytensor.tensor as pt + >>> import numpy # doctest: +ELLIPSIS + >>> import jax.numpy as jnp # doctest: +ELLIPSIS + >>> import pytensor # doctest: +ELLIPSIS + >>> import pytensor.tensor as pt # doctest: +ELLIPSIS >>> x = pt.tensor("x", shape=(2,)) >>> y = pt.tensor("y", shape=(2, 2)) >>> a = pt.tensor("a", shape=()) @@ -116,16 +116,27 @@ def func(*args, **kwargs): # Infer shapes and types of the variables pt_vars_types_flat = [var.type for var in pt_vars_flat] - shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) - shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) - - dummy_inputs_jax = jax.tree_util.tree_map( - lambda var, shape: jnp.empty( - [int(dim.eval()) for dim in shape], dtype=var.type.dtype - ), - pt_vars, - shapes_vars, - ) + + if use_infer_static_shape: + shapes_vars_flat = [ + pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat + ] + + dummy_inputs_jax_flat = [ + jnp.empty(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) + ] + + else: + shapes_vars_flat = pytensor.compile.builders.infer_shape( + pt_vars_flat, (), () + ) + dummy_inputs_jax_flat = [ + jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) + ] + + dummy_inputs_jax = tree_unflatten(vars_treedef, dummy_inputs_jax_flat) # Combine the static variables with the inputs, and split them again in the # output. Static variables don't take part in the graph, or might be a From e6b52d67fa597f4d2122a862184e9d2d7766af94 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 3 Feb 2025 17:53:16 +0100 Subject: [PATCH 09/30] Remove `sol` in variable names --- pytensor/link/jax/ops.py | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 8b20370330..992ba5ee20 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -160,11 +160,11 @@ def func(*args, **kwargs): ### Call the function that accepts flat inputs, which in turn calls the one that ### combines the inputs and static variables. - jitted_sol_op_jax = jax.jit(func_flattened) + jitted_jax_op = jax.jit(func_flattened) len_gz = len(pttypes_outvars) - vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) - jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + vjp_jax_op = _get_vjp_jax_op(func_flattened, len_gz) + jitted_vjp_jax_op = jax.jit(vjp_jax_op) # Get classes that creates a Pytensor Op out of our function that accept # flattened inputs. They are created each time, to set a custom name for the @@ -194,8 +194,8 @@ class VJPJAXOp_local(VJPJAXOp): outvars_treedef, input_types=pt_vars_types_flat, output_types=pttypes_outvars, - jitted_sol_op_jax=jitted_sol_op_jax, - jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, + jitted_jax_op=jitted_jax_op, + jitted_vjp_jax_op=jitted_vjp_jax_op, ) ### Evaluate the Pytensor Op and return unflattened results @@ -265,8 +265,8 @@ def get_func_with_vars(self, vars): return interior_func -def _get_vjp_sol_op_jax(jaxfunc, len_gz): - def vjp_sol_op_jax(args): +def _get_vjp_jax_op(jaxfunc, len_gz): + def vjp_jax_op(args): y0 = args[:-len_gz] gz = args[-len_gz:] if len(gz) == 1: @@ -290,7 +290,7 @@ def func(*inputs): else: return tuple(vjp_fn(gz)) - return vjp_sol_op_jax + return vjp_jax_op def _partition_jaxfunc(jaxfunc, static_vars, func_vars): @@ -350,16 +350,16 @@ def __init__( output_treeedef, input_types, output_types, - jitted_sol_op_jax, - jitted_vjp_sol_op_jax, + jitted_jax_op, + jitted_vjp_jax_op, ): - self.vjp_sol_op = None + self.vjp_jax_op = None self.input_treedef = input_treedef self.output_treedef = output_treeedef self.input_types = input_types self.output_types = output_types - self.jitted_sol_op_jax = jitted_sol_op_jax - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + self.jitted_jax_op = jitted_jax_op + self.jitted_vjp_jax_op = jitted_vjp_jax_op def make_node(self, *inputs): self.num_inputs = len(inputs) @@ -368,16 +368,16 @@ def make_node(self, *inputs): outputs = [pt.as_tensor_variable(type()) for type in self.output_types] self.num_outputs = len(outputs) - self.vjp_sol_op = VJPJAXOp( + self.vjp_jax_op = VJPJAXOp( self.input_treedef, self.input_types, - self.jitted_vjp_sol_op_jax, + self.jitted_vjp_jax_op, ) return Apply(self, inputs, outputs) def perform(self, node, inputs, outputs): - results = self.jitted_sol_op_jax(inputs) + results = self.jitted_jax_op(inputs) if self.num_outputs > 1: for i in range(self.num_outputs): outputs[i][0] = np.array(results[i], self.output_types[i].dtype) @@ -385,7 +385,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = np.array(results, self.output_types[0].dtype) def perform_jax(self, *inputs): - results = self.jitted_sol_op_jax(inputs) + results = self.jitted_jax_op(inputs) return results def grad(self, inputs, output_gradients): @@ -399,7 +399,7 @@ def grad(self, inputs, output_gradients): ) else: output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_sol_op(inputs, output_gradients) + result = self.vjp_jax_op(inputs, output_gradients) if self.num_inputs > 1: return result @@ -413,11 +413,11 @@ def __init__( self, input_treedef, input_types, - jitted_vjp_sol_op_jax, + jitted_vjp_jax_op, ): self.input_treedef = input_treedef self.input_types = input_types - self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + self.jitted_vjp_jax_op = jitted_vjp_jax_op def make_node(self, y0, gz): y0 = [ @@ -436,7 +436,7 @@ def make_node(self, y0, gz): return Apply(self, y0 + gz_not_disconntected, outputs) def perform(self, node, inputs, outputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + results = self.jitted_vjp_jax_op(tuple(inputs)) if len(self.input_types) > 1: for i, result in enumerate(results): outputs[i][0] = np.array(result, self.input_types[i].dtype) @@ -444,7 +444,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = np.array(results, self.input_types[0].dtype) def perform_jax(self, *inputs): - results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + results = self.jitted_vjp_jax_op(tuple(inputs)) if self.num_outputs == 1: if isinstance(results, Sequence): return results[0] @@ -455,10 +455,10 @@ def perform_jax(self, *inputs): @jax_funcify.register(JAXOp) -def sol_op_jax_funcify(op, **kwargs): +def jax_op_funcify(op, **kwargs): return op.perform_jax @jax_funcify.register(VJPJAXOp) -def vjp_sol_op_jax_funcify(op, **kwargs): +def vjp_jax_op_funcify(op, **kwargs): return op.perform_jax From abf99f152e909f2528c56ea5f2471c89c5bffc6d Mon Sep 17 00:00:00 2001 From: Jonas Date: Tue, 4 Feb 2025 20:49:59 +0100 Subject: [PATCH 10/30] Rename tests and make static variables test more meaningfull --- tests/link/jax/test_as_jax_op.py | 54 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 3842278a04..286a1334f7 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -11,8 +11,7 @@ from tests.link.jax.test_basic import compare_jax_and_py -def test_as_jax_op1(): - # 2 parameters input, single output +def test_2in_1out(): rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -33,8 +32,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op2(): - # 2 parameters input, tuple output +def test_2in_tupleout(): rng = np.random.default_rng(2) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -55,8 +53,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op3(): - # 2 parameters input, list output +def test_2in_listout(): rng = np.random.default_rng(3) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -77,8 +74,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op4(): - # single 1d input, tuple output +def test_1din_tupleout(): rng = np.random.default_rng(4) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -96,8 +92,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op5(): - # single 0d input, tuple output +def test_0din_tupleout(): rng = np.random.default_rng(5) x = tensor("a", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -115,8 +110,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op6(): - # single input, list output +def test_1in_listout(): rng = np.random.default_rng(6) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -135,8 +129,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op7(): - # 2 parameters input with pytree, tuple output +def test_pytreein_tupleout(): rng = np.random.default_rng(7) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -159,8 +152,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op8(): - # 2 parameters input with pytree, pytree output +def test_pytreein_pytreeout(): rng = np.random.default_rng(8) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -180,8 +172,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op9(): - # 2 parameters input with pytree, pytree output and non-graph argument +def test_pytreein_pytreeout_w_nongraphargs(): rng = np.random.default_rng(9) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -191,18 +182,35 @@ def test_as_jax_op9(): ] @as_jax_op - def f(x, y, non_model_arg): - return jnp.exp(x), jax.tree_util.tree_map(jax.nn.sigmoid, y) - - out = f(x, y_tmp, "Hello World!") - grad_out = grad(pt.sum(out[0]), [x]) + def f(x, y, depth, which_variable): + if which_variable == "x": + var = x + elif which_variable == "y": + var = y["a"] + y["b"][0] + else: + return "Unsupported argument" + for _ in range(depth): + var = jax.nn.sigmoid(var) + return var + # arguments depth and which_variable are not part of the graph + out = f(x, y_tmp, depth=3, which_variable="x") + grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out[0], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + out = f(x, y_tmp, depth=7, which_variable="y") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + out = f(x, y_tmp, depth=10, which_variable="z") + assert out == "Unsupported argument" + def test_as_jax_op10(): # Use "None" in shape specification and have a non-used output of higher rank From 21d252db131d909cbebdefe1a0bfeac4a1dc8191 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 5 Feb 2025 17:40:26 +0100 Subject: [PATCH 11/30] More test renaming, forgot a few --- tests/link/jax/test_as_jax_op.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 286a1334f7..d361acec4c 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -11,7 +11,7 @@ from tests.link.jax.test_basic import compare_jax_and_py -def test_2in_1out(): +def test_two_inputs_single_output(): rng = np.random.default_rng(1) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -32,7 +32,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_2in_tupleout(): +def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -53,7 +53,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_2in_listout(): +def test_two_inputs_list_output(): rng = np.random.default_rng(3) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -74,7 +74,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_1din_tupleout(): +def test_single_input_tuple_output(): rng = np.random.default_rng(4) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -92,7 +92,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_0din_tupleout(): +def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) x = tensor("a", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -110,7 +110,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_1in_listout(): +def test_single_input_list_output(): rng = np.random.default_rng(6) x = tensor("a", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @@ -129,7 +129,7 @@ def f(x): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_tupleout(): +def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) x = tensor("a", shape=(2,)) y = tensor("b", shape=(2,)) @@ -152,7 +152,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_pytreeout(): +def test_pytree_input_pytree_output(): rng = np.random.default_rng(8) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -172,7 +172,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_pytreein_pytreeout_w_nongraphargs(): +def test_pytree_input_with_non_graph_args(): rng = np.random.default_rng(9) x = tensor("a", shape=(3,)) y = tensor("b", shape=(1,)) @@ -212,8 +212,7 @@ def f(x, y, depth, which_variable): assert out == "Unsupported argument" -def test_as_jax_op10(): - # Use "None" in shape specification and have a non-used output of higher rank +def test_unused_matrix_product_and_exp_gradient(): rng = np.random.default_rng(10) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -235,8 +234,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op11(): - # Test unknown static shape +def test_unknown_static_shape(): rng = np.random.default_rng(11) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -260,8 +258,7 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op12(): - # Test non-array return values +def test_non_array_return_values(): rng = np.random.default_rng(12) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) @@ -283,8 +280,7 @@ def f(x, y, message): fn, _ = compare_jax_and_py(fg, test_values) -def test_as_jax_op13(): - # Test nested functions +def test_nested_functions(): rng = np.random.default_rng(13) x = tensor("a", shape=(3,)) y = tensor("b", shape=(3,)) From 729169659fc999854b51de686addbab86ec85fc0 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 5 Feb 2025 17:47:27 +0100 Subject: [PATCH 12/30] Refactoring of ops.py: code is in general cleaner, and JAXOp can now be used without the decorator @as_jax_op --- pytensor/link/jax/ops.py | 692 ++++++++++++++++++++------------------- 1 file changed, 350 insertions(+), 342 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 992ba5ee20..ca780da9c8 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -1,6 +1,5 @@ """Convert a jax function to a pytensor compatible function.""" -import functools as ft import logging from collections.abc import Sequence @@ -20,62 +19,242 @@ log = logging.getLogger(__name__) -def _filter_ptvars(x): - return isinstance(x, pt.Variable) +class JAXOp(Op): + """ + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward computation and reverse-mode differentiation (via the VJPJAXOp class). + + Parameters + ---------- + input_types : list + A list of PyTensor types for each input variable. + output_types : list + A list of PyTensor types for each output variable. + flat_func : callable + The JAX function that computes outputs from inputs. Inputs and outputs have to be provided as flat arrays. + name : str, optional + A custom name for the Op instance. If provided, the class name will be + updated accordingly. + + Example + ------- + This example defines a simple function that sums the input array with a dynamic shape. + + >>> import numpy as np + >>> import jax + >>> import jax.numpy as jnp + >>> from pytensor.tensor import TensorType + >>> + >>> # Create the jax function that sums the input array. + >>> def sum_function(x, y): + ... return jnp.sum(x + y) + >>> + >>> # Create the input and output types, input has a dynamic shape. + >>> input_type = TensorType("float32", shape=(None,)) + >>> output_type = TensorType("float32", shape=(1,)) + >>> + >>> # Instantiate a JAXOp; tree definitions are set to None for simplicity. + >>> op = JAXOp( + ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" + ... ) + >>> # Define symbolic input variables. + >>> x = pt.tensor("x", dtype="float32", shape=(2,)) + >>> y = pt.tensor("y", dtype="float32", shape=(2,)) + >>> # Compile a PyTensor function. + >>> result = op(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print( + ... f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array(14., dtype=float32)] + >>> + >>> # Compute the gradient of op(x, y) with respect to x. + >>> g = pt.grad(result[0], x) + >>> grad_f = pytensor.function([x, y], [g]) + >>> print( + ... grad_f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array([1., 1.], dtype=float32)] + """ + + def __init__(self, input_types, output_types, flat_func, name=None): + self.input_types = input_types + self.output_types = output_types + self.num_inputs = len(input_types) + self.num_outputs = len(output_types) + normalized_flat_func = _normalize_flat_func(flat_func) + self.jitted_func = jax.jit(normalized_flat_func) + + vjp_func = _get_vjp_jax_op(normalized_flat_func, len(output_types)) + normalized_vjp_func = _normalize_flat_func(vjp_func) + self.jitted_vjp = jax.jit(normalized_vjp_func) + self.vjp_jax_op = VJPJAXOp( + self.input_types, + self.jitted_vjp, + name=("VJP" + name) if name is not None else None, + ) + + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, *inputs): + outputs = [pt.as_tensor_variable(typ()) for typ in self.output_types] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_func(*inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_func(*inputs) + + def grad(self, inputs, output_gradients): + # If a output is not used, it gets disconnected by pytensor and won't have a + # gradient. Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + zero_shape = ( + self.output_types[i].shape + if None not in self.output_types[i].shape + else () + ) + output_gradients[i] = pt.zeros(zero_shape, self.output_types[i].dtype) + + # Compute the gradient. + grad_result = self.vjp_jax_op(inputs, output_gradients) + return grad_result if self.num_inputs > 1 else (grad_result,) + + +class VJPJAXOp(Op): + def __init__(self, input_types, jitted_vjp, name=None): + self.input_types = input_types + self.jitted_vjp = jitted_vjp + if name is not None: + self.custom_name = name + self.__class__.__name__ = name + self.__class__.__qualname__ = ".".join( + self.__class__.__qualname__.split(".")[:-1] + [name] + ) + + def make_node(self, y0, gz): + y0_converted = [ + pt.as_tensor_variable(y).astype(self.input_types[i].dtype) + for i, y in enumerate(y0) + ] + gz_not_disconnected = [ + pt.as_tensor_variable(g) + for g in gz + if not isinstance(g.type, DisconnectedType) + ] + outputs = [typ() for typ in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0_converted + gz_not_disconnected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp(*inputs) + if len(self.input_types) > 1: + for i, res in enumerate(results): + outputs[i][0] = np.array(res, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + return self.jitted_vjp(*inputs) + + +def _normalize_flat_func(func): + def normalized_func(*flat_vars): + out_flat = func(*flat_vars) + if isinstance(out_flat, Sequence): + return tuple(out_flat) if len(out_flat) > 1 else out_flat[0] + else: + return out_flat + + return normalized_func + + +def _get_vjp_jax_op(flat_func, num_out): + def vjp_op(*args): + y0 = args[:-num_out] + gz = args[-num_out:] + if len(gz) == 1: + gz = gz[0] + + def f(*inputs): + return flat_func(*inputs) + + primals, vjp_fn = jax.vjp(f, *y0) + + def broadcast_to_shape(g, shape): + if g.ndim > 0 and g.shape[0] == 1: + g_squeezed = jnp.squeeze(g, axis=0) + else: + g_squeezed = g + return jnp.broadcast_to(g_squeezed, shape) + + gz = tree_map( + lambda g, p: broadcast_to_shape(g, jnp.shape(p)).astype(p.dtype), + gz, + primals, + ) + return vjp_fn(gz) + + return vjp_op def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): - """Return a Pytensor function from a JAX jittable function. + """Return a Pytensor-compatible function from a JAX jittable function. - This decorator transforms any JAX-jittable function into a function that accepts - and returns `pytensor.Variable`. The JAX-jittable function can accept any + This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` + objects. The JAX-jittable function can accept any nested python structure (a `Pytree `_) as input, and might return any nested Python structure. Parameters ---------- - jaxfunc : JAX-jittable function - JAX function which will be wrapped in a Pytensor Op. - name: str, optional - Name of the created pytensor Op, defaults to the name of the passed function. - Only used internally in the pytensor graph. + jaxfunc : Callable + A JAX function to be wrapped. + use_infer_static_shape : bool, optional + If True, use static shape inference; otherwise, use runtime shape inference. + Default is True. + name : str, optional + A custom name for the created Pytensor Op instance. If None, the name of jaxfunc + is used. Returns ------- - Callable : - A function which expects a nested python structure of `pytensor.Variable` and - static variables as inputs and returns `pytensor.Variable` with the same - API as the original jaxfunc. The resulting model can be compiled either with the - default C backend or the JAX backend. + Callable + A function that wraps the given JAX function so that it can be called with + pytensor.Variable inputs and returns pytensor.Variable outputs. Examples -------- - We define a JAX function `f_jax` that accepts a matrix `x`, a vector `y` and a - dictionary as input. This is transformed to a pytensor function with the decorator - `as_jax_op`, and can subsequently be used like normal pytensor operators, i.e. - for evaluation and calculating gradients. - - >>> import numpy # doctest: +ELLIPSIS - >>> import jax.numpy as jnp # doctest: +ELLIPSIS - >>> import pytensor # doctest: +ELLIPSIS - >>> import pytensor.tensor as pt # doctest: +ELLIPSIS - >>> x = pt.tensor("x", shape=(2,)) - >>> y = pt.tensor("y", shape=(2, 2)) - >>> a = pt.tensor("a", shape=()) - >>> args_dict = {"a": a} - >>> @pytensor.as_jax_op - ... def f_jax(x, y, args_dict): - ... z = jnp.dot(x, y) + args_dict["a"] - ... return z - >>> z = f_jax(x, y, args_dict) - >>> z_sum = pt.sum(z) - >>> grad_wrt_a = pt.grad(z_sum, a) - >>> f_all = pytensor.function([x, y, a], [z_sum, grad_wrt_a]) - >>> f_all(numpy.array([1, 2]), numpy.array([[1, 2], [3, 4]]), 1) - [array(19.), array(2.)] - + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @as_jax_op + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.scalar("x") + >>> y = pt.scalar("y") + >>> result = add(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print(f(1, 2)) + [array(3.)] Notes ----- @@ -87,145 +266,165 @@ def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the variables. Shapes are inferred using :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + """ def func(*args, **kwargs): - """Return a pytensor from a jax jittable function.""" - ### Split variables: in the ones that will be transformed to JAX inputs, - ### pytensor.Variables; _WrappedFunc, that are functions that have been returned - ### from a transformed function; and the rest, static variables that are not - ### transformed. - - pt_vars, static_vars_tmp = eqx.partition( - (args, kwargs), _filter_ptvars, is_leaf=callable + # 1. Partition inputs into dynamic pytensor variables, wrapped functions and + # static variables. + # Static variables don't take part in the graph. + pt_vars, func_vars, static_vars = _split_inputs(args, kwargs) + + # 2. Get the original variables from the wrapped functions. + vars_from_func = tree_map(lambda f: f.get_vars(), func_vars) + input_dict = {"vars": pt_vars, "vars_from_func": vars_from_func} + + # 3. Flatten the input dictionary. + # e.g. {"a": tensor_a, "b": [tensor_b]} becomes [tensor_a, tensor_b], because + # pytensor ops only accepts lists of pytensor.Variables as input. + pt_vars_flat, pt_vars_treedef = tree_flatten( + input_dict, ) - # is_leaf=callable is used, as libraries like diffrax or equinox might return - # functions that are still seen as a nested pytree structure. We consider them - # as wrappable functions, that will be wrapped with _WrappedFunc. + pt_types = [var.type for var in pt_vars_flat] - func_vars, static_vars = eqx.partition( - static_vars_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + # 4. Create dummy inputs for shape inference. + shapes = _infer_shapes(pt_vars_flat, use_infer_static_shape) + dummy_in_flat = _create_dummy_inputs_from_shapes( + pt_vars_flat, shapes, use_infer_static_shape ) - vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) - pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) + dummy_inputs = tree_unflatten(pt_vars_treedef, dummy_in_flat) - # Flatten nested python structures, e.g. {"a": tensor_a, "b": [tensor_b]} - # becomes [tensor_a, tensor_b], because pytensor ops only accepts lists of - # pytensor.Variables as input. - pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + # 5. Partition the JAX function into dynamic and static parts. + jaxfunc_dynamic, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + flat_func = _flatten_func(jaxfunc_dynamic, pt_vars_treedef) + + # 6. Infer output types using JAX's eval_shape. + out_treedef, pt_types_flat = _infer_output_types(jaxfunc_dynamic, dummy_inputs) + + # 7. Create the Pytensor Op instance. + curr_name = "JAXOp_" + (jaxfunc.__name__ if name is None else name) + op_instance = JAXOp( + pt_types, + pt_types_flat, + flat_func, + name=curr_name, + ) - # Infer shapes and types of the variables - pt_vars_types_flat = [var.type for var in pt_vars_flat] + # 8. Execute the op and unflatten the outputs. + output_flat = op_instance(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] + outvars = tree_unflatten(out_treedef, output_flat) - if use_infer_static_shape: - shapes_vars_flat = [ - pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat - ] + # 9. Combine with static outputs and wrap eventual output functions with + # _WrappedFunc + return _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars) - dummy_inputs_jax_flat = [ - jnp.empty(shape, dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) - ] + return func - else: - shapes_vars_flat = pytensor.compile.builders.infer_shape( - pt_vars_flat, (), () - ) - dummy_inputs_jax_flat = [ - jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes_vars_flat, strict=True) - ] - dummy_inputs_jax = tree_unflatten(vars_treedef, dummy_inputs_jax_flat) +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + - # Combine the static variables with the inputs, and split them again in the - # output. Static variables don't take part in the graph, or might be a - # a function that is returned. - jaxfunc_partitioned, static_out_dic = _partition_jaxfunc( - jaxfunc, static_vars, func_vars - ) +def _split_inputs(args, kwargs): + """Split inputs into pytensor variables, static values and wrapped functions.""" - func_flattened = _flatten_func(jaxfunc_partitioned, vars_treedef) + pt_vars, static_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + func_vars, static_vars = eqx.partition( + static_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + return pt_vars, func_vars, static_vars - jaxtypes_outvars = jax.eval_shape( - ft.partial(jaxfunc_partitioned, vars=dummy_inputs_jax), - ) - jaxtypes_outvars_flat, outvars_treedef = tree_flatten(jaxtypes_outvars) +def _infer_shapes(pt_vars_flat, use_infer_static_shape): + """Infer shapes of pytensor variables.""" + if use_infer_static_shape: + return [pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat] + else: + return pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + - pttypes_outvars = [ - pt.TensorType(dtype=var.dtype, shape=var.shape) - for var in jaxtypes_outvars_flat +def _create_dummy_inputs_from_shapes(pt_vars_flat, shapes, use_infer_static_shape): + """Create dummy inputs for the jax function from inferred shapes.""" + if use_infer_static_shape: + return [ + jnp.empty(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) + ] + else: + return [ + jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, shapes, strict=True) ] - ### Call the function that accepts flat inputs, which in turn calls the one that - ### combines the inputs and static variables. - jitted_jax_op = jax.jit(func_flattened) - len_gz = len(pttypes_outvars) - vjp_jax_op = _get_vjp_jax_op(func_flattened, len_gz) - jitted_vjp_jax_op = jax.jit(vjp_jax_op) +def _infer_output_types(jaxfunc_part, dummy_inputs): + """Infer output types using JAX's eval_shape.""" + jax_out = jax.eval_shape(jaxfunc_part, dummy_inputs) + jax_out_flat, out_treedef = tree_flatten(jax_out) + pt_out_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) for var in jax_out_flat + ] + return out_treedef, pt_out_types - # Get classes that creates a Pytensor Op out of our function that accept - # flattened inputs. They are created each time, to set a custom name for the - # class. - class JAXOp_local(JAXOp): - pass - class VJPJAXOp_local(VJPJAXOp): - pass +def _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars): + """Process and combine static outputs with the dynamic ones.""" + static_funcs, static_vars_out = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + flat_static, func_treedef = tree_flatten(static_funcs, is_leaf=callable) + for i in range(len(flat_static)): + flat_static[i] = _WrappedFunc(jaxfunc, i, *args, **kwargs) + static_funcs = tree_unflatten(func_treedef, flat_static) + static_combined = eqx.combine(static_funcs, static_vars_out, is_leaf=callable) + return eqx.combine(outvars, static_combined, is_leaf=callable) - if name is None: - curr_name = jaxfunc.__name__ - else: - curr_name = name - JAXOp_local.__name__ = curr_name - JAXOp_local.__qualname__ = ".".join( - JAXOp_local.__qualname__.split(".")[:-1] + [curr_name] - ) - VJPJAXOp_local.__name__ = "VJP_" + curr_name - VJPJAXOp_local.__qualname__ = ".".join( - VJPJAXOp_local.__qualname__.split(".")[:-1] + ["VJP_" + curr_name] - ) +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Split the jax function into dynamic and static components. - local_op = JAXOp_local( - vars_treedef, - outvars_treedef, - input_types=pt_vars_types_flat, - output_types=pttypes_outvars, - jitted_jax_op=jitted_jax_op, - jitted_vjp_jax_op=jitted_vjp_jax_op, - ) + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function - ### Evaluate the Pytensor Op and return unflattened results - output_flat = local_op(*pt_vars_flat) - if not isinstance(output_flat, Sequence): - output_flat = [output_flat] # tree_unflatten expects a sequence. - outvars = tree_unflatten(outvars_treedef, output_flat) + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} - static_outfuncs, static_outvars = eqx.partition( - static_out_dic["out"], callable, is_leaf=callable + def jaxfunc_partitioned(vars): + dyn_vars, func_vars_input = vars["vars"], vars["vars_from_func"] + evaluated_funcs = tree_map( + lambda f, v: f.get_func_with_vars(v), func_vars, func_vars_input ) - - static_outfuncs_flat, treedef_outfuncs = jax.tree_util.tree_flatten( - static_outfuncs, is_leaf=callable + args, kwargs = eqx.combine( + dyn_vars, static_vars, evaluated_funcs, is_leaf=callable ) - for i_func, _ in enumerate(static_outfuncs_flat): - static_outfuncs_flat[i_func] = _WrappedFunc( - jaxfunc, i_func, *args, **kwargs - ) + output = jaxfunc(*args, **kwargs) + out_dyn, static_out = eqx.partition(output, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return out_dyn - static_outfuncs = jax.tree_util.tree_unflatten( - treedef_outfuncs, static_outfuncs_flat - ) - static_vars = eqx.combine(static_outfuncs, static_outvars, is_leaf=callable) + return jaxfunc_partitioned, static_out_dic - output = eqx.combine(outvars, static_vars, is_leaf=callable) - return output +def _flatten_func(jaxfunc, treedef): + def flat_func(*flat_vars): + vars = tree_unflatten(treedef, flat_vars) + out = jaxfunc(vars) + out_flat, _ = tree_flatten(out) + return out_flat - return func + return flat_func class _WrappedFunc: @@ -233,6 +432,7 @@ def __init__(self, exterior_func, i_func, *args, **kwargs): self.args = args self.kwargs = kwargs self.i_func = i_func + # Partition the inputs to separate dynamic variables from static ones. vars, static_vars = eqx.partition( (self.args, self.kwargs), _filter_ptvars, is_leaf=callable ) @@ -244,8 +444,7 @@ def __call__(self, *args, **kwargs): # If called, assume that args and kwargs are pytensors, so return the result # as pytensors. def f(func, *args, **kwargs): - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) return as_jax_op(f)(self, *args, **kwargs) @@ -256,202 +455,11 @@ def get_func_with_vars(self, vars): # Use other variables than the saved ones, to generate the function. This # is used to transform vars externally from pytensor to JAX, and use the # then create the function which is returned. - args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) output = self.exterior_func(*args, **kwargs) - outfuncs, _ = eqx.partition(output, callable, is_leaf=callable) - outfuncs_flat, _ = jax.tree_util.tree_flatten(outfuncs, is_leaf=callable) - interior_func = outfuncs_flat[self.i_func] - return interior_func - - -def _get_vjp_jax_op(jaxfunc, len_gz): - def vjp_jax_op(args): - y0 = args[:-len_gz] - gz = args[-len_gz:] - if len(gz) == 1: - gz = gz[0] - - def func(*inputs): - return jaxfunc(inputs) - - primals, vjp_fn = jax.vjp(func, *y0) - gz = tree_map( - lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)).astype( - primal.dtype - ), # Also cast to the dtype of the primal, this shouldn't be - # necessary, but it happens that the returned dtype of the gradient isn't - # the same anymore. - gz, - primals, - ) - if len(y0) == 1: - return vjp_fn(gz)[0] - else: - return tuple(vjp_fn(gz)) - - return vjp_jax_op - - -def _partition_jaxfunc(jaxfunc, static_vars, func_vars): - """Partition the jax function into static and non-static variables. - - Returns a function that accepts only non-static variables and returns the non-static - variables. The returned static variables are stored in a dictionary and returned, - to allow the referencing after creating the function - - Additionally wrapped functions saved in func_vars are regenerated with - vars["vars_from_func"] as input, to allow the transformation of the variables. - """ - static_out_dic = {"out": None} - - def jaxfunc_partitioned(vars): - vars, vars_from_func = vars["vars"], vars["vars_from_func"] - func_vars_evaled = tree_map( - lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func - ) - args, kwargs = eqx.combine( - vars, static_vars, func_vars_evaled, is_leaf=callable - ) - - out = jaxfunc(*args, **kwargs) - outvars, static_out = eqx.partition(out, eqx.is_array, is_leaf=callable) - static_out_dic["out"] = static_out - return outvars - - return jaxfunc_partitioned, static_out_dic - - -### Construct the function that accepts flat inputs and returns flat outputs. -def _flatten_func(jaxfunc, vars_treedef): - def func_flattened(vars_flat): - vars = tree_unflatten(vars_treedef, vars_flat) - outvars = jaxfunc(vars) - outvars_flat, _ = tree_flatten(outvars) - return _normalize_flat_output(outvars_flat) - - return func_flattened - - -def _normalize_flat_output(output): - if len(output) > 1: - return tuple( - output - ) # Transform to tuple because jax makes a difference between - # tuple and list and not pytensor - else: - return output[0] - - -class JAXOp(Op): - def __init__( - self, - input_treedef, - output_treeedef, - input_types, - output_types, - jitted_jax_op, - jitted_vjp_jax_op, - ): - self.vjp_jax_op = None - self.input_treedef = input_treedef - self.output_treedef = output_treeedef - self.input_types = input_types - self.output_types = output_types - self.jitted_jax_op = jitted_jax_op - self.jitted_vjp_jax_op = jitted_vjp_jax_op - - def make_node(self, *inputs): - self.num_inputs = len(inputs) - - # Define our output variables - outputs = [pt.as_tensor_variable(type()) for type in self.output_types] - self.num_outputs = len(outputs) - - self.vjp_jax_op = VJPJAXOp( - self.input_treedef, - self.input_types, - self.jitted_vjp_jax_op, - ) - - return Apply(self, inputs, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_jax_op(inputs) - if self.num_outputs > 1: - for i in range(self.num_outputs): - outputs[i][0] = np.array(results[i], self.output_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.output_types[0].dtype) - - def perform_jax(self, *inputs): - results = self.jitted_jax_op(inputs) - return results - - def grad(self, inputs, output_gradients): - # If a output is not used, it is disconnected and doesn't have a gradient. - # Set gradient here to zero for those outputs. - for i in range(self.num_outputs): - if isinstance(output_gradients[i].type, DisconnectedType): - if None not in self.output_types[i].shape: - output_gradients[i] = pt.zeros( - self.output_types[i].shape, self.output_types[i].dtype - ) - else: - output_gradients[i] = pt.zeros((), self.output_types[i].dtype) - result = self.vjp_jax_op(inputs, output_gradients) - - if self.num_inputs > 1: - return result - else: - return (result,) # Pytensor requires a tuple here - - -# vector-jacobian product Op -class VJPJAXOp(Op): - def __init__( - self, - input_treedef, - input_types, - jitted_vjp_jax_op, - ): - self.input_treedef = input_treedef - self.input_types = input_types - self.jitted_vjp_jax_op = jitted_vjp_jax_op - - def make_node(self, y0, gz): - y0 = [ - pt.as_tensor_variable( - _y, - ).astype(self.input_types[i].dtype) - for i, _y in enumerate(y0) - ] - gz_not_disconntected = [ - pt.as_tensor_variable(_gz) - for _gz in gz - if not isinstance(_gz.type, DisconnectedType) - ] - outputs = [in_type() for in_type in self.input_types] - self.num_outputs = len(outputs) - return Apply(self, y0 + gz_not_disconntected, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_vjp_jax_op(tuple(inputs)) - if len(self.input_types) > 1: - for i, result in enumerate(results): - outputs[i][0] = np.array(result, self.input_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.input_types[0].dtype) - - def perform_jax(self, *inputs): - results = self.jitted_vjp_jax_op(tuple(inputs)) - if self.num_outputs == 1: - if isinstance(results, Sequence): - return results[0] - else: - return results - else: - return tuple(results) + out_funcs, _ = eqx.partition(output, callable, is_leaf=callable) + out_funcs_flat, _ = tree_flatten(out_funcs, is_leaf=callable) + return out_funcs_flat[self.i_func] @jax_funcify.register(JAXOp) From 3e2949d045a63404ab6372dcbf5363af71381b40 Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 6 Feb 2025 13:45:46 +0100 Subject: [PATCH 13/30] Clean up tests --- tests/link/jax/test_as_jax_op.py | 124 ++++++++++++++----------------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index d361acec4c..71c34c04e5 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -13,8 +13,8 @@ def test_two_inputs_single_output(): rng = np.random.default_rng(1) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -34,8 +34,8 @@ def f(x, y): def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -44,19 +44,22 @@ def test_two_inputs_tuple_output(): def f(x, y): return jax.nn.sigmoid(x + y), y * 2 - out, _ = f(x, y) - grad_out = grad(pt.sum(out), [x, y]) + out1, out2 = f(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + # must_be_device_array is False, because the with disabled jit compilation, + # inputs are not automatically transformed to jax.Array anymore + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) -def test_two_inputs_list_output(): +def test_two_inputs_list_output_one_unused_output(): + # One output is unused, to test whether the wrapper can handle DisconnectedType rng = np.random.default_rng(3) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -76,63 +79,62 @@ def f(x, y): def test_single_input_tuple_output(): rng = np.random.default_rng(4) - x = tensor("a", shape=(2,)) + x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return jax.nn.sigmoid(x), x * 2 - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) - x = tensor("a", shape=()) + x = tensor("x", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return jax.nn.sigmoid(x), x - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_single_input_list_output(): rng = np.random.default_rng(6) - x = tensor("a", shape=(2,)) + x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] @as_jax_op def f(x): return [jax.nn.sigmoid(x), 2 * x] - out, _ = f(x) - grad_out = grad(pt.sum(out), [x]) + out1, out2 = f(x) + grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out, *grad_out]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) - with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) - x = tensor("a", shape=(2,)) - y = tensor("b", shape=(2,)) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) y_tmp = {"y": y, "y2": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -149,13 +151,13 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) def test_pytree_input_pytree_output(): rng = np.random.default_rng(8) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(1,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) y_tmp = {"a": y, "b": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -171,11 +173,14 @@ def f(x, y): fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + def test_pytree_input_with_non_graph_args(): rng = np.random.default_rng(9) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(1,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) y_tmp = {"a": y, "b": [y**2]} test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) @@ -212,10 +217,13 @@ def f(x, y, depth, which_variable): assert out == "Unsupported argument" -def test_unused_matrix_product_and_exp_gradient(): +def test_unused_matrix_product(): + # A matrix output is unused, to test whether the wrapper can handle a + # DisconnectedType with a larger dimension. + rng = np.random.default_rng(10) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -236,19 +244,19 @@ def f(x, y): def test_unknown_static_shape(): rng = np.random.default_rng(11) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - x = pt.cumsum(x) # Now x has an unknown shape + x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape @as_jax_op def f(x, y): return x * jnp.ones(3) - out = f(x, y) + out = f(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -258,32 +266,10 @@ def f(x, y): fn, _ = compare_jax_and_py(fg, test_values) -def test_non_array_return_values(): - rng = np.random.default_rng(12) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) - test_values = [ - rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) - ] - - @as_jax_op - def f(x, y, message): - return x * jnp.ones(3), "Success: " + message - - out = f(x, y, "Hi") - grad_out = grad(pt.sum(out[0]), [x]) - - fg = FunctionGraph([x, y], [out[0], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) - - with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) - - def test_nested_functions(): rng = np.random.default_rng(13) - x = tensor("a", shape=(3,)) - y = tensor("b", shape=(3,)) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) test_values = [ rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] @@ -319,8 +305,8 @@ class TestDtypes: @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) def test_different_in_output(self, in_dtype, out_dtype): - x = tensor("a", shape=(3,), dtype=in_dtype) - y = tensor("b", shape=(3,), dtype=in_dtype) + x = tensor("x", shape=(3,), dtype=in_dtype) + y = tensor("y", shape=(3,), dtype=in_dtype) if "int" in in_dtype: test_values = [ @@ -356,8 +342,8 @@ def f(x, y): @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) def test_test_different_inputs(self, in1_dtype, in2_dtype): - x = tensor("a", shape=(3,), dtype=in1_dtype) - y = tensor("b", shape=(3,), dtype=in2_dtype) + x = tensor("x", shape=(3,), dtype=in1_dtype) + y = tensor("y", shape=(3,), dtype=in2_dtype) if "int" in in1_dtype: test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] From 38b17b58ca04666fd3f3e7fd121518185de5039a Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 6 Feb 2025 14:13:56 +0100 Subject: [PATCH 14/30] Add to some tests a direct call to JAXOp --- tests/link/jax/test_as_jax_op.py | 131 +++++++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 17 deletions(-) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 71c34c04e5..62fd270032 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -6,8 +6,9 @@ import pytensor.tensor as pt from pytensor import as_jax_op, config, grad from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import JAXOp from pytensor.scalar import all_types -from pytensor.tensor import tensor +from pytensor.tensor import TensorType, tensor from tests.link.jax.test_basic import compare_jax_and_py @@ -19,11 +20,11 @@ def test_two_inputs_single_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return jax.nn.sigmoid(x + y) - out = f(x, y) + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -31,6 +32,17 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) @@ -40,11 +52,11 @@ def test_two_inputs_tuple_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return jax.nn.sigmoid(x + y), y * 2 - out1, out2 = f(x, y) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out1 + out2), [x, y]) fg = FunctionGraph([x, y], [out1, out2, *grad_out]) @@ -54,6 +66,17 @@ def f(x, y): # inputs are not automatically transformed to jax.Array anymore fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_two_inputs_list_output_one_unused_output(): # One output is unused, to test whether the wrapper can handle DisconnectedType @@ -64,11 +87,11 @@ def test_two_inputs_list_output_one_unused_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return [jax.nn.sigmoid(x + y), y * 2] - out, _ = f(x, y) + # Test with as_jax_op decorator + out, _ = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -76,17 +99,28 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out, _ = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_single_input_tuple_output(): rng = np.random.default_rng(4) x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return jax.nn.sigmoid(x), x * 2 - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -94,17 +128,28 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_scalar_input_tuple_output(): rng = np.random.default_rng(5) x = tensor("x", shape=()) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return jax.nn.sigmoid(x), x - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -112,17 +157,28 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_single_input_list_output(): rng = np.random.default_rng(6) x = tensor("x", shape=(2,)) test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] - @as_jax_op def f(x): return [jax.nn.sigmoid(x), 2 * x] - out1, out2 = f(x) + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) fg = FunctionGraph([x], [out1, out2, *grad_out]) @@ -130,6 +186,20 @@ def f(x): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + # Test direct JAXOp usage, with unspecified output shapes + jax_op = JAXOp( + [x.type], + [ + TensorType(config.floatX, shape=(None,)), + TensorType(config.floatX, shape=(None,)), + ], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_pytree_input_tuple_output(): rng = np.random.default_rng(7) @@ -144,6 +214,7 @@ def test_pytree_input_tuple_output(): def f(x, y): return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + # Test with as_jax_op decorator out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]), [x, y]) @@ -167,6 +238,7 @@ def test_pytree_input_pytree_output(): def f(x, y): return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + # Test with as_jax_op decorator out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) @@ -198,6 +270,7 @@ def f(x, y, depth, which_variable): var = jax.nn.sigmoid(var) return var + # Test with as_jax_op decorator # arguments depth and which_variable are not part of the graph out = f(x, y_tmp, depth=3, which_variable="x") grad_out = grad(pt.sum(out), [x]) @@ -228,11 +301,11 @@ def test_unused_matrix_product(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op def f(x, y): return x[:, None] @ y[None], jnp.exp(x) - out = f(x, y) + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out[1]), [x]) fg = FunctionGraph([x, y], [out[1], *grad_out]) @@ -241,6 +314,20 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [ + TensorType(config.floatX, shape=(3, 3)), + TensorType(config.floatX, shape=(3,)), + ], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_unknown_static_shape(): rng = np.random.default_rng(11) @@ -252,11 +339,10 @@ def test_unknown_static_shape(): x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape - @as_jax_op def f(x, y): return x * jnp.ones(3) - out = f(x_cumsum, y) + out = as_jax_op(f)(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) fg = FunctionGraph([x, y], [out, *grad_out]) @@ -265,6 +351,17 @@ def f(x, y): with jax.disable_jit(): fn, _ = compare_jax_and_py(fg, test_values) + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(None,))], + f, + ) + out = jax_op(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + def test_nested_functions(): rng = np.random.default_rng(13) From bb75938b12daa8f4e9158abc8c48d60bd6dac2b4 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 6 May 2025 10:17:24 +0200 Subject: [PATCH 15/30] temporary as_jax_op fix --- pytensor/link/jax/ops.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index ca780da9c8..b61c702bd3 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -333,13 +333,15 @@ def _split_inputs(args, kwargs): """Split inputs into pytensor variables, static values and wrapped functions.""" pt_vars, static_tmp = eqx.partition( - (args, kwargs), _filter_ptvars, is_leaf=callable + (args, kwargs), + _filter_ptvars, # is_leaf=callable ) # is_leaf=callable is used, as libraries like diffrax or equinox might return # functions that are still seen as a nested pytree structure. We consider them # as wrappable functions, that will be wrapped with _WrappedFunc. func_vars, static_vars = eqx.partition( - static_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + static_tmp, + lambda x: isinstance(x, _WrappedFunc), # is_leaf=callable ) return pt_vars, func_vars, static_vars @@ -407,10 +409,12 @@ def jaxfunc_partitioned(vars): lambda f, v: f.get_func_with_vars(v), func_vars, func_vars_input ) args, kwargs = eqx.combine( - dyn_vars, static_vars, evaluated_funcs, is_leaf=callable + dyn_vars, + static_vars, + evaluated_funcs, # is_leaf=callable ) output = jaxfunc(*args, **kwargs) - out_dyn, static_out = eqx.partition(output, eqx.is_array, is_leaf=callable) + out_dyn, static_out = eqx.partition(output, eqx.is_array) # , is_leaf=callable) static_out_dic["out"] = static_out return out_dyn From d04f41d80836b9dfdaa798754236f54737389c1b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 16 Sep 2025 12:54:03 +0200 Subject: [PATCH 16/30] Simplify as_jax_op --- .github/workflows/test.yml | 2 +- pytensor/__init__.py | 2 +- pytensor/link/jax/dispatch/basic.py | 6 + pytensor/link/jax/ops.py | 492 ++++++++++------------------ tests/link/jax/test_as_jax_op.py | 182 +++++----- 5 files changed, 272 insertions(+), 412 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0e6c9ee0f2..664ba0fdc5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -208,7 +208,7 @@ jobs: micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi - if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi + if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi diff --git a/pytensor/__init__.py b/pytensor/__init__.py index a7f9aa8058..924c31225b 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -173,7 +173,7 @@ def get_underlying_scalar_constant(v): except ImportError as e: import_error_as_jax_op = e - def as_jax_op(*args, **kwargs): + def as_jax_op(jaxfunc): raise ImportError( "JAX and/or equinox are not installed. Install them" " to use this function: pip install pytensor[jax]" diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index 66eb647cca..4735f9aa98 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -13,6 +13,7 @@ from pytensor.graph import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse +from pytensor.link.jax.ops import JAXOp from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise @@ -142,3 +143,8 @@ def opfromgraph(*inputs): return fgraph_fn(*inputs) return opfromgraph + + +@jax_funcify.register(JAXOp) +def jax_op_funcify(op, **kwargs): + return op.perform_jax diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index b61c702bd3..bf757c8c02 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -2,18 +2,13 @@ import logging from collections.abc import Sequence +from functools import wraps -import equinox as eqx -import jax -import jax.numpy as jnp import numpy as np -from jax.tree_util import tree_flatten, tree_map, tree_unflatten -import pytensor.compile.builders import pytensor.tensor as pt from pytensor.gradient import DisconnectedType -from pytensor.graph import Apply, Op -from pytensor.link.jax.dispatch import jax_funcify +from pytensor.graph import Apply, Op, Variable log = logging.getLogger(__name__) @@ -30,7 +25,7 @@ class JAXOp(Op): output_types : list A list of PyTensor types for each output variable. flat_func : callable - The JAX function that computes outputs from inputs. Inputs and outputs have to be provided as flat arrays. + The JAX function that computes outputs from inputs. name : str, optional A custom name for the Op instance. If provided, the class name will be updated accordingly. @@ -52,7 +47,7 @@ class JAXOp(Op): >>> input_type = TensorType("float32", shape=(None,)) >>> output_type = TensorType("float32", shape=(1,)) >>> - >>> # Instantiate a JAXOp; tree definitions are set to None for simplicity. + >>> # Instantiate a JAXOp >>> op = JAXOp( ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" ... ) @@ -82,140 +77,88 @@ class JAXOp(Op): [array([1., 1.], dtype=float32)] """ - def __init__(self, input_types, output_types, flat_func, name=None): - self.input_types = input_types - self.output_types = output_types - self.num_inputs = len(input_types) - self.num_outputs = len(output_types) - normalized_flat_func = _normalize_flat_func(flat_func) - self.jitted_func = jax.jit(normalized_flat_func) - - vjp_func = _get_vjp_jax_op(normalized_flat_func, len(output_types)) - normalized_vjp_func = _normalize_flat_func(vjp_func) - self.jitted_vjp = jax.jit(normalized_vjp_func) - self.vjp_jax_op = VJPJAXOp( - self.input_types, - self.jitted_vjp, - name=("VJP" + name) if name is not None else None, - ) - - if name is not None: - self.custom_name = name - self.__class__.__name__ = name - self.__class__.__qualname__ = ".".join( - self.__class__.__qualname__.split(".")[:-1] + [name] - ) - - def make_node(self, *inputs): - outputs = [pt.as_tensor_variable(typ()) for typ in self.output_types] + __props__ = ("input_types", "output_types", "jax_func", "name") + + def __init__(self, input_types, output_types, jax_func, name=None): + import jax + + self.input_types = tuple(input_types) + self.output_types = tuple(output_types) + self.jax_func = jax_func + self.jitted_func = jax.jit(jax_func) + self.name = name + super().__init__() + + def __repr__(self): + base = self.__class__.__name__ + if self.name is not None: + base = f"{base}{self.name}" + props = list(self.__props__) + props.remove("name") + props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props) + return f"{base}({props})" + + def make_node(self, *inputs: Variable) -> Apply: + outputs = [typ() for typ in self.output_types] return Apply(self, inputs, outputs) def perform(self, node, inputs, outputs): results = self.jitted_func(*inputs) - if self.num_outputs > 1: - for i in range(self.num_outputs): - outputs[i][0] = np.array(results[i], self.output_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.output_types[0].dtype) + if len(results) != len(outputs): + raise ValueError( + f"Expected {len(outputs)} outputs from jax function, got {len(results)}." + ) + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.output_types[i].dtype) def perform_jax(self, *inputs): - return self.jitted_func(*inputs) + output = self.jitted_func(*inputs) + if len(output) == 1: + return output[0] + return output def grad(self, inputs, output_gradients): - # If a output is not used, it gets disconnected by pytensor and won't have a - # gradient. Set gradient here to zero for those outputs. - for i in range(self.num_outputs): - if isinstance(output_gradients[i].type, DisconnectedType): - zero_shape = ( - self.output_types[i].shape - if None not in self.output_types[i].shape - else () - ) - output_gradients[i] = pt.zeros(zero_shape, self.output_types[i].dtype) - - # Compute the gradient. - grad_result = self.vjp_jax_op(inputs, output_gradients) - return grad_result if self.num_inputs > 1 else (grad_result,) - - -class VJPJAXOp(Op): - def __init__(self, input_types, jitted_vjp, name=None): - self.input_types = input_types - self.jitted_vjp = jitted_vjp - if name is not None: - self.custom_name = name - self.__class__.__name__ = name - self.__class__.__qualname__ = ".".join( - self.__class__.__qualname__.split(".")[:-1] + [name] + import jax + + wrt_index = [] + for i, out in enumerate(output_gradients): + if not isinstance(out.type, DisconnectedType): + wrt_index.append(i) + + num_inputs = len(inputs) + + def vjp_jax_op(*args): + inputs = args[:num_inputs] + covectors = args[num_inputs:] + assert len(covectors) == len(wrt_index) + + def func_restricted(*inputs): + out = self.jax_func(*inputs) + return [out[i].astype(self.output_types[i].dtype) for i in wrt_index] + + _primals, vjp_fn = jax.vjp(func_restricted, *inputs) + dtypes = [self.output_types[i].dtype for i in wrt_index] + return vjp_fn( + [ + covector.astype(dtype) + for covector, dtype in zip(covectors, dtypes, strict=True) + ] ) - def make_node(self, y0, gz): - y0_converted = [ - pt.as_tensor_variable(y).astype(self.input_types[i].dtype) - for i, y in enumerate(y0) - ] - gz_not_disconnected = [ - pt.as_tensor_variable(g) - for g in gz - if not isinstance(g.type, DisconnectedType) - ] - outputs = [typ() for typ in self.input_types] - self.num_outputs = len(outputs) - return Apply(self, y0_converted + gz_not_disconnected, outputs) - - def perform(self, node, inputs, outputs): - results = self.jitted_vjp(*inputs) - if len(self.input_types) > 1: - for i, res in enumerate(results): - outputs[i][0] = np.array(res, self.input_types[i].dtype) - else: - outputs[0][0] = np.array(results, self.input_types[0].dtype) - - def perform_jax(self, *inputs): - return self.jitted_vjp(*inputs) - - -def _normalize_flat_func(func): - def normalized_func(*flat_vars): - out_flat = func(*flat_vars) - if isinstance(out_flat, Sequence): - return tuple(out_flat) if len(out_flat) > 1 else out_flat[0] - else: - return out_flat - - return normalized_func - - -def _get_vjp_jax_op(flat_func, num_out): - def vjp_op(*args): - y0 = args[:-num_out] - gz = args[-num_out:] - if len(gz) == 1: - gz = gz[0] - - def f(*inputs): - return flat_func(*inputs) - - primals, vjp_fn = jax.vjp(f, *y0) - - def broadcast_to_shape(g, shape): - if g.ndim > 0 and g.shape[0] == 1: - g_squeezed = jnp.squeeze(g, axis=0) - else: - g_squeezed = g - return jnp.broadcast_to(g_squeezed, shape) - - gz = tree_map( - lambda g, p: broadcast_to_shape(g, jnp.shape(p)).astype(p.dtype), - gz, - primals, + op = JAXOp( + self.input_types + tuple(self.output_types[i] for i in wrt_index), + [self.input_types[i] for i in range(num_inputs)], + vjp_jax_op, + name="VJP" + (self.name if self.name is not None else ""), ) - return vjp_fn(gz) - return vjp_op + output = op(*[*inputs, *[output_gradients[i] for i in wrt_index]]) + if not isinstance(output, Sequence): + output = [output] + return output -def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): +def as_jax_op(jaxfunc): """Return a Pytensor-compatible function from a JAX jittable function. This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` @@ -228,12 +171,6 @@ def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): ---------- jaxfunc : Callable A JAX function to be wrapped. - use_infer_static_shape : bool, optional - If True, use static shape inference; otherwise, use runtime shape inference. - Default is True. - name : str, optional - A custom name for the created Pytensor Op instance. If None, the name of jaxfunc - is used. Returns ------- @@ -256,6 +193,32 @@ def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): >>> print(f(1, 2)) [array(3.)] + We can also pass arbitrary jax pytree structures as inputs and outputs: + + >>> import jax + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @as_jax_op + ... def complex_function(x, y, scale=1.0): + ... return { + ... "sum": jnp.add(x, y) * scale, + ... } + >>> x = pt.vector("x") + >>> y = pt.vector("y") + >>> result = complex_function(x, y, scale=2.0) + >>> f = pytensor.function([x, y], [result["sum"]]) + + Or even Equinox modules: + + >>> x = tensor("x", shape=(3,)) + >>> y = tensor("y", shape=(3,)) + >>> mlp = nn.MLP(3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0)) + >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) + >>> @as_jax_op + >>> def f(x, mlp): + >>> return mlp(x) + >>> out = f(x, mlp) + Notes ----- The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, @@ -268,209 +231,102 @@ def as_jax_op(jaxfunc, use_infer_static_shape=True, name=None): :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. """ - + name = jaxfunc.__name__ + + try: + import equinox as eqx + import jax + import jax.numpy as jnp + except ImportError as e: + raise ImportError( + "The as_jax_op decorator requires both jax and equinox to be installed." + ) from e + + @wraps(jaxfunc) def func(*args, **kwargs): - # 1. Partition inputs into dynamic pytensor variables, wrapped functions and + # Partition inputs into dynamic pytensor variables, wrapped functions and # static variables. # Static variables don't take part in the graph. - pt_vars, func_vars, static_vars = _split_inputs(args, kwargs) - # 2. Get the original variables from the wrapped functions. - vars_from_func = tree_map(lambda f: f.get_vars(), func_vars) - input_dict = {"vars": pt_vars, "vars_from_func": vars_from_func} + pt_vars, static_vars = eqx.partition( + (args, kwargs), lambda x: isinstance(x, pt.Variable) + ) - # 3. Flatten the input dictionary. - # e.g. {"a": tensor_a, "b": [tensor_b]} becomes [tensor_a, tensor_b], because - # pytensor ops only accepts lists of pytensor.Variables as input. - pt_vars_flat, pt_vars_treedef = tree_flatten( - input_dict, + # Flatten the input dictionary. + pt_vars_flat, pt_vars_treedef = jax.tree.flatten( + pt_vars, ) pt_types = [var.type for var in pt_vars_flat] - # 4. Create dummy inputs for shape inference. - shapes = _infer_shapes(pt_vars_flat, use_infer_static_shape) - dummy_in_flat = _create_dummy_inputs_from_shapes( - pt_vars_flat, shapes, use_infer_static_shape - ) - dummy_inputs = tree_unflatten(pt_vars_treedef, dummy_in_flat) + # We need to figure out static shapes so that we can figure + # out the output types. + input_shapes = [var.type.shape for var in pt_vars_flat] + resolved_input_shapes = [] + for var, shape in zip(pt_vars_flat, input_shapes, strict=True): + if any(s is None for s in shape): + _, shape = pt.basic.infer_static_shape(var.shape) + if any(s is None for s in shape): + raise ValueError( + f"Input variable {var} has a shape with undetermined " + "shape. Please provide inputs with fully determined shapes " + "by calling pt.specify_shape." + ) + resolved_input_shapes.append(shape) + + # Figure out output types using jax.eval_shape. + extra_output_storage = {} + + def wrap_jaxfunc(args): + vars = jax.tree.unflatten(pt_vars_treedef, args) + args, kwargs = eqx.combine( + vars, + static_vars, + ) + outputs = jaxfunc(*args, **kwargs) + output_vals, output_static = eqx.partition(outputs, eqx.is_array) + extra_output_storage["output_static"] = output_static + outputs_flat, output_treedef = jax.tree.flatten(output_vals) + extra_output_storage["output_treedef"] = output_treedef + return outputs_flat + + dummy_inputs = [ + jnp.ones(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, resolved_input_shapes, strict=True) + ] - # 5. Partition the JAX function into dynamic and static parts. - jaxfunc_dynamic, static_out_dic = _partition_jaxfunc( - jaxfunc, static_vars, func_vars - ) - flat_func = _flatten_func(jaxfunc_dynamic, pt_vars_treedef) + output_shapes_flat = jax.eval_shape(wrap_jaxfunc, dummy_inputs) + output_treedef = extra_output_storage["output_treedef"] + output_static = extra_output_storage["output_static"] + pt_output_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) + for var in output_shapes_flat + ] - # 6. Infer output types using JAX's eval_shape. - out_treedef, pt_types_flat = _infer_output_types(jaxfunc_dynamic, dummy_inputs) + def flat_func(*flat_vars): + vars = jax.tree.unflatten(pt_vars_treedef, flat_vars) + args, kwargs = eqx.combine( + vars, + static_vars, + ) + outputs = jaxfunc(*args, **kwargs) + output_vals, _ = eqx.partition(outputs, eqx.is_array) + outputs_flat, _ = jax.tree.flatten(output_vals) + return outputs_flat - # 7. Create the Pytensor Op instance. - curr_name = "JAXOp_" + (jaxfunc.__name__ if name is None else name) op_instance = JAXOp( pt_types, - pt_types_flat, + pt_output_types, flat_func, - name=curr_name, + name=name, ) # 8. Execute the op and unflatten the outputs. output_flat = op_instance(*pt_vars_flat) if not isinstance(output_flat, Sequence): output_flat = [output_flat] - outvars = tree_unflatten(out_treedef, output_flat) + outvars = jax.tree.unflatten(output_treedef, output_flat) + outvars = eqx.combine(outvars, output_static) - # 9. Combine with static outputs and wrap eventual output functions with - # _WrappedFunc - return _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars) + return outvars return func - - -def _filter_ptvars(x): - return isinstance(x, pt.Variable) - - -def _split_inputs(args, kwargs): - """Split inputs into pytensor variables, static values and wrapped functions.""" - - pt_vars, static_tmp = eqx.partition( - (args, kwargs), - _filter_ptvars, # is_leaf=callable - ) - # is_leaf=callable is used, as libraries like diffrax or equinox might return - # functions that are still seen as a nested pytree structure. We consider them - # as wrappable functions, that will be wrapped with _WrappedFunc. - func_vars, static_vars = eqx.partition( - static_tmp, - lambda x: isinstance(x, _WrappedFunc), # is_leaf=callable - ) - return pt_vars, func_vars, static_vars - - -def _infer_shapes(pt_vars_flat, use_infer_static_shape): - """Infer shapes of pytensor variables.""" - if use_infer_static_shape: - return [pt.basic.infer_static_shape(var.shape)[1] for var in pt_vars_flat] - else: - return pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) - - -def _create_dummy_inputs_from_shapes(pt_vars_flat, shapes, use_infer_static_shape): - """Create dummy inputs for the jax function from inferred shapes.""" - if use_infer_static_shape: - return [ - jnp.empty(shape, dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes, strict=True) - ] - else: - return [ - jnp.empty([int(dim.eval()) for dim in shape], dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, shapes, strict=True) - ] - - -def _infer_output_types(jaxfunc_part, dummy_inputs): - """Infer output types using JAX's eval_shape.""" - jax_out = jax.eval_shape(jaxfunc_part, dummy_inputs) - jax_out_flat, out_treedef = tree_flatten(jax_out) - pt_out_types = [ - pt.TensorType(dtype=var.dtype, shape=var.shape) for var in jax_out_flat - ] - return out_treedef, pt_out_types - - -def _process_outputs(static_out_dic, jaxfunc, args, kwargs, outvars): - """Process and combine static outputs with the dynamic ones.""" - static_funcs, static_vars_out = eqx.partition( - static_out_dic["out"], callable, is_leaf=callable - ) - flat_static, func_treedef = tree_flatten(static_funcs, is_leaf=callable) - for i in range(len(flat_static)): - flat_static[i] = _WrappedFunc(jaxfunc, i, *args, **kwargs) - static_funcs = tree_unflatten(func_treedef, flat_static) - static_combined = eqx.combine(static_funcs, static_vars_out, is_leaf=callable) - return eqx.combine(outvars, static_combined, is_leaf=callable) - - -def _partition_jaxfunc(jaxfunc, static_vars, func_vars): - """Split the jax function into dynamic and static components. - - Returns a function that accepts only non-static variables and returns the non-static - variables. The returned static variables are stored in a dictionary and returned, - to allow the referencing after creating the function - - Additionally wrapped functions saved in func_vars are regenerated with - vars["vars_from_func"] as input, to allow the transformation of the variables. - """ - static_out_dic = {"out": None} - - def jaxfunc_partitioned(vars): - dyn_vars, func_vars_input = vars["vars"], vars["vars_from_func"] - evaluated_funcs = tree_map( - lambda f, v: f.get_func_with_vars(v), func_vars, func_vars_input - ) - args, kwargs = eqx.combine( - dyn_vars, - static_vars, - evaluated_funcs, # is_leaf=callable - ) - output = jaxfunc(*args, **kwargs) - out_dyn, static_out = eqx.partition(output, eqx.is_array) # , is_leaf=callable) - static_out_dic["out"] = static_out - return out_dyn - - return jaxfunc_partitioned, static_out_dic - - -def _flatten_func(jaxfunc, treedef): - def flat_func(*flat_vars): - vars = tree_unflatten(treedef, flat_vars) - out = jaxfunc(vars) - out_flat, _ = tree_flatten(out) - return out_flat - - return flat_func - - -class _WrappedFunc: - def __init__(self, exterior_func, i_func, *args, **kwargs): - self.args = args - self.kwargs = kwargs - self.i_func = i_func - # Partition the inputs to separate dynamic variables from static ones. - vars, static_vars = eqx.partition( - (self.args, self.kwargs), _filter_ptvars, is_leaf=callable - ) - self.vars = vars - self.static_vars = static_vars - self.exterior_func = exterior_func - - def __call__(self, *args, **kwargs): - # If called, assume that args and kwargs are pytensors, so return the result - # as pytensors. - def f(func, *args, **kwargs): - return func(*args, **kwargs) - - return as_jax_op(f)(self, *args, **kwargs) - - def get_vars(self): - return self.vars - - def get_func_with_vars(self, vars): - # Use other variables than the saved ones, to generate the function. This - # is used to transform vars externally from pytensor to JAX, and use the - # then create the function which is returned. - args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) - output = self.exterior_func(*args, **kwargs) - out_funcs, _ = eqx.partition(output, callable, is_leaf=callable) - out_funcs_flat, _ = tree_flatten(out_funcs, is_leaf=callable) - return out_funcs_flat[self.i_func] - - -@jax_funcify.register(JAXOp) -def jax_op_funcify(op, **kwargs): - return op.perform_jax - - -@jax_funcify.register(VJPJAXOp) -def vjp_jax_op_funcify(op, **kwargs): - return op.perform_jax diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index 62fd270032..bb87f8f1e9 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -1,17 +1,17 @@ -import jax -import jax.numpy as jnp import numpy as np import pytest import pytensor.tensor as pt from pytensor import as_jax_op, config, grad -from pytensor.graph.fg import FunctionGraph from pytensor.link.jax.ops import JAXOp from pytensor.scalar import all_types from pytensor.tensor import TensorType, tensor from tests.link.jax.test_basic import compare_jax_and_py +jax = pytest.importorskip("jax") + + def test_two_inputs_single_output(): rng = np.random.default_rng(1) x = tensor("x", shape=(2,)) @@ -27,10 +27,12 @@ def f(x, y): out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + def f(x, y): + return [jax.nn.sigmoid(x + y)] # Test direct JAXOp usage jax_op = JAXOp( @@ -40,8 +42,7 @@ def f(x, y): ) out = jax_op(x, y) grad_out = grad(pt.sum(out), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) def test_two_inputs_tuple_output(): @@ -59,12 +60,13 @@ def f(x, y): out1, out2 = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out1 + out2), [x, y]) - fg = FunctionGraph([x, y], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) with jax.disable_jit(): # must_be_device_array is False, because the with disabled jit compilation, # inputs are not automatically transformed to jax.Array anymore - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x, y], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) # Test direct JAXOp usage jax_op = JAXOp( @@ -74,8 +76,7 @@ def f(x, y): ) out1, out2 = jax_op(x, y) grad_out = grad(pt.sum(out1 + out2), [x, y]) - fg = FunctionGraph([x, y], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) def test_two_inputs_list_output_one_unused_output(): @@ -94,10 +95,9 @@ def f(x, y): out, _ = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) # Test direct JAXOp usage jax_op = JAXOp( @@ -107,8 +107,7 @@ def f(x, y): ) out, _ = jax_op(x, y) grad_out = grad(pt.sum(out), [x, y]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) def test_single_input_tuple_output(): @@ -123,10 +122,11 @@ def f(x): out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) # Test direct JAXOp usage jax_op = JAXOp( @@ -136,8 +136,7 @@ def f(x): ) out1, out2 = jax_op(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) def test_scalar_input_tuple_output(): @@ -152,10 +151,11 @@ def f(x): out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) # Test direct JAXOp usage jax_op = JAXOp( @@ -165,8 +165,7 @@ def f(x): ) out1, out2 = jax_op(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) def test_single_input_list_output(): @@ -181,10 +180,11 @@ def f(x): out1, out2 = as_jax_op(f)(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x], [out1, out2, *grad_out], test_values, must_be_device_array=False + ) # Test direct JAXOp usage, with unspecified output shapes jax_op = JAXOp( @@ -197,8 +197,7 @@ def f(x): ) out1, out2 = jax_op(x) grad_out = grad(pt.sum(out1), [x]) - fg = FunctionGraph([x], [out1, out2, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x], [out1, out2, *grad_out], test_values) def test_pytree_input_tuple_output(): @@ -218,11 +217,12 @@ def f(x, y): out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]), [x, y]) - fg = FunctionGraph([x, y], [out[0], out[1], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], out[1], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x, y], [out[0], out[1], *grad_out], test_values, must_be_device_array=False + ) def test_pytree_input_pytree_output(): @@ -236,17 +236,21 @@ def test_pytree_input_pytree_output(): @as_jax_op def f(x, y): - return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + return x, jax.tree_util.tree_map(lambda x: jax.numpy.exp(x), y) # Test with as_jax_op decorator out = f(x, y_tmp) grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) - fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], out[1]["a"], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + compare_jax_and_py( + [x, y], + [out[0], out[1]["a"], *grad_out], + test_values, + must_be_device_array=False, + ) def test_pytree_input_with_non_graph_args(): @@ -274,17 +278,15 @@ def f(x, y, depth, which_variable): # arguments depth and which_variable are not part of the graph out = f(x, y_tmp, depth=3, which_variable="x") grad_out = grad(pt.sum(out), [x]) - fg = FunctionGraph([x, y], [out[0], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) out = f(x, y_tmp, depth=7, which_variable="y") grad_out = grad(pt.sum(out), [x]) - fg = FunctionGraph([x, y], [out[0], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[0], *grad_out], test_values) out = f(x, y_tmp, depth=10, which_variable="z") assert out == "Unsupported argument" @@ -302,17 +304,16 @@ def test_unused_matrix_product(): ] def f(x, y): - return x[:, None] @ y[None], jnp.exp(x) + return x[:, None] @ y[None], jax.numpy.exp(x) # Test with as_jax_op decorator out = as_jax_op(f)(x, y) grad_out = grad(pt.sum(out[1]), [x]) - fg = FunctionGraph([x, y], [out[1], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) # Test direct JAXOp usage jax_op = JAXOp( @@ -325,8 +326,7 @@ def f(x, y): ) out = jax_op(x, y) grad_out = grad(pt.sum(out[1]), [x]) - fg = FunctionGraph([x, y], [out[1], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) def test_unknown_static_shape(): @@ -340,16 +340,15 @@ def test_unknown_static_shape(): x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape def f(x, y): - return x * jnp.ones(3) + return [x * jax.numpy.ones(3)] - out = as_jax_op(f)(x_cumsum, y) + (out,) = as_jax_op(f)(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) # Test direct JAXOp usage jax_op = JAXOp( @@ -359,11 +358,13 @@ def f(x, y): ) out = jax_op(x_cumsum, y) grad_out = grad(pt.sum(out), [x]) - fg = FunctionGraph([x, y], [out, *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) -def test_nested_functions(): +def test_nn(): + import equinox as eqx + import equinox.nn as nn + rng = np.random.default_rng(13) x = tensor("x", shape=(3,)) y = tensor("y", shape=(3,)) @@ -371,31 +372,22 @@ def test_nested_functions(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op - def f_internal(y): - def f_ret(t): - return y + t - - def f_ret2(t): - return f_ret(t) + t**2 - - return f_ret, y**2 * jnp.ones(1), f_ret2 - - f, y_pow, f2 = f_internal(y) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + mlp = nn.MLP(3, 3, 3, depth=2, activation=jax.numpy.tanh, key=jax.random.key(0)) + mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) @as_jax_op - def f_outer(x, dict_other): - f, y_pow = dict_other["func"], dict_other["y"] - return x * jnp.ones(3), f(x) * y_pow + def f(x, mlp): + return mlp(x) - out = f_outer(x, {"func": f, "y": y_pow}) - grad_out = grad(pt.sum(out[1]), [x]) + out = f(x, mlp) + grad_out = grad(pt.sum(out), [x]) - fg = FunctionGraph([x, y], [out[1], *grad_out]) - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out[1], *grad_out], test_values) class TestDtypes: @@ -418,8 +410,8 @@ def test_different_in_output(self, in_dtype, out_dtype): @as_jax_op def f(x, y): - out = jnp.add(x, y) - return jnp.real(out).astype(out_dtype) + out = jax.numpy.add(x, y) + return jax.numpy.real(out).astype(out_dtype) out = f(x, y) assert out.dtype == out_dtype @@ -427,14 +419,15 @@ def f(x, y): if "float" in in_dtype and "float" in out_dtype: grad_out = grad(out[0], [x, y]) assert grad_out[0].dtype == in_dtype - fg = FunctionGraph([x, y], [out, *grad_out]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) else: - fg = FunctionGraph([x, y], [out]) - - fn, _ = compare_jax_and_py(fg, test_values) + compare_jax_and_py([x, y], [out], test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + if "float" in in_dtype and "float" in out_dtype: + compare_jax_and_py([x, y], [out, *grad_out], test_values) + else: + compare_jax_and_py([x, y], [out], test_values) @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) @@ -453,8 +446,8 @@ def test_test_different_inputs(self, in1_dtype, in2_dtype): @as_jax_op def f(x, y): - out = jnp.add(x, y) - return jnp.real(out).astype(in1_dtype) + out = jax.numpy.add(x, y) + return jax.numpy.real(out).astype(in1_dtype) out = f(x, y) assert out.dtype == in1_dtype @@ -464,11 +457,16 @@ def f(x, y): # an integer, but it doesn't work for some reason. grad_out = grad(out[0], [x]) assert grad_out[0].dtype == in1_dtype - fg = FunctionGraph([x, y], [out, *grad_out]) + inputs = [x, y] + outputs = [out, *grad_out] else: - fg = FunctionGraph([x, y], [out]) + inputs = [x, y] + outputs = [out] - fn, _ = compare_jax_and_py(fg, test_values) + fn, _ = compare_jax_and_py(inputs, outputs, test_values) with jax.disable_jit(): - fn, _ = compare_jax_and_py(fg, test_values) + if "float" in in1_dtype and "float" in in2_dtype: + compare_jax_and_py([x, y], [out, *grad_out], test_values) + else: + compare_jax_and_py([x, y], [out], test_values) From 8ecf45ccfa4f8cc20324e1f4f67ff4d56e7714a9 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 16 Sep 2025 20:10:45 +0200 Subject: [PATCH 17/30] optionally eval shapes in as_jax_op --- pytensor/link/jax/ops.py | 242 +++++++++++++++++++------------ tests/link/jax/test_as_jax_op.py | 4 +- 2 files changed, 148 insertions(+), 98 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index bf757c8c02..9e5f5488bc 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -7,6 +7,7 @@ import numpy as np import pytensor.tensor as pt +from pytensor.compile.function import function from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Op, Variable @@ -158,7 +159,7 @@ def func_restricted(*inputs): return output -def as_jax_op(jaxfunc): +def as_jax_op(jaxfunc=None, *, allow_eval=True): """Return a Pytensor-compatible function from a JAX jittable function. This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` @@ -169,8 +170,11 @@ def as_jax_op(jaxfunc): Parameters ---------- - jaxfunc : Callable - A JAX function to be wrapped. + jaxfunc : Callable, optional + A JAX function to be wrapped. If None, returns a decorator function. + allow_eval : bool, default=True + Whether to allow evaluation of symbolic shapes when input shapes are + not fully determined. Returns ------- @@ -223,7 +227,7 @@ def as_jax_op(jaxfunc): ----- The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, available at - `pymc-labls.io `__. To accept functions and non pytensor variables as input, the function make use of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the @@ -231,102 +235,148 @@ def as_jax_op(jaxfunc): :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. """ - name = jaxfunc.__name__ - try: - import equinox as eqx - import jax - import jax.numpy as jnp - except ImportError as e: - raise ImportError( - "The as_jax_op decorator requires both jax and equinox to be installed." - ) from e - - @wraps(jaxfunc) - def func(*args, **kwargs): - # Partition inputs into dynamic pytensor variables, wrapped functions and - # static variables. - # Static variables don't take part in the graph. - - pt_vars, static_vars = eqx.partition( - (args, kwargs), lambda x: isinstance(x, pt.Variable) - ) - - # Flatten the input dictionary. - pt_vars_flat, pt_vars_treedef = jax.tree.flatten( - pt_vars, - ) - pt_types = [var.type for var in pt_vars_flat] - - # We need to figure out static shapes so that we can figure - # out the output types. - input_shapes = [var.type.shape for var in pt_vars_flat] - resolved_input_shapes = [] - for var, shape in zip(pt_vars_flat, input_shapes, strict=True): - if any(s is None for s in shape): - _, shape = pt.basic.infer_static_shape(var.shape) - if any(s is None for s in shape): - raise ValueError( - f"Input variable {var} has a shape with undetermined " - "shape. Please provide inputs with fully determined shapes " - "by calling pt.specify_shape." - ) - resolved_input_shapes.append(shape) - - # Figure out output types using jax.eval_shape. - extra_output_storage = {} + def decorator(func): + name = func.__name__ + + try: + import equinox as eqx + import jax + except ImportError as e: + raise ImportError( + "The as_jax_op decorator requires both jax and equinox to be installed." + ) from e + + @wraps(func) + def wrapper(*args, **kwargs): + # Partition inputs into dynamic pytensor variables, wrapped functions and + # static variables. + # Static variables don't take part in the graph. + + input_vars, static_input = eqx.partition( + (args, kwargs), lambda x: isinstance(x, pt.Variable) + ) - def wrap_jaxfunc(args): - vars = jax.tree.unflatten(pt_vars_treedef, args) - args, kwargs = eqx.combine( - vars, - static_vars, + # Flatten the input dictionary. + input_flat, input_treedef = jax.tree.flatten( + input_vars, ) - outputs = jaxfunc(*args, **kwargs) - output_vals, output_static = eqx.partition(outputs, eqx.is_array) - extra_output_storage["output_static"] = output_static - outputs_flat, output_treedef = jax.tree.flatten(output_vals) - extra_output_storage["output_treedef"] = output_treedef - return outputs_flat - - dummy_inputs = [ - jnp.ones(shape, dtype=var.type.dtype) - for var, shape in zip(pt_vars_flat, resolved_input_shapes, strict=True) - ] - - output_shapes_flat = jax.eval_shape(wrap_jaxfunc, dummy_inputs) - output_treedef = extra_output_storage["output_treedef"] - output_static = extra_output_storage["output_static"] - pt_output_types = [ - pt.TensorType(dtype=var.dtype, shape=var.shape) - for var in output_shapes_flat - ] - - def flat_func(*flat_vars): - vars = jax.tree.unflatten(pt_vars_treedef, flat_vars) - args, kwargs = eqx.combine( - vars, - static_vars, + input_types = [var.type for var in input_flat] + + # We need to figure out static shapes so that we can figure + # out the output types. + output_types, output_treedef, output_static = _find_output_types( + func, input_flat, input_treedef, static_input, allow_eval=allow_eval ) - outputs = jaxfunc(*args, **kwargs) - output_vals, _ = eqx.partition(outputs, eqx.is_array) - outputs_flat, _ = jax.tree.flatten(output_vals) - return outputs_flat - - op_instance = JAXOp( - pt_types, - pt_output_types, - flat_func, - name=name, - ) - # 8. Execute the op and unflatten the outputs. - output_flat = op_instance(*pt_vars_flat) - if not isinstance(output_flat, Sequence): - output_flat = [output_flat] - outvars = jax.tree.unflatten(output_treedef, output_flat) - outvars = eqx.combine(outvars, output_static) + def flat_func(*flat_vars): + vars = jax.tree.unflatten(input_treedef, flat_vars) + args, kwargs = eqx.combine( + vars, + static_input, + ) + outputs = func(*args, **kwargs) + output_vals, _ = eqx.partition(outputs, eqx.is_array) + outputs_flat, _ = jax.tree.flatten(output_vals) + return outputs_flat + + op_instance = JAXOp( + input_types, + output_types, + flat_func, + name=name, + ) - return outvars + # 8. Execute the op and unflatten the outputs. + output_flat = op_instance(*input_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] + outvars = jax.tree.unflatten(output_treedef, output_flat) + outvars = eqx.combine(outvars, output_static) + + return outvars + + return wrapper + + if jaxfunc is None: + return decorator + else: + return decorator(jaxfunc) + + +def _find_output_types( + jaxfunc, inputs_flat, input_treedef, static_input, *, allow_eval=True +): + import equinox as eqx + import jax + import jax.numpy as jnp + + resolved_input_shapes = [] + needs_eval = False + for var in inputs_flat: + # If shape is already fully determined, use it directly + if not any(s is None for s in var.type.shape): + resolved_input_shapes.append(var.type.shape) + continue + + # Try to infer static shape + _, shape = pt.basic.infer_static_shape(var.shape) + if not any(s is None for s in shape): + resolved_input_shapes.append(shape) + continue - return func + # Shape still has undetermined dimensions + if not allow_eval: + raise ValueError( + f"Input variable {var} has a shape with undetermined " + "shape. Please provide inputs with fully determined shapes " + "by calling pt.specify_shape." + ) + needs_eval = True + resolved_input_shapes.append(var.shape) + + if needs_eval: + try: + shape_fn = function( + [], + resolved_input_shapes, + on_unused_input="ignore", + mode="FAST_COMPILE", + ) + except Exception as e: + raise ValueError( + "Could not compile a function to infer example shapes. " + "Please provide inputs with fully determined shapes by " + "calling pt.specify_shape." + ) from e + resolved_input_shapes = shape_fn() + + # Figure out output types using jax.eval_shape. + extra_output_storage = {} + + dummy_inputs = [ + jnp.ones(shape, dtype=var.type.dtype) + for var, shape in zip(inputs_flat, resolved_input_shapes, strict=True) + ] + + def wrap_jaxfunc(args): + vars = jax.tree.unflatten(input_treedef, args) + args, kwargs = eqx.combine( + vars, + static_input, + ) + outputs = jaxfunc(*args, **kwargs) + output_vals, output_static = eqx.partition(outputs, eqx.is_array) + extra_output_storage["output_static"] = output_static + outputs_flat, output_treedef = jax.tree.flatten(output_vals) + extra_output_storage["output_treedef"] = output_treedef + return outputs_flat + + output_shapes_flat = jax.eval_shape(wrap_jaxfunc, dummy_inputs) + output_treedef = extra_output_storage["output_treedef"] + output_static = extra_output_storage["output_static"] + output_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) for var in output_shapes_flat + ] + + return output_types, output_treedef, output_static diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index bb87f8f1e9..b6998d4ab2 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -384,10 +384,10 @@ def f(x, mlp): out = f(x, mlp) grad_out = grad(pt.sum(out), [x]) - compare_jax_and_py([x, y], [out[1], *grad_out], test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): - compare_jax_and_py([x, y], [out[1], *grad_out], test_values) + compare_jax_and_py([x, y], [out, *grad_out], test_values) class TestDtypes: From abd668ba10058601511ade2cf12a11276242930c Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 16 Sep 2025 20:31:00 +0200 Subject: [PATCH 18/30] minor coding style changes in as_jax_op --- pytensor/link/jax/ops.py | 274 +++++++++++++++++++++------------------ 1 file changed, 150 insertions(+), 124 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 9e5f5488bc..fce3975e24 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -17,7 +17,8 @@ class JAXOp(Op): """ - JAXOp is a PyTensor Op that wraps a JAX function, providing both forward computation and reverse-mode differentiation (via the VJPJAXOp class). + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward + computation and reverse-mode differentiation (via VJP). Parameters ---------- @@ -25,7 +26,7 @@ class JAXOp(Op): A list of PyTensor types for each input variable. output_types : list A list of PyTensor types for each output variable. - flat_func : callable + jax_function : callable The JAX function that computes outputs from inputs. name : str, optional A custom name for the Op instance. If provided, the class name will be @@ -80,13 +81,13 @@ class JAXOp(Op): __props__ = ("input_types", "output_types", "jax_func", "name") - def __init__(self, input_types, output_types, jax_func, name=None): + def __init__(self, input_types, output_types, jax_function, name=None): import jax self.input_types = tuple(input_types) self.output_types = tuple(output_types) - self.jax_func = jax_func - self.jitted_func = jax.jit(jax_func) + self.jax_func = jax_function + self.jitted_func = jax.jit(jax_function) self.name = name super().__init__() @@ -100,77 +101,96 @@ def __repr__(self): return f"{base}({props})" def make_node(self, *inputs: Variable) -> Apply: - outputs = [typ() for typ in self.output_types] + """Create an Apply node with the given inputs and inferred outputs.""" + outputs = [output_type() for output_type in self.output_types] return Apply(self, inputs, outputs) def perform(self, node, inputs, outputs): + """Execute the JAX function and store results in output storage.""" results = self.jitted_func(*inputs) if len(results) != len(outputs): raise ValueError( - f"Expected {len(outputs)} outputs from jax function, got {len(results)}." + f"JAX function returned {len(results)} outputs, but " + f"{len(outputs)} were expected." ) for i, result in enumerate(results): - outputs[i][0] = np.array(result, self.output_types[i].dtype) + outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype) def perform_jax(self, *inputs): - output = self.jitted_func(*inputs) - if len(output) == 1: - return output[0] - return output + """Execute the JAX function directly, returning JAX arrays.""" + outputs = self.jitted_func(*inputs) + if len(outputs) == 1: + return outputs[0] + return outputs def grad(self, inputs, output_gradients): + """Compute gradients using JAX's vector-Jacobian product (VJP).""" import jax - wrt_index = [] - for i, out in enumerate(output_gradients): - if not isinstance(out.type, DisconnectedType): - wrt_index.append(i) + # Find indices of outputs that need gradients + connected_output_indices = [] + for i, output_grad in enumerate(output_gradients): + if not isinstance(output_grad.type, DisconnectedType): + connected_output_indices.append(i) num_inputs = len(inputs) - def vjp_jax_op(*args): - inputs = args[:num_inputs] - covectors = args[num_inputs:] - assert len(covectors) == len(wrt_index) - - def func_restricted(*inputs): - out = self.jax_func(*inputs) - return [out[i].astype(self.output_types[i].dtype) for i in wrt_index] + def vjp_operation(*args): + """VJP operation that computes gradients w.r.t. inputs.""" + input_values = args[:num_inputs] + cotangent_vectors = args[num_inputs:] + assert len(cotangent_vectors) == len(connected_output_indices) + + def restricted_function(*input_values): + """Restricted function that only returns connected outputs.""" + outputs = self.jax_func(*input_values) + return [ + outputs[i].astype(self.output_types[i].dtype) + for i in connected_output_indices + ] - _primals, vjp_fn = jax.vjp(func_restricted, *inputs) - dtypes = [self.output_types[i].dtype for i in wrt_index] - return vjp_fn( + _primals, vjp_function = jax.vjp(restricted_function, *input_values) + output_dtypes = [ + self.output_types[i].dtype for i in connected_output_indices + ] + return vjp_function( [ - covector.astype(dtype) - for covector, dtype in zip(covectors, dtypes, strict=True) + cotangent.astype(dtype) + for cotangent, dtype in zip( + cotangent_vectors, output_dtypes, strict=True + ) ] ) - op = JAXOp( - self.input_types + tuple(self.output_types[i] for i in wrt_index), + # Create VJP operation + vjp_op = JAXOp( + self.input_types + + tuple(self.output_types[i] for i in connected_output_indices), [self.input_types[i] for i in range(num_inputs)], - vjp_jax_op, + vjp_operation, name="VJP" + (self.name if self.name is not None else ""), ) - output = op(*[*inputs, *[output_gradients[i] for i in wrt_index]]) - if not isinstance(output, Sequence): - output = [output] - return output + gradient_outputs = vjp_op( + *[*inputs, *[output_gradients[i] for i in connected_output_indices]] + ) + if not isinstance(gradient_outputs, Sequence): + gradient_outputs = [gradient_outputs] + return gradient_outputs -def as_jax_op(jaxfunc=None, *, allow_eval=True): - """Return a Pytensor-compatible function from a JAX jittable function. +def as_jax_op(jax_function=None, *, allow_eval=True): + """Return a PyTensor-compatible function from a JAX jittable function. - This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` - objects. The JAX-jittable function can accept any - nested python structure (a `Pytree - `_) as input, and might return - any nested Python structure. + This decorator wraps a JAX function so that it accepts and returns + `pytensor.Variable` objects. The JAX-jittable function can accept any + nested Python structure (a `Pytree + `_) as input, and might + return any nested Python structure. Parameters ---------- - jaxfunc : Callable, optional + jax_function : Callable, optional A JAX function to be wrapped. If None, returns a decorator function. allow_eval : bool, default=True Whether to allow evaluation of symbolic shapes when input shapes are @@ -212,16 +232,17 @@ def as_jax_op(jaxfunc=None, *, allow_eval=True): >>> result = complex_function(x, y, scale=2.0) >>> f = pytensor.function([x, y], [result["sum"]]) - Or even Equinox modules: + Or Equinox modules: - >>> x = tensor("x", shape=(3,)) - >>> y = tensor("y", shape=(3,)) - >>> mlp = nn.MLP(3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0)) + >>> x = pt.tensor("x", shape=(3,)) + >>> y = pt.tensor("y", shape=(3,)) + >>> import equinox as eqx + >>> mlp = eqx.nn.MLP(3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0)) >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) >>> @as_jax_op - >>> def f(x, mlp): - >>> return mlp(x) - >>> out = f(x, mlp) + ... def neural_network(x, mlp): + ... return mlp(x) + >>> out = neural_network(x, mlp) Notes ----- @@ -229,8 +250,8 @@ def as_jax_op(jaxfunc=None, *, allow_eval=True): available at `pymc-labs.io `__. - To accept functions and non pytensor variables as input, the function make use - of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + To accept functions and non-PyTensor variables as input, the function uses + :func:`equinox.partition` and :func:`equinox.combine` to split and combine the variables. Shapes are inferred using :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. @@ -249,95 +270,98 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - # Partition inputs into dynamic pytensor variables, wrapped functions and - # static variables. - # Static variables don't take part in the graph. - - input_vars, static_input = eqx.partition( + # Partition inputs into dynamic PyTensor variables and static variables. + # Static variables don't participate in the computational graph. + pytensor_variables, static_values = eqx.partition( (args, kwargs), lambda x: isinstance(x, pt.Variable) ) - # Flatten the input dictionary. - input_flat, input_treedef = jax.tree.flatten( - input_vars, - ) - input_types = [var.type for var in input_flat] + # Flatten the PyTensor variables for processing + variables_flat, variables_treedef = jax.tree.flatten(pytensor_variables) + input_types = [var.type for var in variables_flat] - # We need to figure out static shapes so that we can figure - # out the output types. + # Determine output types by analyzing the function structure output_types, output_treedef, output_static = _find_output_types( - func, input_flat, input_treedef, static_input, allow_eval=allow_eval + func, + variables_flat, + variables_treedef, + static_values, + allow_eval=allow_eval, ) - def flat_func(*flat_vars): - vars = jax.tree.unflatten(input_treedef, flat_vars) - args, kwargs = eqx.combine( - vars, - static_input, + def flattened_function(*flat_variables): + """Execute the original function with flattened inputs.""" + variables = jax.tree.unflatten(variables_treedef, flat_variables) + reconstructed_args, reconstructed_kwargs = eqx.combine( + variables, static_values ) - outputs = func(*args, **kwargs) - output_vals, _ = eqx.partition(outputs, eqx.is_array) - outputs_flat, _ = jax.tree.flatten(output_vals) - return outputs_flat + function_outputs = func(*reconstructed_args, **reconstructed_kwargs) + array_outputs, _ = eqx.partition(function_outputs, eqx.is_array) + flattened_outputs, _ = jax.tree.flatten(array_outputs) + return flattened_outputs - op_instance = JAXOp( + # Create the JAX operation + jax_op_instance = JAXOp( input_types, output_types, - flat_func, + flattened_function, name=name, ) - # 8. Execute the op and unflatten the outputs. - output_flat = op_instance(*input_flat) - if not isinstance(output_flat, Sequence): - output_flat = [output_flat] - outvars = jax.tree.unflatten(output_treedef, output_flat) - outvars = eqx.combine(outvars, output_static) + # Execute the operation and reconstruct the output structure + flattened_results = jax_op_instance(*variables_flat) + if not isinstance(flattened_results, Sequence): + flattened_results = [flattened_results] + + output_variables = jax.tree.unflatten(output_treedef, flattened_results) + final_outputs = eqx.combine(output_variables, output_static) - return outvars + return final_outputs return wrapper - if jaxfunc is None: + if jax_function is None: return decorator else: - return decorator(jaxfunc) + return decorator(jax_function) def _find_output_types( - jaxfunc, inputs_flat, input_treedef, static_input, *, allow_eval=True + jax_function, inputs_flat, input_treedef, static_input, *, allow_eval=True ): + """Determine output types by analyzing the JAX function structure.""" import equinox as eqx import jax import jax.numpy as jnp resolved_input_shapes = [] - needs_eval = False - for var in inputs_flat: + requires_shape_evaluation = False + + for variable in inputs_flat: # If shape is already fully determined, use it directly - if not any(s is None for s in var.type.shape): - resolved_input_shapes.append(var.type.shape) + if not any(dimension is None for dimension in variable.type.shape): + resolved_input_shapes.append(variable.type.shape) continue # Try to infer static shape - _, shape = pt.basic.infer_static_shape(var.shape) - if not any(s is None for s in shape): - resolved_input_shapes.append(shape) + _, inferred_shape = pt.basic.infer_static_shape(variable.shape) + if not any(dimension is None for dimension in inferred_shape): + resolved_input_shapes.append(inferred_shape) continue # Shape still has undetermined dimensions if not allow_eval: raise ValueError( - f"Input variable {var} has a shape with undetermined " - "shape. Please provide inputs with fully determined shapes " - "by calling pt.specify_shape." + f"Input variable {variable} has undetermined shape dimensions. " + "Please provide inputs with fully determined shapes by calling " + "pt.specify_shape." ) - needs_eval = True - resolved_input_shapes.append(var.shape) + requires_shape_evaluation = True + resolved_input_shapes.append(variable.shape) - if needs_eval: + if requires_shape_evaluation: try: - shape_fn = function( + shape_evaluation_function = function( [], resolved_input_shapes, on_unused_input="ignore", @@ -349,34 +373,36 @@ def _find_output_types( "Please provide inputs with fully determined shapes by " "calling pt.specify_shape." ) from e - resolved_input_shapes = shape_fn() + resolved_input_shapes = shape_evaluation_function() - # Figure out output types using jax.eval_shape. - extra_output_storage = {} + # Determine output types using jax.eval_shape with dummy inputs + output_metadata_storage = {} - dummy_inputs = [ - jnp.ones(shape, dtype=var.type.dtype) - for var, shape in zip(inputs_flat, resolved_input_shapes, strict=True) + dummy_input_arrays = [ + jnp.ones(shape, dtype=variable.type.dtype) + for variable, shape in zip(inputs_flat, resolved_input_shapes, strict=True) ] - def wrap_jaxfunc(args): - vars = jax.tree.unflatten(input_treedef, args) - args, kwargs = eqx.combine( - vars, - static_input, - ) - outputs = jaxfunc(*args, **kwargs) - output_vals, output_static = eqx.partition(outputs, eqx.is_array) - extra_output_storage["output_static"] = output_static - outputs_flat, output_treedef = jax.tree.flatten(output_vals) - extra_output_storage["output_treedef"] = output_treedef - return outputs_flat - - output_shapes_flat = jax.eval_shape(wrap_jaxfunc, dummy_inputs) - output_treedef = extra_output_storage["output_treedef"] - output_static = extra_output_storage["output_static"] + def wrapped_jax_function(input_arrays): + """Wrapper to extract output metadata during shape evaluation.""" + variables = jax.tree.unflatten(input_treedef, input_arrays) + reconstructed_args, reconstructed_kwargs = eqx.combine(variables, static_input) + function_outputs = jax_function(*reconstructed_args, **reconstructed_kwargs) + array_outputs, static_outputs = eqx.partition(function_outputs, eqx.is_array) + + # Store metadata for later use + output_metadata_storage["output_static"] = static_outputs + flattened_outputs, output_structure = jax.tree.flatten(array_outputs) + output_metadata_storage["output_treedef"] = output_structure + return flattened_outputs + + output_shapes_flat = jax.eval_shape(wrapped_jax_function, dummy_input_arrays) + output_treedef = output_metadata_storage["output_treedef"] + output_static = output_metadata_storage["output_static"] + output_types = [ - pt.TensorType(dtype=var.dtype, shape=var.shape) for var in output_shapes_flat + pt.TensorType(dtype=output_shape.dtype, shape=output_shape.shape) + for output_shape in output_shapes_flat ] return output_types, output_treedef, output_static From fcff09b1a4971bcd0102e0ab389924bec0415d99 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 16 Sep 2025 20:40:04 +0200 Subject: [PATCH 19/30] set output shape to None if not statically known in as_jax_op --- pytensor/__init__.py | 2 +- pytensor/link/jax/ops.py | 21 +++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 924c31225b..0714ac6d54 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -173,7 +173,7 @@ def get_underlying_scalar_constant(v): except ImportError as e: import_error_as_jax_op = e - def as_jax_op(jaxfunc): + def as_jax_op(jax_function=None, allow_eval=True): raise ImportError( "JAX and/or equinox are not installed. Install them" " to use this function: pip install pytensor[jax]" diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index fce3975e24..d5d96479b4 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -280,7 +280,7 @@ def wrapper(*args, **kwargs): variables_flat, variables_treedef = jax.tree.flatten(pytensor_variables) input_types = [var.type for var in variables_flat] - # Determine output types by analyzing the function structure + # Determine output types by calling the function through jax.eval_shape output_types, output_treedef, output_static = _find_output_types( func, variables_flat, @@ -329,7 +329,7 @@ def flattened_function(*flat_variables): def _find_output_types( jax_function, inputs_flat, input_treedef, static_input, *, allow_eval=True ): - """Determine output types by analyzing the JAX function structure.""" + """Determine output types with jax.eval_shape on dummy inputs.""" import equinox as eqx import jax import jax.numpy as jnp @@ -400,9 +400,18 @@ def wrapped_jax_function(input_arrays): output_treedef = output_metadata_storage["output_treedef"] output_static = output_metadata_storage["output_static"] - output_types = [ - pt.TensorType(dtype=output_shape.dtype, shape=output_shape.shape) - for output_shape in output_shapes_flat - ] + # If we used shape evaluation, set all output shapes to unknown + if requires_shape_evaluation: + output_types = [ + pt.TensorType( + dtype=output_shape.dtype, shape=tuple(None for _ in output_shape.shape) + ) + for output_shape in output_shapes_flat + ] + else: + output_types = [ + pt.TensorType(dtype=output_shape.dtype, shape=output_shape.shape) + for output_shape in output_shapes_flat + ] return output_types, output_treedef, output_static From 866b4ba8a77d5290bd25999a83c05fa795c6c0bb Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 17 Sep 2025 00:14:03 +0200 Subject: [PATCH 20/30] clean up global import --- pytensor/__init__.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 0714ac6d54..ee1dc6bab5 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -166,19 +166,7 @@ def get_underlying_scalar_constant(v): from pytensor.scan.basic import scan from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph - -try: - import pytensor.link.jax.ops - from pytensor.link.jax.ops import as_jax_op -except ImportError as e: - import_error_as_jax_op = e - - def as_jax_op(jax_function=None, allow_eval=True): - raise ImportError( - "JAX and/or equinox are not installed. Install them" - " to use this function: pip install pytensor[jax]" - ) from import_error_as_jax_op - +from pytensor.link.jax.ops import as_jax_op # isort: on From 0ae53a0b0f03dae146e489b59f572cd88770bd0d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 17 Sep 2025 00:14:03 +0200 Subject: [PATCH 21/30] remove name from JaxOp.__props__ --- pytensor/link/jax/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index d5d96479b4..dd37798062 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -79,7 +79,7 @@ class JAXOp(Op): [array([1., 1.], dtype=float32)] """ - __props__ = ("input_types", "output_types", "jax_func", "name") + __props__ = ("input_types", "output_types", "jax_func") def __init__(self, input_types, output_types, jax_function, name=None): import jax From 07a2c431adb73003b4538ca3784aa80d0dafeb9b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 17 Sep 2025 00:14:03 +0200 Subject: [PATCH 22/30] don't compile in shape eval of as_jax_op --- pytensor/link/jax/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index dd37798062..3d6a2e965a 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -8,6 +8,7 @@ import pytensor.tensor as pt from pytensor.compile.function import function +from pytensor.compile.mode import Mode from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Op, Variable @@ -365,7 +366,7 @@ def _find_output_types( [], resolved_input_shapes, on_unused_input="ignore", - mode="FAST_COMPILE", + mode=Mode(linker="py", optimizer="fast_compile"), ) except Exception as e: raise ValueError( From f83212623c5d61c4c7561e9778df1caea45c1e6b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 17 Sep 2025 00:14:03 +0200 Subject: [PATCH 23/30] more tests for as_jax_op --- tests/link/jax/test_as_jax_op.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py index b6998d4ab2..f1f89be2a6 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_as_jax_op.py @@ -3,6 +3,7 @@ import pytensor.tensor as pt from pytensor import as_jax_op, config, grad +from pytensor.compile.sharedvalue import shared from pytensor.link.jax.ops import JAXOp from pytensor.scalar import all_types from pytensor.tensor import TensorType, tensor @@ -390,6 +391,43 @@ def f(x, mlp): compare_jax_and_py([x, y], [out, *grad_out], test_values) +def test_no_inputs(): + def f(): + return jax.numpy.array(42.0) + + out = as_jax_op(f)() + assert out.eval() == 42.0 + + +def test_unknown_shape(): + x = tensor("x", shape=(None,)) + + def f(x): + return x * 2 + + with pytest.raises(ValueError, match="Please provide inputs"): + as_jax_op(f)(x) + + +def test_unknown_shape_with_eval(): + x = shared(np.ones(3)) + assert x.type.shape == (None,) + + def f(x): + return x * 2 + + out = as_jax_op(f)(x) + grad_out = grad(pt.sum(out), [x]) + + compare_jax_and_py([], [out, *grad_out], []) + + with jax.disable_jit(): + compare_jax_and_py([], [out, *grad_out], [], must_be_device_array=False) + + with pytest.raises(ValueError, match="Please provide inputs"): + as_jax_op(f, allow_eval=False)(x) + + class TestDtypes: @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) From bf2d0b3725795fe21cef68ae0eb07005503f6760 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Sep 2025 13:24:15 +0200 Subject: [PATCH 24/30] changes based on review --- pytensor/link/jax/ops.py | 102 ++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 3d6a2e965a..9a7dd5b5b9 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -1,19 +1,16 @@ """Convert a jax function to a pytensor compatible function.""" -import logging from collections.abc import Sequence from functools import wraps import numpy as np -import pytensor.tensor as pt from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.gradient import DisconnectedType from pytensor.graph import Apply, Op, Variable - - -log = logging.getLogger(__name__) +from pytensor.tensor.basic import infer_static_shape +from pytensor.tensor.type import TensorType class JAXOp(Op): @@ -28,7 +25,8 @@ class JAXOp(Op): output_types : list A list of PyTensor types for each output variable. jax_function : callable - The JAX function that computes outputs from inputs. + The JAX function that computes outputs from inputs. It should + always return a tuple of outputs, even if there is only one output. name : str, optional A custom name for the Op instance. If provided, the class name will be updated accordingly. @@ -48,7 +46,7 @@ class JAXOp(Op): >>> >>> # Create the input and output types, input has a dynamic shape. >>> input_type = TensorType("float32", shape=(None,)) - >>> output_type = TensorType("float32", shape=(1,)) + >>> output_type = TensorType("float32", shape=()) >>> >>> # Instantiate a JAXOp >>> op = JAXOp( @@ -94,32 +92,45 @@ def __init__(self, input_types, output_types, jax_function, name=None): def __repr__(self): base = self.__class__.__name__ - if self.name is not None: - base = f"{base}{self.name}" props = list(self.__props__) - props.remove("name") - props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props) + if self.name is not None: + props.insert(0, "name") + props = ", ".join(f"{prop}={getattr(self, prop)}" for prop in props) return f"{base}({props})" def make_node(self, *inputs: Variable) -> Apply: """Create an Apply node with the given inputs and inferred outputs.""" + if len(inputs) != len(self.input_types): + raise ValueError( + f"Op {self} expected {len(self.input_types)} inputs, got {len(inputs)}" + ) + filtered_inputs = [ + inp_type.filter_variable(inp) + for inp, inp_type in zip(inputs, self.input_types) + ] outputs = [output_type() for output_type in self.output_types] - return Apply(self, inputs, outputs) + return Apply(self, filtered_inputs, outputs) def perform(self, node, inputs, outputs): """Execute the JAX function and store results in output storage.""" results = self.jitted_func(*inputs) + if not isinstance(results, tuple): + raise TypeError("JAX function must return a tuple of outputs.") if len(results) != len(outputs): raise ValueError( f"JAX function returned {len(results)} outputs, but " f"{len(outputs)} were expected." ) - for i, result in enumerate(results): - outputs[i][0] = np.array(result, dtype=self.output_types[i].dtype) + for output_container, result, out_type in zip( + outputs, results, self.output_types + ): + output_container[0] = np.array(result, dtype=out_type.dtype) def perform_jax(self, *inputs): """Execute the JAX function directly, returning JAX arrays.""" outputs = self.jitted_func(*inputs) + if not isinstance(outputs, tuple): + raise TypeError("JAX function must return a tuple of outputs.") if len(outputs) == 1: return outputs[0] return outputs @@ -129,10 +140,11 @@ def grad(self, inputs, output_gradients): import jax # Find indices of outputs that need gradients - connected_output_indices = [] - for i, output_grad in enumerate(output_gradients): - if not isinstance(output_grad.type, DisconnectedType): - connected_output_indices.append(i) + connected_output_indices = [ + i + for i, output_grad in enumerate(output_gradients) + if not isinstance(output_grad.type, DisconnectedType) + ] num_inputs = len(inputs) @@ -163,21 +175,24 @@ def restricted_function(*input_values): ] ) + if self.name is not None: + name = "vjp_" + self.name + else: + name = "vjp_jax_op" + # Create VJP operation vjp_op = JAXOp( self.input_types + tuple(self.output_types[i] for i in connected_output_indices), [self.input_types[i] for i in range(num_inputs)], vjp_operation, - name="VJP" + (self.name if self.name is not None else ""), + name=name, ) - gradient_outputs = vjp_op( - *[*inputs, *[output_gradients[i] for i in connected_output_indices]] + return vjp_op( + *[*inputs, *[output_gradients[i] for i in connected_output_indices]], + return_list=True, ) - if not isinstance(gradient_outputs, Sequence): - gradient_outputs = [gradient_outputs] - return gradient_outputs def as_jax_op(jax_function=None, *, allow_eval=True): @@ -235,34 +250,23 @@ def as_jax_op(jax_function=None, *, allow_eval=True): Or Equinox modules: - >>> x = pt.tensor("x", shape=(3,)) - >>> y = pt.tensor("y", shape=(3,)) - >>> import equinox as eqx - >>> mlp = eqx.nn.MLP(3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0)) - >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) - >>> @as_jax_op - ... def neural_network(x, mlp): - ... return mlp(x) - >>> out = neural_network(x, mlp) - - Notes - ----- - The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, - available at - `pymc-labs.io `__. - To accept functions and non-PyTensor variables as input, the function uses - :func:`equinox.partition` and :func:`equinox.combine` to split and combine the - variables. Shapes are inferred using - :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. - + >>> x = pt.tensor("x", shape=(3,)) # doctest +SKIP + >>> y = pt.tensor("y", shape=(3,)) # doctest +SKIP + >>> import equinox as eqx # doctest +SKIP + >>> mlp = eqx.nn.MLP( + ... 3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0) + ... ) # doctest +SKIP + >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) # doctest +SKIP + >>> @as_jax_op # doctest +SKIP + ... def neural_network(x, mlp): # doctest +SKIP + ... return mlp(x) # doctest +SKIP + >>> out = neural_network(x, mlp) # doctest +SKIP """ def decorator(func): name = func.__name__ try: - import equinox as eqx import jax except ImportError as e: raise ImportError( @@ -345,7 +349,7 @@ def _find_output_types( continue # Try to infer static shape - _, inferred_shape = pt.basic.infer_static_shape(variable.shape) + _, inferred_shape = infer_static_shape(variable.shape) if not any(dimension is None for dimension in inferred_shape): resolved_input_shapes.append(inferred_shape) continue @@ -404,14 +408,14 @@ def wrapped_jax_function(input_arrays): # If we used shape evaluation, set all output shapes to unknown if requires_shape_evaluation: output_types = [ - pt.TensorType( + TensorType( dtype=output_shape.dtype, shape=tuple(None for _ in output_shape.shape) ) for output_shape in output_shapes_flat ] else: output_types = [ - pt.TensorType(dtype=output_shape.dtype, shape=output_shape.shape) + TensorType(dtype=output_shape.dtype, shape=output_shape.shape) for output_shape in output_shapes_flat ] From 1e93af99334b01265ae096f6158ec262c61b1b90 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Sep 2025 13:24:15 +0200 Subject: [PATCH 25/30] rename as_jax_op to wrap_jax --- doc/library/index.rst | 5 +- pytensor/__init__.py | 2 +- pytensor/link/jax/ops.py | 10 +- .../{test_as_jax_op.py => test_wrap_jax.py} | 171 ++++++++++++------ 4 files changed, 119 insertions(+), 69 deletions(-) rename tests/link/jax/{test_as_jax_op.py => test_wrap_jax.py} (80%) diff --git a/doc/library/index.rst b/doc/library/index.rst index 70506f6120..63cf7572a6 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -64,9 +64,9 @@ Convert to Variable Wrap JAX functions ================== -.. autofunction:: as_jax_op(...) +.. autofunction:: wrap_jax(...) - Alias for :func:`pytensor.link.jax.ops.as_jax_op` + Alias for :func:`pytensor.link.jax.ops.wrap_jax` Debug ===== @@ -74,4 +74,3 @@ Debug .. autofunction:: pytensor.dprint(...) Alias for :func:`pytensor.printing.debugprint` - diff --git a/pytensor/__init__.py b/pytensor/__init__.py index ee1dc6bab5..12f67c9a37 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -166,7 +166,7 @@ def get_underlying_scalar_constant(v): from pytensor.scan.basic import scan from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph -from pytensor.link.jax.ops import as_jax_op +from pytensor.link.jax.ops import wrap_jax # isort: on diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index 9a7dd5b5b9..a895cf52bc 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -195,7 +195,7 @@ def restricted_function(*input_values): ) -def as_jax_op(jax_function=None, *, allow_eval=True): +def wrap_jax(jax_function=None, *, allow_eval=True): """Return a PyTensor-compatible function from a JAX jittable function. This decorator wraps a JAX function so that it accepts and returns @@ -223,7 +223,7 @@ def as_jax_op(jax_function=None, *, allow_eval=True): >>> import jax.numpy as jnp >>> import pytensor.tensor as pt - >>> @as_jax_op + >>> @wrap_jax ... def add(x, y): ... return jnp.add(x, y) >>> x = pt.scalar("x") @@ -238,7 +238,7 @@ def as_jax_op(jax_function=None, *, allow_eval=True): >>> import jax >>> import jax.numpy as jnp >>> import pytensor.tensor as pt - >>> @as_jax_op + >>> @wrap_jax ... def complex_function(x, y, scale=1.0): ... return { ... "sum": jnp.add(x, y) * scale, @@ -257,7 +257,7 @@ def as_jax_op(jax_function=None, *, allow_eval=True): ... 3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0) ... ) # doctest +SKIP >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) # doctest +SKIP - >>> @as_jax_op # doctest +SKIP + >>> @wrap_jax # doctest +SKIP ... def neural_network(x, mlp): # doctest +SKIP ... return mlp(x) # doctest +SKIP >>> out = neural_network(x, mlp) # doctest +SKIP @@ -270,7 +270,7 @@ def decorator(func): import jax except ImportError as e: raise ImportError( - "The as_jax_op decorator requires both jax and equinox to be installed." + "The wrap_jax decorator requires both jax and equinox to be installed." ) from e @wraps(func) diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_wrap_jax.py similarity index 80% rename from tests/link/jax/test_as_jax_op.py rename to tests/link/jax/test_wrap_jax.py index f1f89be2a6..2052b5f4db 100644 --- a/tests/link/jax/test_as_jax_op.py +++ b/tests/link/jax/test_wrap_jax.py @@ -1,8 +1,7 @@ import numpy as np import pytest -import pytensor.tensor as pt -from pytensor import as_jax_op, config, grad +from pytensor import config, grad, wrap_jax from pytensor.compile.sharedvalue import shared from pytensor.link.jax.ops import JAXOp from pytensor.scalar import all_types @@ -24,16 +23,16 @@ def test_two_inputs_single_output(): def f(x, y): return jax.nn.sigmoid(x + y) - # Test with as_jax_op decorator - out = as_jax_op(f)(x, y) - grad_out = grad(pt.sum(out), [x, y]) + # Test with wrap_jax decorator + out = wrap_jax(f)(x, y) + grad_out = grad(out.sum(), [x, y]) compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): compare_jax_and_py([x, y], [out, *grad_out], test_values) def f(x, y): - return [jax.nn.sigmoid(x + y)] + return (jax.nn.sigmoid(x + y),) # Test direct JAXOp usage jax_op = JAXOp( @@ -42,10 +41,32 @@ def f(x, y): f, ) out = jax_op(x, y) - grad_out = grad(pt.sum(out), [x, y]) + grad_out = grad(out.sum(), [x, y]) compare_jax_and_py([x, y], [out, *grad_out], test_values) +def test_op_returns_list(): + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + + test_values = [np.ones((2,)).astype(config.floatX) for inp in (x, y)] + + def f(x, y): + return jax.nn.sigmoid(x + y) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + + with pytest.raises(TypeError, match="tuple of outputs"): + out = jax_op(x, y) + grad_out = grad(out.sum(), [x, y]) + compare_jax_and_py([x, y], [out, *grad_out], test_values) + + def test_two_inputs_tuple_output(): rng = np.random.default_rng(2) x = tensor("x", shape=(2,)) @@ -57,9 +78,9 @@ def test_two_inputs_tuple_output(): def f(x, y): return jax.nn.sigmoid(x + y), y * 2 - # Test with as_jax_op decorator - out1, out2 = as_jax_op(f)(x, y) - grad_out = grad(pt.sum(out1 + out2), [x, y]) + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x, y) + grad_out = grad((out1 + out2).sum(), [x, y]) compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) with jax.disable_jit(): @@ -76,7 +97,7 @@ def f(x, y): f, ) out1, out2 = jax_op(x, y) - grad_out = grad(pt.sum(out1 + out2), [x, y]) + grad_out = grad((out1 + out2).sum(), [x, y]) compare_jax_and_py([x, y], [out1, out2, *grad_out], test_values) @@ -90,11 +111,11 @@ def test_two_inputs_list_output_one_unused_output(): ] def f(x, y): - return [jax.nn.sigmoid(x + y), y * 2] + return (jax.nn.sigmoid(x + y), y * 2) - # Test with as_jax_op decorator - out, _ = as_jax_op(f)(x, y) - grad_out = grad(pt.sum(out), [x, y]) + # Test with wrap_jax decorator + out, _ = wrap_jax(f)(x, y) + grad_out = grad(out.sum(), [x, y]) compare_jax_and_py([x, y], [out, *grad_out], test_values) with jax.disable_jit(): @@ -107,7 +128,7 @@ def f(x, y): f, ) out, _ = jax_op(x, y) - grad_out = grad(pt.sum(out), [x, y]) + grad_out = grad(out.sum(), [x, y]) compare_jax_and_py([x, y], [out, *grad_out], test_values) @@ -119,9 +140,9 @@ def test_single_input_tuple_output(): def f(x): return jax.nn.sigmoid(x), x * 2 - # Test with as_jax_op decorator - out1, out2 = as_jax_op(f)(x) - grad_out = grad(pt.sum(out1), [x]) + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): @@ -136,7 +157,7 @@ def f(x): f, ) out1, out2 = jax_op(x) - grad_out = grad(pt.sum(out1), [x]) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) @@ -148,9 +169,9 @@ def test_scalar_input_tuple_output(): def f(x): return jax.nn.sigmoid(x), x - # Test with as_jax_op decorator - out1, out2 = as_jax_op(f)(x) - grad_out = grad(pt.sum(out1), [x]) + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): @@ -165,7 +186,7 @@ def f(x): f, ) out1, out2 = jax_op(x) - grad_out = grad(pt.sum(out1), [x]) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) @@ -175,11 +196,11 @@ def test_single_input_list_output(): test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] def f(x): - return [jax.nn.sigmoid(x), 2 * x] + return (jax.nn.sigmoid(x), 2 * x) - # Test with as_jax_op decorator - out1, out2 = as_jax_op(f)(x) - grad_out = grad(pt.sum(out1), [x]) + # Test with wrap_jax decorator + out1, out2 = wrap_jax(f)(x) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) with jax.disable_jit(): @@ -197,7 +218,7 @@ def f(x): f, ) out1, out2 = jax_op(x) - grad_out = grad(pt.sum(out1), [x]) + grad_out = grad(out1.sum(), [x]) compare_jax_and_py([x], [out1, out2, *grad_out], test_values) @@ -210,13 +231,13 @@ def test_pytree_input_tuple_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op + @wrap_jax def f(x, y): return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] - # Test with as_jax_op decorator + # Test with wrap_jax decorator out = f(x, y_tmp) - grad_out = grad(pt.sum(out[1]), [x, y]) + grad_out = grad(out[1].sum(), [x, y]) compare_jax_and_py([x, y], [out[0], out[1], *grad_out], test_values) @@ -235,13 +256,13 @@ def test_pytree_input_pytree_output(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op + @wrap_jax def f(x, y): return x, jax.tree_util.tree_map(lambda x: jax.numpy.exp(x), y) - # Test with as_jax_op decorator + # Test with wrap_jax decorator out = f(x, y_tmp) - grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) + grad_out = grad(out[1]["b"][0].sum(), [x, y]) compare_jax_and_py([x, y], [out[0], out[1]["a"], *grad_out], test_values) @@ -263,7 +284,7 @@ def test_pytree_input_with_non_graph_args(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - @as_jax_op + @wrap_jax def f(x, y, depth, which_variable): if which_variable == "x": var = x @@ -275,16 +296,16 @@ def f(x, y, depth, which_variable): var = jax.nn.sigmoid(var) return var - # Test with as_jax_op decorator + # Test with wrap_jax decorator # arguments depth and which_variable are not part of the graph out = f(x, y_tmp, depth=3, which_variable="x") - grad_out = grad(pt.sum(out), [x]) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([x, y], [out[0], *grad_out], test_values) with jax.disable_jit(): compare_jax_and_py([x, y], [out[0], *grad_out], test_values) out = f(x, y_tmp, depth=7, which_variable="y") - grad_out = grad(pt.sum(out), [x]) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([x, y], [out[0], *grad_out], test_values) with jax.disable_jit(): compare_jax_and_py([x, y], [out[0], *grad_out], test_values) @@ -307,9 +328,9 @@ def test_unused_matrix_product(): def f(x, y): return x[:, None] @ y[None], jax.numpy.exp(x) - # Test with as_jax_op decorator - out = as_jax_op(f)(x, y) - grad_out = grad(pt.sum(out[1]), [x]) + # Test with wrap_jax decorator + out = wrap_jax(f)(x, y) + grad_out = grad(out[1].sum(), [x]) compare_jax_and_py([x, y], [out[1], *grad_out], test_values) @@ -326,7 +347,7 @@ def f(x, y): f, ) out = jax_op(x, y) - grad_out = grad(pt.sum(out[1]), [x]) + grad_out = grad(out[1].sum(), [x]) compare_jax_and_py([x, y], [out[1], *grad_out], test_values) @@ -338,13 +359,13 @@ def test_unknown_static_shape(): rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) ] - x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape + x_cumsum = x.cumsum() # Now x_cumsum has an unknown shape def f(x, y): - return [x * jax.numpy.ones(3)] + return (x * jax.numpy.ones(3),) - (out,) = as_jax_op(f)(x_cumsum, y) - grad_out = grad(pt.sum(out), [x]) + (out,) = wrap_jax(f)(x_cumsum, y) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([x, y], [out, *grad_out], test_values) @@ -358,13 +379,13 @@ def f(x, y): f, ) out = jax_op(x_cumsum, y) - grad_out = grad(pt.sum(out), [x]) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([x, y], [out, *grad_out], test_values) def test_nn(): - import equinox as eqx - import equinox.nn as nn + eqx = pytest.importorskip("equinox") + nn = pytest.importorskip("equinox.nn") rng = np.random.default_rng(13) x = tensor("x", shape=(3,)) @@ -378,12 +399,12 @@ def test_nn(): mlp = nn.MLP(3, 3, 3, depth=2, activation=jax.numpy.tanh, key=jax.random.key(0)) mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) - @as_jax_op + @wrap_jax def f(x, mlp): return mlp(x) out = f(x, mlp) - grad_out = grad(pt.sum(out), [x]) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([x, y], [out, *grad_out], test_values) @@ -395,7 +416,7 @@ def test_no_inputs(): def f(): return jax.numpy.array(42.0) - out = as_jax_op(f)() + out = wrap_jax(f)() assert out.eval() == 42.0 @@ -406,7 +427,7 @@ def f(x): return x * 2 with pytest.raises(ValueError, match="Please provide inputs"): - as_jax_op(f)(x) + wrap_jax(f)(x) def test_unknown_shape_with_eval(): @@ -416,8 +437,8 @@ def test_unknown_shape_with_eval(): def f(x): return x * 2 - out = as_jax_op(f)(x) - grad_out = grad(pt.sum(out), [x]) + out = wrap_jax(f)(x) + grad_out = grad(out.sum(), [x]) compare_jax_and_py([], [out, *grad_out], []) @@ -425,7 +446,37 @@ def f(x): compare_jax_and_py([], [out, *grad_out], [], must_be_device_array=False) with pytest.raises(ValueError, match="Please provide inputs"): - as_jax_op(f, allow_eval=False)(x) + wrap_jax(f, allow_eval=False)(x) + + +def test_decorator_forms(): + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + + @wrap_jax + def the_name1(x, y): + return (x + y).sum() + + @wrap_jax(allow_eval=True) + def the_name2(x, y): + return (x + y).sum() + + the_name1(x, y) + the_name2(x, y) + + +def test_repr(): + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + + def the_name(x, y): + return (x + y).sum() + + jax_op = wrap_jax(the_name) + assert "the_name" in repr(jax_op(x, y).owner.op) + + (grad_x, _) = grad(jax_op(x, y), [x, y]) + assert "vjp_the_name" in repr(grad_x.owner.op) class TestDtypes: @@ -446,7 +497,7 @@ def test_different_in_output(self, in_dtype, out_dtype): for inp in (x, y) ] - @as_jax_op + @wrap_jax def f(x, y): out = jax.numpy.add(x, y) return jax.numpy.real(out).astype(out_dtype) @@ -482,7 +533,7 @@ def test_test_different_inputs(self, in1_dtype, in2_dtype): else: test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) - @as_jax_op + @wrap_jax def f(x, y): out = jax.numpy.add(x, y) return jax.numpy.real(out).astype(in1_dtype) @@ -501,7 +552,7 @@ def f(x, y): inputs = [x, y] outputs = [out] - fn, _ = compare_jax_and_py(inputs, outputs, test_values) + compare_jax_and_py(inputs, outputs, test_values) with jax.disable_jit(): if "float" in in1_dtype and "float" in in2_dtype: From 10d2097e908404ff19d17ae0ca925eb5ac364423 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Sep 2025 13:24:15 +0200 Subject: [PATCH 26/30] remove equinox dependency --- pytensor/link/jax/ops.py | 118 +++++++++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index a895cf52bc..dde60f8e57 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -270,15 +270,15 @@ def decorator(func): import jax except ImportError as e: raise ImportError( - "The wrap_jax decorator requires both jax and equinox to be installed." + "The wrap_jax decorator requires jax to be installed." ) from e @wraps(func) def wrapper(*args, **kwargs): # Partition inputs into dynamic PyTensor variables and static variables. # Static variables don't participate in the computational graph. - pytensor_variables, static_values = eqx.partition( - (args, kwargs), lambda x: isinstance(x, pt.Variable) + pytensor_variables, static_values = _eqx_partition( + (args, kwargs), lambda x: isinstance(x, Variable) ) # Flatten the PyTensor variables for processing @@ -297,13 +297,13 @@ def wrapper(*args, **kwargs): def flattened_function(*flat_variables): """Execute the original function with flattened inputs.""" variables = jax.tree.unflatten(variables_treedef, flat_variables) - reconstructed_args, reconstructed_kwargs = eqx.combine( + reconstructed_args, reconstructed_kwargs = _eqx_combine( variables, static_values ) function_outputs = func(*reconstructed_args, **reconstructed_kwargs) - array_outputs, _ = eqx.partition(function_outputs, eqx.is_array) + array_outputs, _ = _eqx_partition(function_outputs, _is_array) flattened_outputs, _ = jax.tree.flatten(array_outputs) - return flattened_outputs + return tuple(flattened_outputs) # Create the JAX operation jax_op_instance = JAXOp( @@ -319,7 +319,7 @@ def flattened_function(*flat_variables): flattened_results = [flattened_results] output_variables = jax.tree.unflatten(output_treedef, flattened_results) - final_outputs = eqx.combine(output_variables, output_static) + final_outputs = _eqx_combine(output_variables, output_static) return final_outputs @@ -335,7 +335,6 @@ def _find_output_types( jax_function, inputs_flat, input_treedef, static_input, *, allow_eval=True ): """Determine output types with jax.eval_shape on dummy inputs.""" - import equinox as eqx import jax import jax.numpy as jnp @@ -391,9 +390,9 @@ def _find_output_types( def wrapped_jax_function(input_arrays): """Wrapper to extract output metadata during shape evaluation.""" variables = jax.tree.unflatten(input_treedef, input_arrays) - reconstructed_args, reconstructed_kwargs = eqx.combine(variables, static_input) + reconstructed_args, reconstructed_kwargs = _eqx_combine(variables, static_input) function_outputs = jax_function(*reconstructed_args, **reconstructed_kwargs) - array_outputs, static_outputs = eqx.partition(function_outputs, eqx.is_array) + array_outputs, static_outputs = _eqx_partition(function_outputs, _is_array) # Store metadata for later use output_metadata_storage["output_static"] = static_outputs @@ -420,3 +419,102 @@ def wrapped_jax_function(input_arrays): ] return output_types, output_treedef, output_static + + +# From the equinox library, licensed under Apache 2.0 +# https://github.com/patrick-kidger/equinox +# +# Copied here to avoid a dependency on equinox just these functions. +def _eqx_combine(*pytrees, is_leaf=None): + """Combines multiple PyTrees into one PyTree, by replacing `None` leaves. + + !!! example + + ```python + pytree1 = [None, 1, 2] + pytree2 = [0, None, None] + equinox.combine(pytree1, pytree2) # [0, 1, 2] + ``` + + !!! tip + + The idea is that `equinox.combine` should be used to undo a call to + [`equinox.filter`][] or [`equinox.partition`][]. + + **Arguments:** + + - `*pytrees`: a sequence of PyTrees all with the same structure. + - `is_leaf`: As [`equinox.partition`][]. + + **Returns:** + + A PyTree with the same structure as its inputs. Each leaf will be the first + non-`None` leaf found in the corresponding leaves of `pytrees` as they are + iterated over. + """ + import jax + + if is_leaf is None: + _is_leaf = _is_none + else: + _is_leaf = lambda x: _is_none(x) or is_leaf(x) # noqa: E731 + + return jax.tree.map(_combine, *pytrees, is_leaf=_is_leaf) + + +def _eqx_partition( + pytree, + filter_spec, + replace=None, + is_leaf=None, +): + """Splits a PyTree into two pieces. Equivalent to + `filter(...), filter(..., inverse=True)`, but slightly more efficient. + + !!! info + + See also [`equinox.combine`][] to reconstitute the PyTree again. + """ + import jax + + filter_tree = jax.tree.map(_make_filter_tree(is_leaf), filter_spec, pytree) + left = jax.tree.map(lambda mask, x: x if mask else replace, filter_tree, pytree) + right = jax.tree.map(lambda mask, x: replace if mask else x, filter_tree, pytree) + return left, right + + +def _make_filter_tree(is_leaf): + import jax + import jax.core + + def _filter_tree(mask, arg): + if isinstance(mask, jax.core.Tracer): + raise ValueError("`filter_spec` leaf values cannot be traced arrays.") + if isinstance(mask, bool): + return jax.tree.map(lambda _: mask, arg, is_leaf=is_leaf) + elif callable(mask): + return jax.tree.map(mask, arg, is_leaf=is_leaf) + else: + raise ValueError( + "`filter_spec` must consist of booleans and callables only." + ) + + return _filter_tree + + +def _is_array(element) -> bool: + """Returns `True` if `element` is a JAX array or NumPy array.""" + import jax + + return isinstance(element, np.ndarray | np.generic | jax.Array) + + +def _combine(*args): + for arg in args: + if arg is not None: + return arg + return None + + +def _is_none(x): + return x is None From fceed2ce55f84dbf50afb8188a3e97de82c0a054 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Sep 2025 13:24:15 +0200 Subject: [PATCH 27/30] rename as_op to wrap_py --- doc/extending/creating_an_op.rst | 12 ++++----- .../extending_pytensor_solution_1.py | 10 +++---- pytensor/compile/__init__.py | 1 + pytensor/compile/ops.py | 22 ++++++++++++---- tests/compile/test_ops.py | 26 +++++++++++++++---- 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/doc/extending/creating_an_op.rst b/doc/extending/creating_an_op.rst index b9aa77f81f..3da251155f 100644 --- a/doc/extending/creating_an_op.rst +++ b/doc/extending/creating_an_op.rst @@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d :download:`Solution` -:func:`as_op` +:func:`wrap_py` ------------- -:func:`as_op` is a Python decorator that converts a Python function into a +:func:`wrap_py` is a Python decorator that converts a Python function into a basic PyTensor :class:`Op` that will call the supplied function during execution. This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation. @@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature inputs PyTensor variables that were declared. .. note:: - The python function wrapped by the :func:`as_op` decorator needs to return a new + The python function wrapped by the :func:`wrap_py` decorator needs to return a new data allocation, no views or in place modification of the input. -:func:`as_op` Example +:func:`wrap_py` Example ^^^^^^^^^^^^^^^^^^^^^ .. testcode:: asop @@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature import pytensor.tensor as pt import numpy as np from pytensor import function - from pytensor.compile.ops import as_op + from pytensor.compile.ops import wrap_py def infer_shape_numpy_dot(fgraph, node, input_shapes): ashp, bshp = input_shapes return [ashp[:-1] + bshp[-1:]] - @as_op( + @wrap_py( itypes=[pt.dmatrix, pt.dmatrix], otypes=[pt.dmatrix], infer_shape=infer_shape_numpy_dot, diff --git a/doc/extending/extending_pytensor_solution_1.py b/doc/extending/extending_pytensor_solution_1.py index ff470ec420..c5a2f8b4e5 100644 --- a/doc/extending/extending_pytensor_solution_1.py +++ b/doc/extending/extending_pytensor_solution_1.py @@ -167,9 +167,9 @@ def test_infer_shape(self): import numpy as np -# as_op exercice +# wrap_py exercice import pytensor -from pytensor.compile.ops import as_op +from pytensor.compile.ops import wrap_py def infer_shape_numpy_dot(fgraph, node, input_shapes): @@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes): return [ashp[:-1] + bshp[-1:]] -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_dot, @@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes): return [ashp[0]] -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_add_sub, @@ -201,7 +201,7 @@ def numpy_add(a, b): return np.add(a, b) -@as_op( +@wrap_py( itypes=[pt.fmatrix, pt.fmatrix], otypes=[pt.fmatrix], infer_shape=infer_shape_numpy_add_sub, diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index f6a95fe163..8c7fe5f396 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -56,6 +56,7 @@ register_deep_copy_op_c_code, register_view_op_c_code, view_op, + wrap_py, ) from pytensor.compile.profiling import ProfileStats from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index a4eba4079f..72b1447b32 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -1,6 +1,6 @@ """ This file contains auxiliary Ops, used during the compilation phase and Ops -building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that +building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that help make new Ops more rapidly. """ @@ -268,12 +268,12 @@ def __reduce__(self): obj = load_back(mod, name) except (ImportError, KeyError, AttributeError): raise pickle.PicklingError( - f"Can't pickle as_op(), not found as {mod}.{name}" + f"Can't pickle wrap_py(), not found as {mod}.{name}" ) else: if obj is not self: raise pickle.PicklingError( - f"Can't pickle as_op(), not the object at {mod}.{name}" + f"Can't pickle wrap_py(), not the object at {mod}.{name}" ) return load_back, (mod, name) @@ -282,6 +282,18 @@ def _infer_shape(self, fgraph, node, input_shapes): def as_op(itypes, otypes, infer_shape=None): + import warnings + + warnings.warn( + "pytensor.as_op is deprecated and will be removed in a future release. " + "Please use pytensor.wrap_py instead.", + DeprecationWarning, + stacklevel=2, + ) + return wrap_py(itypes, otypes, infer_shape) + + +def wrap_py(itypes, otypes, infer_shape=None): """ Decorator that converts a function into a basic PyTensor op that will call the supplied function as its implementation. @@ -301,8 +313,8 @@ def infer_shape(fgraph, node, input_shapes): Examples -------- - @as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], - otypes=[pytensor.tensor.fmatrix]) + @wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], + otypes=[pytensor.tensor.fmatrix]) def numpy_dot(a, b): return numpy.dot(a, b) diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index 461c7793ad..5b7a5ea24a 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -1,14 +1,15 @@ import pickle import numpy as np +import pytest from pytensor import function -from pytensor.compile.ops import as_op +from pytensor.compile.ops import as_op, wrap_py from pytensor.tensor.type import dmatrix, dvector from tests import unittest_tools as utt -@as_op([dmatrix, dmatrix], dmatrix) +@wrap_py([dmatrix, dmatrix], dmatrix) def mul(a, b): """ This is for test_pickle, since the function still has to be @@ -21,7 +22,7 @@ class TestOpDecorator(utt.InferShapeTester): def test_1arg(self): x = dmatrix("x") - @as_op(dmatrix, dvector) + @wrap_py(dmatrix, dvector) def cumprod(x): return np.cumprod(x) @@ -31,13 +32,28 @@ def cumprod(x): assert np.allclose(r, r0), (r, r0) + def test_deprecation(self): + x = dmatrix("x") + + with pytest.warns(DeprecationWarning): + + @as_op(dmatrix, dvector) + def cumprod(x): + return np.cumprod(x) + + fn = function([x], cumprod(x)) + r = fn([[1.5, 5], [2, 2]]) + r0 = np.array([1.5, 7.5, 15.0, 30.0]) + + assert np.allclose(r, r0), (r, r0) + def test_2arg(self): x = dmatrix("x") x.tag.test_value = np.zeros((2, 2)) y = dvector("y") y.tag.test_value = [0, 0, 0, 0] - @as_op([dmatrix, dvector], dvector) + @wrap_py([dmatrix, dvector], dvector) def cumprod_plus(x, y): return np.cumprod(x) + y @@ -57,7 +73,7 @@ def infer_shape(fgraph, node, shapes): x, y = shapes return [y] - @as_op([dmatrix, dvector], dvector, infer_shape) + @wrap_py([dmatrix, dvector], dvector, infer_shape) def cumprod_plus(x, y): return np.cumprod(x) + y From 09f911399a89f46e34cc1d342e8c0eec2e7b3dcf Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Sep 2025 13:24:15 +0200 Subject: [PATCH 28/30] fix doctests of wrap_jax --- doc/conf.py | 1 - pytensor/link/jax/ops.py | 25 +++++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 48d81730ba..e10dcffb90 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,7 +38,6 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), - "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py index dde60f8e57..cada35afdd 100644 --- a/pytensor/link/jax/ops.py +++ b/pytensor/link/jax/ops.py @@ -42,7 +42,7 @@ class JAXOp(Op): >>> >>> # Create the jax function that sums the input array. >>> def sum_function(x, y): - ... return jnp.sum(x + y) + ... return (jnp.sum(x + y),) >>> >>> # Create the input and output types, input has a dynamic shape. >>> input_type = TensorType("float32", shape=(None,)) @@ -67,7 +67,7 @@ class JAXOp(Op): [array(14., dtype=float32)] >>> >>> # Compute the gradient of op(x, y) with respect to x. - >>> g = pt.grad(result[0], x) + >>> g = pt.grad(result, x) >>> grad_f = pytensor.function([x, y], [g]) >>> print( ... grad_f( @@ -223,6 +223,7 @@ def wrap_jax(jax_function=None, *, allow_eval=True): >>> import jax.numpy as jnp >>> import pytensor.tensor as pt + >>> from pytensor import wrap_jax >>> @wrap_jax ... def add(x, y): ... return jnp.add(x, y) @@ -238,13 +239,14 @@ def wrap_jax(jax_function=None, *, allow_eval=True): >>> import jax >>> import jax.numpy as jnp >>> import pytensor.tensor as pt + >>> from pytensor import wrap_jax >>> @wrap_jax ... def complex_function(x, y, scale=1.0): ... return { ... "sum": jnp.add(x, y) * scale, ... } - >>> x = pt.vector("x") - >>> y = pt.vector("y") + >>> x = pt.vector("x", shape=(3,)) + >>> y = pt.vector("y", shape=(3,)) >>> result = complex_function(x, y, scale=2.0) >>> f = pytensor.function([x, y], [result["sum"]]) @@ -261,6 +263,21 @@ def wrap_jax(jax_function=None, *, allow_eval=True): ... def neural_network(x, mlp): # doctest +SKIP ... return mlp(x) # doctest +SKIP >>> out = neural_network(x, mlp) # doctest +SKIP + + If the input shapes are not fully determined, and valid + input shapes cannot be inferred by evaluating the inputs either, + an error will be raised: + + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @wrap_jax + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.vector("x") # shape is not fully determined + >>> y = pt.vector("y") # shape is not fully determined + >>> result = add(x, y) + ValueError: Could not compile a function to infer example shapes. Please provide inputs with fully determined shapes by calling pt.specify_shape. + ... """ def decorator(func): From 6bf98f915cc2cfb6ab391757e5f4a71e677360a5 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Tue, 30 Sep 2025 19:41:15 +0200 Subject: [PATCH 29/30] Update pytensor/compile/ops.py --- pytensor/compile/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 72b1447b32..51398cd7d8 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -287,7 +287,7 @@ def as_op(itypes, otypes, infer_shape=None): warnings.warn( "pytensor.as_op is deprecated and will be removed in a future release. " "Please use pytensor.wrap_py instead.", - DeprecationWarning, + FutureWarning, stacklevel=2, ) return wrap_py(itypes, otypes, infer_shape) From c3874c1243ef8872478b6d5e3c935cb258f30589 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Tue, 30 Sep 2025 19:41:46 +0200 Subject: [PATCH 30/30] Update tests/compile/test_ops.py --- tests/compile/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index 5b7a5ea24a..6d81caae6c 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -35,7 +35,7 @@ def cumprod(x): def test_deprecation(self): x = dmatrix("x") - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): @as_op(dmatrix, dvector) def cumprod(x):