diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index 637fd05..c776fbc 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -3,18 +3,20 @@ import functools from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import jax.tree import numpy as np from jax import ShapeDtypeStruct, dtypes, extend from jax.core import ShapedArray -from jax.interpreters import ad, mlir, xla +from jax.interpreters import ad, batching, mlir, xla from jax.tree_util import PyTreeDef from jax.typing import ArrayLike from tesseract_core import Tesseract -from tesseract_jax.tesseract_compat import Jaxeract +from tesseract_jax.tesseract_compat import Jaxeract, combine_args + +T = TypeVar("T") tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch") tesseract_dispatch_p.multiple_results = True @@ -35,21 +37,13 @@ def __hash__(self) -> int: def split_args( - flat_args: Sequence[Any], is_static_mask: Sequence[bool] -) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]: - """Split a flat argument list into a tuple (array_args, static_args).""" - static_args = tuple( - _make_hashable(arg) - for arg, is_static in zip(flat_args, is_static_mask, strict=True) - if is_static - ) - array_args = tuple( - arg - for arg, is_static in zip(flat_args, is_static_mask, strict=True) - if not is_static - ) - - return array_args, static_args + flat_args: Sequence[T], mask: Sequence[bool] +) -> tuple[tuple[T, ...], tuple[T, ...]]: + """Split a flat argument tuple according to mask (mask_False, mask_True).""" + lists = ([], []) + for a, m in zip(flat_args, mask, strict=True): + lists[m].append(a) + return tuple(tuple(args) for args in lists) @tesseract_dispatch_p.def_abstract_eval @@ -238,6 +232,48 @@ def _dispatch(*args: ArrayLike) -> Any: mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering) +def tesseract_dispatch_batching( + array_args: ArrayLike | ShapedArray | Any, + axes: Sequence[Any], + *, + static_args: tuple[_Hashable, ...], + input_pytreedef: PyTreeDef, + output_pytreedef: PyTreeDef, + output_avals: tuple[ShapeDtypeStruct, ...], + is_static_mask: tuple[bool, ...], + client: Jaxeract, + eval_func: str, +) -> Any: + """Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian).""" + new_args = [ + arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0) + for arg, ax in zip(array_args, axes, strict=True) + ] + + is_batched_mask = [d is not batching.not_mapped for d in axes] + unbatched_args, batched_args = split_args(new_args, is_batched_mask) + + def _batch_fun(batched_args: tuple): + combined_args = combine_args(unbatched_args, batched_args, is_batched_mask) + return tesseract_dispatch_p.bind( + *combined_args, + static_args=static_args, + input_pytreedef=input_pytreedef, + output_pytreedef=output_pytreedef, + output_avals=output_avals, + is_static_mask=is_static_mask, + client=client, + eval_func=eval_func, + ) + + outvals = jax.lax.map(_batch_fun, batched_args) + + return tuple(outvals), (0,) * len(outvals) + + +batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching + + def _check_dtype(dtype: Any) -> None: dt = np.dtype(dtype) if dtypes.canonicalize_dtype(dt) != dt: @@ -318,6 +354,7 @@ def apply_tesseract( flat_args, input_pytreedef = jax.tree.flatten(inputs) is_static_mask = tuple(_is_static(arg) for arg in flat_args) array_args, static_args = split_args(flat_args, is_static_mask) + static_args = tuple(_make_hashable(arg) for arg in static_args) # Get abstract values for outputs, so we can unflatten them later output_pytreedef, avals = None, None diff --git a/tesseract_jax/tesseract_compat.py b/tesseract_jax/tesseract_compat.py index 856b662..4d011c4 100644 --- a/tesseract_jax/tesseract_compat.py +++ b/tesseract_jax/tesseract_compat.py @@ -1,6 +1,7 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from typing import Any, TypeAlias import jax.tree @@ -12,6 +13,24 @@ PyTree: TypeAlias = Any +def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple: + """Merge the elements of two lists based on a mask. + + The length of the two lists is required to be equal to the length of the mask. + `combine_args` will populate the new list according to the mask: if the mask evaluates + to `False` it will take the next item of the first list, if it evaluate to `True` it will + take from the second list. + + Example: + >>> combine_args(["foo", "bar"], [0, 1, 2], [1, 0, 0, 1, 1]) + [0, "foo", "bar", 1, 2] + """ + assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0) + args0_iter, args1_iter = iter(args0), iter(args1) + combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask] + return tuple(combined_args) + + def unflatten_args( array_args: tuple[ArrayLike, ...], static_args: tuple[Any, ...], @@ -20,23 +39,14 @@ def unflatten_args( remove_static_args: bool = False, ) -> PyTree: """Unflatten lists of arguments (static and not) into a pytree.""" - combined_args = [] - static_iter = iter(static_args) - array_iter = iter(array_args) - - for is_static in is_static_mask: - if is_static: - elem = next(static_iter) - elem = elem.wrapped if hasattr(elem, "wrapped") else elem - - if remove_static_args: - combined_args.append(None) - else: - combined_args.append(elem) - - else: - combined_args.append(next(array_iter)) + if remove_static_args: + static_args_converted = [None] * len(static_args) + else: + static_args_converted = [ + elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args + ] + combined_args = combine_args(array_args, static_args_converted, is_static_mask) result = jax.tree.unflatten(input_pytreedef, combined_args) if remove_static_args: diff --git a/tests/nested_tesseract/tesseract_api.py b/tests/nested_tesseract/tesseract_api.py index 6455d53..d762a48 100644 --- a/tests/nested_tesseract/tesseract_api.py +++ b/tests/nested_tesseract/tesseract_api.py @@ -81,6 +81,17 @@ def vector_jacobian_product( return out +def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]): + jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs} + + if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs: + jac["scalars.a"]["scalars.a"] = 10.0 + if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs: + jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]] + + return jac + + def abstract_eval(abstract_inputs): """Calculate output shape of apply from the shape of its inputs.""" return { diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index a79a4b7..92cb5f3 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -1,6 +1,7 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 + import jax import numpy as np import pytest @@ -36,7 +37,7 @@ def _assert_pytree_isequal(a, b, rtol=None, atol=None): else: assert a_elem == b_elem, f"Values are different: {a_elem} != {b_elem}" except AssertionError as e: - failures.append(a_path, str(e)) + failures.append((a_path, str(e))) if failures: msg = "\n".join(f"Path: {path}, Error: {error}" for path, error in failures) @@ -148,9 +149,103 @@ def f(x, y): _assert_pytree_isequal(vjp, vjp_raw) +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("jac_direction", ["fwd", "rev"]) +def test_univariate_tesseract_jacobian( + served_univariate_tesseract_raw, use_jit, jac_direction +): + rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) + + # make things callable without keyword args + def f(x, y): + return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] + + if jac_direction == "fwd": + f = jax.jacfwd(f, argnums=(0, 1)) + rosenbrock_raw = jax.jacfwd(rosenbrock_impl, argnums=(0, 1)) + else: + f = jax.jacrev(f, argnums=(0, 1)) + rosenbrock_raw = jax.jacrev(rosenbrock_impl, argnums=(0, 1)) + + if use_jit: + f = jax.jit(f) + rosenbrock_raw = jax.jit(rosenbrock_raw) + + x, y = np.array(0.0), np.array(0.0) + jac = f(x, y) + + # Test against Tesseract client + jac_ref = rosenbrock_tess.jacobian( + inputs=dict(x=x, y=y), jac_inputs=["x", "y"], jac_outputs=["result"] + ) + + # Convert from nested dict to nested tuple + jac_ref = tuple((jac_ref["result"]["x"], jac_ref["result"]["y"])) + _assert_pytree_isequal(jac, jac_ref) + + # Test against direct implementation + jac_raw = rosenbrock_raw(x, y) + _assert_pytree_isequal(jac, jac_raw) + + +@pytest.mark.parametrize("use_jit", [True, False]) +def test_univariate_tesseract_vmap(served_univariate_tesseract_raw, use_jit): + rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) + + # make things callable without keyword args + def f(x, y): + return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] + + # add one batch dimension + for axes in [(0, 0), (0, None), (None, 0)]: + x = np.arange(3) if axes[0] is not None else np.array(0.0) + y = np.arange(3) if axes[1] is not None else np.array(0.0) + f_vmapped = jax.vmap(f, in_axes=axes) + raw_vmapped = jax.vmap(rosenbrock_impl, in_axes=axes) + + if use_jit: + f_vmapped = jax.jit(f_vmapped) + raw_vmapped = jax.jit(raw_vmapped) + + result = f_vmapped(x, y) + result_raw = raw_vmapped(x, y) + + _assert_pytree_isequal(result, result_raw) + + # add an additional batch dimension + for extra_dim in [0, 1, -1]: + if axes[0] is not None: + x = np.arange(6).reshape(2, 3) + if axes[1] is not None: + y = np.arange(6).reshape(2, 3) + + additional_axes = tuple( + extra_dim if ax is not None else None for ax in axes + ) + + for out_axis in [0, 1, -1]: + f_vmappedtwice = jax.vmap( + f_vmapped, in_axes=additional_axes, out_axes=out_axis + ) + raw_vmappedtwice = jax.vmap( + raw_vmapped, in_axes=additional_axes, out_axes=out_axis + ) + + if use_jit: + f_vmappedtwice = jax.jit(f_vmappedtwice) + raw_vmappedtwice = jax.jit(raw_vmappedtwice) + + result = f_vmappedtwice(x, y) + result_raw = raw_vmappedtwice(x, y) + + _assert_pytree_isequal(result, result_raw) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit): - nested_tess = Tesseract(served_nested_tesseract_raw) + nested_tess = Tesseract.from_tesseract_api( + "tests/nested_tesseract/tesseract_api.py" + ) a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") v, w = ( np.array([1.0, 2.0, 3.0], dtype="float32"), @@ -286,6 +381,123 @@ def f(a, v): _assert_pytree_isequal(vjp, vjp_ref) +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("jac_direction", ["fwd", "rev"]) +def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jac_direction): + nested_tess = Tesseract(served_nested_tesseract_raw) + a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") + v, w = ( + np.array([1.0, 2.0, 3.0], dtype="float32"), + np.array([5.0, 7.0, 9.0], dtype="float32"), + ) + + def f(a, v): + return apply_tesseract( + nested_tess, + inputs=dict( + scalars={"a": a, "b": b}, + vectors={"v": v, "w": w}, + other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, + ), + ) + + if jac_direction == "fwd": + f = jax.jacfwd(f, argnums=(0, 1)) + else: + f = jax.jacrev(f, argnums=(0, 1)) + + if use_jit: + f = jax.jit(f) + + jac = f(a, v) + + jac_ref = nested_tess.jacobian( + inputs=dict( + scalars={"a": a, "b": b}, + vectors={"v": v, "w": w}, + other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, + ), + jac_inputs=["scalars.a", "vectors.v"], + jac_outputs=["scalars.a", "vectors.v"], + ) + # JAX returns a 2-layered nested dict containing tuples of arrays + # we need to flatten it to match the Tesseract output (2 layered nested dict of arrays) + jac = { + "scalars.a": { + "scalars.a": jac["scalars"]["a"][0], + "vectors.v": jac["scalars"]["a"][1], + }, + "vectors.v": { + "scalars.a": jac["vectors"]["v"][0], + "vectors.v": jac["vectors"]["v"][1], + }, + } + _assert_pytree_isequal(jac, jac_ref) + + +@pytest.mark.parametrize("use_jit", [True, False]) +def test_nested_tesseract_vmap(served_nested_tesseract_raw, use_jit): + nested_tess = Tesseract(served_nested_tesseract_raw) + b = np.array(2.0, dtype="float32") + w = np.array([5.0, 7.0, 9.0], dtype="float32") + + def f(a, v, s, i): + return apply_tesseract( + nested_tess, + inputs={ + "scalars": {"a": a, "b": b}, + "vectors": {"v": v, "w": w}, + "other_stuff": {"s": s, "i": i, "f": 2.718}, + }, + ) + + def f_raw(a, v, s, i): + return { + "scalars": {"a": a * 10 + b, "b": b}, + "vectors": {"v": v * 10 + w, "w": w}, + } + + # add one batch dimension + for a_axis in [None, 0]: + for v_axis in [None, -1, 0, 1]: + if a_axis is None and v_axis is None: + continue + if a_axis == 0: + a = np.arange(4, dtype="float32") + else: + a = np.array(0.0, dtype="float32") + if v_axis == 0: + v = np.arange(12, dtype="float32").reshape((4, 3)) + elif v_axis in [-1, 1]: + v = np.arange(12, dtype="float32").reshape((3, 4)) + else: + v = np.arange(3, dtype="float32") + + mapped_in_axes = (a_axis, v_axis, None, None) + + for mapped_out_axes in [-1, 0, 1] if v_axis else [0]: + if v_axis: + mapped_out_axes = { + "scalars": {"a": 0, "b": 0}, + "vectors": {"v": mapped_out_axes, "w": 0}, + } + f_vmapped = jax.vmap( + f, in_axes=mapped_in_axes, out_axes=mapped_out_axes + ) + raw_vmapped = jax.vmap( + f_raw, in_axes=mapped_in_axes, out_axes=mapped_out_axes + ) + + if use_jit: + f_vmapped = jax.jit(f_vmapped, static_argnames=["s", "i"]) + raw_vmapped = jax.jit(raw_vmapped, static_argnames=["s", "i"]) + + result = f_vmapped(a, v, "hello", 3) + result_raw = raw_vmapped(a, v, "hello", 3) + + _assert_pytree_isequal(result, result_raw) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_partial_differentiation(served_univariate_tesseract_raw, use_jit): """Test that differentiation works correctly in cases where some inputs are constants."""