From a6bcc1a201d1f9b95044ed3e6a3bf8c3b23924a6 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 20 Jun 2025 02:45:21 +0100 Subject: [PATCH 01/11] feat: support simple sequential batching rule --- tesseract_jax/primitive.py | 81 ++++++++++++++++++++++++------- tesseract_jax/tesseract_compat.py | 32 ++++++------ 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index 102e69a..17f7784 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -3,18 +3,20 @@ import functools from collections.abc import Sequence -from typing import Any +from typing import Any, TypeVar import jax.tree import numpy as np from jax import ShapeDtypeStruct, dtypes, extend from jax.core import ShapedArray -from jax.interpreters import ad, mlir, xla +from jax.interpreters import ad, batching, mlir, xla from jax.tree_util import PyTreeDef from jax.typing import ArrayLike from tesseract_core import Tesseract -from tesseract_jax.tesseract_compat import Jaxeract +from tesseract_jax.tesseract_compat import Jaxeract, combine_args + +T = TypeVar("T") tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch") tesseract_dispatch_p.multiple_results = True @@ -35,21 +37,13 @@ def __hash__(self) -> int: def split_args( - flat_args: Sequence[Any], is_static_mask: Sequence[bool] -) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]: - """Split a flat argument list into a tuple (array_args, static_args).""" - static_args = tuple( - _make_hashable(arg) - for arg, is_static in zip(flat_args, is_static_mask, strict=True) - if is_static - ) - array_args = tuple( - arg - for arg, is_static in zip(flat_args, is_static_mask, strict=True) # fmt: skip - if not is_static - ) - - return array_args, static_args + flat_args: Sequence[T], mask: Sequence[bool] +) -> tuple[tuple[T, ...], tuple[T, ...]]: + """Split a flat argument tuple according to mask (mask_False, mask_True).""" + lists = ([], []) + for a, m in zip(flat_args, mask, strict=True): + lists[m].append(a) + return tuple(tuple(args) for args in lists) @tesseract_dispatch_p.def_abstract_eval @@ -238,6 +232,55 @@ def _dispatch(*args: ArrayLike) -> Any: mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering) +def tesseract_dispatch_batching( + array_args: ArrayLike | ShapedArray | Any, + axes: Sequence[Any], + *, + static_args: tuple[_Hashable, ...], + input_pytreedef: PyTreeDef, + output_pytreedef: PyTreeDef, + output_avals: tuple[ShapeDtypeStruct, ...], + is_static_mask: tuple[bool, ...], + client: Jaxeract, + eval_func: str, +) -> Any: + """Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian).""" + new_args = [ + arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0) + for arg, ax in zip(array_args, axes, strict=True) + ] + + # if output_pytreedef is not None: + # output_pytreedef_expanded = tuple( + # None if layout is None else tuple(n + 1 for n in layout) + (0,) + # for layout in output_pytreedef + # ) + + is_batched_mask = [d is not batching.not_mapped for d in axes] + unbatched_args, batched_args = split_args(new_args, is_batched_mask) + + def _batch_fun(batched_args: tuple): + combined_args = combine_args(unbatched_args, batched_args, is_batched_mask) + return tesseract_dispatch_p.bind( + *combined_args, + static_args=static_args, + input_pytreedef=input_pytreedef, + output_pytreedef=output_pytreedef, + output_avals=output_avals, + is_static_mask=is_static_mask, + client=client, + eval_func=eval_func, + ) + + g = lambda _, x: ((), _batch_fun(x)) + _, outvals = jax.lax.scan(g, (), batched_args) + + return tuple(outvals), (0,) * len(outvals) + + +batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching + + def _check_dtype(dtype: Any) -> None: dt = np.dtype(dtype) if dtypes.canonicalize_dtype(dt) != dt: @@ -316,8 +359,10 @@ def apply_tesseract( client = Jaxeract(tesseract_client) flat_args, input_pytreedef = jax.tree.flatten(inputs) + is_static_mask = tuple(_is_static(arg) for arg in flat_args) array_args, static_args = split_args(flat_args, is_static_mask) + static_args = tuple(_make_hashable(arg) for arg in static_args) # Get abstract values for outputs, so we can unflatten them later output_pytreedef, avals = None, None diff --git a/tesseract_jax/tesseract_compat.py b/tesseract_jax/tesseract_compat.py index 856b662..a629511 100644 --- a/tesseract_jax/tesseract_compat.py +++ b/tesseract_jax/tesseract_compat.py @@ -1,6 +1,7 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence from typing import Any, TypeAlias import jax.tree @@ -12,6 +13,14 @@ PyTree: TypeAlias = Any +def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple: + """Merge the elements of two lists based on a mask.""" + assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0) + args0_iter, args1_iter = iter(args0), iter(args1) + combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask] + return tuple(combined_args) + + def unflatten_args( array_args: tuple[ArrayLike, ...], static_args: tuple[Any, ...], @@ -20,23 +29,14 @@ def unflatten_args( remove_static_args: bool = False, ) -> PyTree: """Unflatten lists of arguments (static and not) into a pytree.""" - combined_args = [] - static_iter = iter(static_args) - array_iter = iter(array_args) - - for is_static in is_static_mask: - if is_static: - elem = next(static_iter) - elem = elem.wrapped if hasattr(elem, "wrapped") else elem - - if remove_static_args: - combined_args.append(None) - else: - combined_args.append(elem) - - else: - combined_args.append(next(array_iter)) + if remove_static_args: + static_args_converted = [None] * len(static_args) + else: + static_args_converted = [ + elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args + ] + combined_args = combine_args(array_args, static_args_converted, is_static_mask) result = jax.tree.unflatten(input_pytreedef, combined_args) if remove_static_args: From c8f3081971d230484e3ce43d162b7fb328150b66 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 20 Jun 2025 03:14:16 +0100 Subject: [PATCH 02/11] delete unnecessarily added blank line --- tesseract_jax/primitive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index 17f7784..c8edce6 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -359,7 +359,6 @@ def apply_tesseract( client = Jaxeract(tesseract_client) flat_args, input_pytreedef = jax.tree.flatten(inputs) - is_static_mask = tuple(_is_static(arg) for arg in flat_args) array_args, static_args = split_args(flat_args, is_static_mask) static_args = tuple(_make_hashable(arg) for arg in static_args) From 47efbc627c14e030b55711875fdda655a241d09d Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Fri, 20 Jun 2025 19:07:51 +0100 Subject: [PATCH 03/11] added jacobian tests --- tests/nested_tesseract/tesseract_api.py | 11 ++++ tests/test_endtoend.py | 84 ++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/tests/nested_tesseract/tesseract_api.py b/tests/nested_tesseract/tesseract_api.py index 6455d53..d762a48 100644 --- a/tests/nested_tesseract/tesseract_api.py +++ b/tests/nested_tesseract/tesseract_api.py @@ -81,6 +81,17 @@ def vector_jacobian_product( return out +def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]): + jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs} + + if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs: + jac["scalars.a"]["scalars.a"] = 10.0 + if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs: + jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]] + + return jac + + def abstract_eval(abstract_inputs): """Calculate output shape of apply from the shape of its inputs.""" return { diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index a79a4b7..04dabdf 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -36,7 +36,7 @@ def _assert_pytree_isequal(a, b, rtol=None, atol=None): else: assert a_elem == b_elem, f"Values are different: {a_elem} != {b_elem}" except AssertionError as e: - failures.append(a_path, str(e)) + failures.append((a_path, str(e))) if failures: msg = "\n".join(f"Path: {path}, Error: {error}" for path, error in failures) @@ -148,6 +148,39 @@ def f(x, y): _assert_pytree_isequal(vjp, vjp_raw) +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev]) +def test_univariate_tesseract_jacobian( + served_univariate_tesseract_raw, use_jit, jacfun +): + rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) + + # make things callable without keyword args + def f(x, y): + return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] + + rosenbrock_raw = rosenbrock_impl + if use_jit: + f = jax.jit(f) + rosenbrock_raw = jax.jit(rosenbrock_raw) + + x, y = np.array(0.0), np.array(0.0) + jac = jacfun(f, argnums=(0, 1))(x, y) + + # Test against Tesseract client + jac_ref = rosenbrock_tess.jacobian( + inputs=dict(x=x, y=y), jac_inputs=["x", "y"], jac_outputs=["result"] + ) + + # Convert from nested dict to nested tuplw + jac_ref = tuple((jac_ref["result"]["x"], jac_ref["result"]["y"])) + _assert_pytree_isequal(jac, jac_ref) + + # Test against direct implementation + jac_raw = jacfun(rosenbrock_raw, argnums=(0, 1))(x, y) + _assert_pytree_isequal(jac, jac_raw) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit): nested_tess = Tesseract(served_nested_tesseract_raw) @@ -286,6 +319,55 @@ def f(a, v): _assert_pytree_isequal(vjp, vjp_ref) +@pytest.mark.parametrize("use_jit", [True, False]) +@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev]) +def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun): + nested_tess = Tesseract(served_nested_tesseract_raw) + a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") + v, w = ( + np.array([1.0, 2.0, 3.0], dtype="float32"), + np.array([5.0, 7.0, 9.0], dtype="float32"), + ) + + def f(a, v): + return apply_tesseract( + nested_tess, + inputs=dict( + scalars={"a": a, "b": b}, + vectors={"v": v, "w": w}, + other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, + ), + ) + + if use_jit: + f = jax.jit(f) + + jac = jacfun(f, argnums=(0, 1))(a, v) + + jac_ref = nested_tess.jacobian( + inputs=dict( + scalars={"a": a, "b": b}, + vectors={"v": v, "w": w}, + other_stuff={"s": "hey!", "i": 1234, "f": 2.718}, + ), + jac_inputs=["scalars.a", "vectors.v"], + jac_outputs=["scalars.a", "vectors.v"], + ) + # JAX returns a 2-layered nested dict containing tuples of arrays + # we need to flatten it to match the Tesseract output (2 layered nested dict of arrays) + jac = { + "scalars.a": { + "scalars.a": jac["scalars"]["a"][0], + "vectors.v": jac["scalars"]["a"][1], + }, + "vectors.v": { + "scalars.a": jac["vectors"]["v"][0], + "vectors.v": jac["vectors"]["v"][1], + }, + } + _assert_pytree_isequal(jac, jac_ref) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_partial_differentiation(served_univariate_tesseract_raw, use_jit): """Test that differentiation works correctly in cases where some inputs are constants.""" From 070c72ebce057abac26e2e1b170587c8dbba5fdc Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 21 Jun 2025 01:10:16 +0100 Subject: [PATCH 04/11] revert to map instead of scan --- tesseract_jax/primitive.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index c8edce6..c776fbc 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -250,12 +250,6 @@ def tesseract_dispatch_batching( for arg, ax in zip(array_args, axes, strict=True) ] - # if output_pytreedef is not None: - # output_pytreedef_expanded = tuple( - # None if layout is None else tuple(n + 1 for n in layout) + (0,) - # for layout in output_pytreedef - # ) - is_batched_mask = [d is not batching.not_mapped for d in axes] unbatched_args, batched_args = split_args(new_args, is_batched_mask) @@ -272,8 +266,7 @@ def _batch_fun(batched_args: tuple): eval_func=eval_func, ) - g = lambda _, x: ((), _batch_fun(x)) - _, outvals = jax.lax.scan(g, (), batched_args) + outvals = jax.lax.map(_batch_fun, batched_args) return tuple(outvals), (0,) * len(outvals) From 0ecf6c425be6040cb6d15be8d99d55c997329961 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sat, 21 Jun 2025 02:39:46 +0100 Subject: [PATCH 05/11] add vmap test for univariate tesseract --- tests/test_endtoend.py | 68 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 04dabdf..7d0d119 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -1,6 +1,8 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import jax import numpy as np import pytest @@ -149,23 +151,26 @@ def f(x, y): @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev]) +@pytest.mark.parametrize( + "jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))] +) def test_univariate_tesseract_jacobian( served_univariate_tesseract_raw, use_jit, jacfun ): rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) # make things callable without keyword args + @jacfun def f(x, y): return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] - rosenbrock_raw = rosenbrock_impl + rosenbrock_raw = jacfun(rosenbrock_impl) if use_jit: f = jax.jit(f) rosenbrock_raw = jax.jit(rosenbrock_raw) x, y = np.array(0.0), np.array(0.0) - jac = jacfun(f, argnums=(0, 1))(x, y) + jac = f(x, y) # Test against Tesseract client jac_ref = rosenbrock_tess.jacobian( @@ -177,10 +182,58 @@ def f(x, y): _assert_pytree_isequal(jac, jac_ref) # Test against direct implementation - jac_raw = jacfun(rosenbrock_raw, argnums=(0, 1))(x, y) + jac_raw = rosenbrock_raw(x, y) _assert_pytree_isequal(jac, jac_raw) +@pytest.mark.parametrize("use_jit", [True, False]) +def test_univariate_tesseract_vmap(served_univariate_tesseract_raw, use_jit): + rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) + + # make things callable without keyword args + def f(x, y): + return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] + + # add one batch dimension + for axes in [(0, 0), (0, None), (None, 0)]: + x = np.arange(3) if axes[0] is not None else np.array(0.0) + y = np.arange(3) if axes[1] is not None else np.array(0.0) + f_vmapped = jax.vmap(f, in_axes=axes) + raw_vmapped = jax.vmap(rosenbrock_impl, in_axes=axes) + + if use_jit: + f_vmapped = jax.jit(f_vmapped) + raw_vmapped = jax.jit(raw_vmapped) + + result = f_vmapped(x, y) + result_raw = raw_vmapped(x, y) + + _assert_pytree_isequal(result, result_raw) + + # add an additional batch dimension + for extra_dim in [0, 1, -1]: + if axes[0] is not None: + x = np.arange(6).reshape(2, 3) + if axes[1] is not None: + y = np.arange(6).reshape(2, 3) + + additional_axes = tuple( + extra_dim if ax is not None else None for ax in axes + ) + + f_vmappedtwice = jax.vmap(f_vmapped, in_axes=additional_axes) + raw_vmappedtwice = jax.vmap(raw_vmapped, in_axes=additional_axes) + + if use_jit: + f_vmappedtwice = jax.jit(f_vmappedtwice) + raw_vmappedtwice = jax.jit(raw_vmappedtwice) + + result = f_vmappedtwice(x, y) + result_raw = raw_vmappedtwice(x, y) + + _assert_pytree_isequal(result, result_raw) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit): nested_tess = Tesseract(served_nested_tesseract_raw) @@ -320,7 +373,9 @@ def f(a, v): @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev]) +@pytest.mark.parametrize( + "jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))] +) def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun): nested_tess = Tesseract(served_nested_tesseract_raw) a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") @@ -329,6 +384,7 @@ def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun) np.array([5.0, 7.0, 9.0], dtype="float32"), ) + @jacfun def f(a, v): return apply_tesseract( nested_tess, @@ -342,7 +398,7 @@ def f(a, v): if use_jit: f = jax.jit(f) - jac = jacfun(f, argnums=(0, 1))(a, v) + jac = f(a, v) jac_ref = nested_tess.jacobian( inputs=dict( From 894cf76b55ad12a6ffac6bdfc5bda1d6c7e1c692 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 09:29:49 +0100 Subject: [PATCH 06/11] iterate over out axes when vmapped twice --- tests/test_endtoend.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 7d0d119..6148788 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -221,17 +221,22 @@ def f(x, y): extra_dim if ax is not None else None for ax in axes ) - f_vmappedtwice = jax.vmap(f_vmapped, in_axes=additional_axes) - raw_vmappedtwice = jax.vmap(raw_vmapped, in_axes=additional_axes) - - if use_jit: - f_vmappedtwice = jax.jit(f_vmappedtwice) - raw_vmappedtwice = jax.jit(raw_vmappedtwice) - - result = f_vmappedtwice(x, y) - result_raw = raw_vmappedtwice(x, y) - - _assert_pytree_isequal(result, result_raw) + for out_axis in [0, 1, -1]: + f_vmappedtwice = jax.vmap( + f_vmapped, in_axes=additional_axes, out_axes=out_axis + ) + raw_vmappedtwice = jax.vmap( + raw_vmapped, in_axes=additional_axes, out_axes=out_axis + ) + + if use_jit: + f_vmappedtwice = jax.jit(f_vmappedtwice) + raw_vmappedtwice = jax.jit(raw_vmappedtwice) + + result = f_vmappedtwice(x, y) + result_raw = raw_vmappedtwice(x, y) + + _assert_pytree_isequal(result, result_raw) @pytest.mark.parametrize("use_jit", [True, False]) From 8805d57ae0cc537f415052e4e1ee6c032b4ca18d Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 14:01:41 +0100 Subject: [PATCH 07/11] add nested tesserat vmap tests --- tests/test_endtoend.py | 67 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 6148788..c8b52ad 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -241,7 +241,9 @@ def f(x, y): @pytest.mark.parametrize("use_jit", [True, False]) def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit): - nested_tess = Tesseract(served_nested_tesseract_raw) + nested_tess = Tesseract.from_tesseract_api( + "tests/nested_tesseract/tesseract_api.py" + ) a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") v, w = ( np.array([1.0, 2.0, 3.0], dtype="float32"), @@ -429,6 +431,69 @@ def f(a, v): _assert_pytree_isequal(jac, jac_ref) +@pytest.mark.parametrize("use_jit", [True, False]) +def test_nested_tesseract_vmap(served_nested_tesseract_raw, use_jit): + nested_tess = Tesseract(served_nested_tesseract_raw) + b = np.array(2.0, dtype="float32") + w = np.array([5.0, 7.0, 9.0], dtype="float32") + + def f(a, v, s, i): + return apply_tesseract( + nested_tess, + inputs={ + "scalars": {"a": a, "b": b}, + "vectors": {"v": v, "w": w}, + "other_stuff": {"s": s, "i": i, "f": 2.718}, + }, + ) + + def f_raw(a, v, s, i): + return { + "scalars": {"a": a * 10 + b, "b": b}, + "vectors": {"v": v * 10 + w, "w": w}, + } + + # add one batch dimension + for a_axis in [None, 0]: + for v_axis in [None, -1, 0, 1]: + if a_axis is None and v_axis is None: + continue + if a_axis == 0: + a = np.arange(4, dtype="float32") + else: + a = np.array(0.0, dtype="float32") + if v_axis == 0: + v = np.arange(12, dtype="float32").reshape((4, 3)) + elif v_axis in [-1, 1]: + v = np.arange(12, dtype="float32").reshape((3, 4)) + else: + v = np.arange(3, dtype="float32") + + mapped_in_axes = (a_axis, v_axis, None, None) + + for mapped_out_axes in [-1, 0, 1] if v_axis else [0]: + if v_axis: + mapped_out_axes = { + "scalars": {"a": 0, "b": 0}, + "vectors": {"v": mapped_out_axes, "w": 0}, + } + f_vmapped = jax.vmap( + f, in_axes=mapped_in_axes, out_axes=mapped_out_axes + ) + raw_vmapped = jax.vmap( + f_raw, in_axes=mapped_in_axes, out_axes=mapped_out_axes + ) + + if use_jit: + f_vmapped = jax.jit(f_vmapped, static_argnames=["s", "i"]) + raw_vmapped = jax.jit(raw_vmapped, static_argnames=["s", "i"]) + + result = f_vmapped(a, v, "hello", 3) + result_raw = raw_vmapped(a, v, "hello", 3) + + _assert_pytree_isequal(result, result_raw) + + @pytest.mark.parametrize("use_jit", [True, False]) def test_partial_differentiation(served_univariate_tesseract_raw, use_jit): """Test that differentiation works correctly in cases where some inputs are constants.""" From 5953f9438c7c9ae59965079fe3cb4f01e4a5d699 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 14:11:52 +0100 Subject: [PATCH 08/11] paramaterize jacobian direction with strings rather than callables --- tests/test_endtoend.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index c8b52ad..94bca55 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -1,7 +1,6 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from functools import partial import jax import numpy as np @@ -151,20 +150,23 @@ def f(x, y): @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize( - "jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))] -) +@pytest.mark.parametrize("jac_direction", ["fwd", "rev"]) def test_univariate_tesseract_jacobian( - served_univariate_tesseract_raw, use_jit, jacfun + served_univariate_tesseract_raw, use_jit, jac_direction ): rosenbrock_tess = Tesseract(served_univariate_tesseract_raw) # make things callable without keyword args - @jacfun def f(x, y): return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"] - rosenbrock_raw = jacfun(rosenbrock_impl) + if jac_direction == "fwd": + f = jax.jacfwd(f, argnums=(0, 1)) + rosenbrock_raw = jax.jacfwd(rosenbrock_impl, argnums=(0, 1)) + else: + f = jax.jacrev(f, argnums=(0, 1)) + rosenbrock_raw = jax.jacrev(rosenbrock_impl, argnums=(0, 1)) + if use_jit: f = jax.jit(f) rosenbrock_raw = jax.jit(rosenbrock_raw) @@ -380,10 +382,8 @@ def f(a, v): @pytest.mark.parametrize("use_jit", [True, False]) -@pytest.mark.parametrize( - "jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))] -) -def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun): +@pytest.mark.parametrize("jac_direction", ["fwd", "rev"]) +def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jac_direction): nested_tess = Tesseract(served_nested_tesseract_raw) a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32") v, w = ( @@ -391,7 +391,6 @@ def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun) np.array([5.0, 7.0, 9.0], dtype="float32"), ) - @jacfun def f(a, v): return apply_tesseract( nested_tess, @@ -402,6 +401,11 @@ def f(a, v): ), ) + if jac_direction == "fwd": + f = jax.jacfwd(f, argnums=(0, 1)) + else: + f = jax.jacrev(f, argnums=(0, 1)) + if use_jit: f = jax.jit(f) From 801ef8e5518ac73bb7344526a103d3b5c4a64fe4 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 14:28:34 +0100 Subject: [PATCH 09/11] updated docstring for combine_args --- tesseract_jax/tesseract_compat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tesseract_jax/tesseract_compat.py b/tesseract_jax/tesseract_compat.py index a629511..a704023 100644 --- a/tesseract_jax/tesseract_compat.py +++ b/tesseract_jax/tesseract_compat.py @@ -14,7 +14,15 @@ def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple: - """Merge the elements of two lists based on a mask.""" + """Merge the elements of two lists based on a mask. + + The length of the two lists is required to be equal to the length of the mask. + `combine_args` will populate the new list according to the mask: if the mask evaluates + to `False` it will take the next item of the first list, if it evaluate to `True` it will + take from the second list. + For example, merging the lists ["foo", "bar"] and [0, 1, 2] wih the mask [1, 0, 0, 1, 1] + will return [0, "foo", "bar", 1, 2] + """ assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0) args0_iter, args1_iter = iter(args0), iter(args1) combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask] From 30662992d0a41d9e61f3d9a9e5bfb48e80a4b3ce Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 14:28:59 +0100 Subject: [PATCH 10/11] correct tuplw typo --- tests/test_endtoend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 94bca55..92cb5f3 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -179,7 +179,7 @@ def f(x, y): inputs=dict(x=x, y=y), jac_inputs=["x", "y"], jac_outputs=["result"] ) - # Convert from nested dict to nested tuplw + # Convert from nested dict to nested tuple jac_ref = tuple((jac_ref["result"]["x"], jac_ref["result"]["y"])) _assert_pytree_isequal(jac, jac_ref) From 632b88443a03702ce3772fefc3a24cb2a9c6b61f Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 23 Jun 2025 20:34:18 +0100 Subject: [PATCH 11/11] reformat `combine_args` example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Dion Häfner --- tesseract_jax/tesseract_compat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tesseract_jax/tesseract_compat.py b/tesseract_jax/tesseract_compat.py index a704023..4d011c4 100644 --- a/tesseract_jax/tesseract_compat.py +++ b/tesseract_jax/tesseract_compat.py @@ -20,8 +20,10 @@ def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tupl `combine_args` will populate the new list according to the mask: if the mask evaluates to `False` it will take the next item of the first list, if it evaluate to `True` it will take from the second list. - For example, merging the lists ["foo", "bar"] and [0, 1, 2] wih the mask [1, 0, 0, 1, 1] - will return [0, "foo", "bar", 1, 2] + + Example: + >>> combine_args(["foo", "bar"], [0, 1, 2], [1, 0, 0, 1, 1]) + [0, "foo", "bar", 1, 2] """ assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0) args0_iter, args1_iter = iter(args0), iter(args1)