Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions tesseract_jax/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

import functools
from collections.abc import Sequence
from typing import Any
from typing import Any, TypeVar

import jax.tree
import numpy as np
from jax import ShapeDtypeStruct, dtypes, extend
from jax.core import ShapedArray
from jax.interpreters import ad, mlir, xla
from jax.interpreters import ad, batching, mlir, xla
from jax.tree_util import PyTreeDef
from jax.typing import ArrayLike
from tesseract_core import Tesseract

from tesseract_jax.tesseract_compat import Jaxeract
from tesseract_jax.tesseract_compat import Jaxeract, combine_args

T = TypeVar("T")

tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
tesseract_dispatch_p.multiple_results = True
Expand All @@ -35,21 +37,13 @@ def __hash__(self) -> int:


def split_args(
flat_args: Sequence[Any], is_static_mask: Sequence[bool]
) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]:
"""Split a flat argument list into a tuple (array_args, static_args)."""
static_args = tuple(
_make_hashable(arg)
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
if is_static
)
array_args = tuple(
arg
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
if not is_static
)

return array_args, static_args
flat_args: Sequence[T], mask: Sequence[bool]
) -> tuple[tuple[T, ...], tuple[T, ...]]:
"""Split a flat argument tuple according to mask (mask_False, mask_True)."""
lists = ([], [])
for a, m in zip(flat_args, mask, strict=True):
lists[m].append(a)
return tuple(tuple(args) for args in lists)


@tesseract_dispatch_p.def_abstract_eval
Expand Down Expand Up @@ -238,6 +232,48 @@ def _dispatch(*args: ArrayLike) -> Any:
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)


def tesseract_dispatch_batching(
array_args: ArrayLike | ShapedArray | Any,
axes: Sequence[Any],
*,
static_args: tuple[_Hashable, ...],
input_pytreedef: PyTreeDef,
output_pytreedef: PyTreeDef,
output_avals: tuple[ShapeDtypeStruct, ...],
is_static_mask: tuple[bool, ...],
client: Jaxeract,
eval_func: str,
) -> Any:
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
new_args = [
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
for arg, ax in zip(array_args, axes, strict=True)
]

is_batched_mask = [d is not batching.not_mapped for d in axes]
unbatched_args, batched_args = split_args(new_args, is_batched_mask)

def _batch_fun(batched_args: tuple):
combined_args = combine_args(unbatched_args, batched_args, is_batched_mask)
return tesseract_dispatch_p.bind(
*combined_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=output_pytreedef,
output_avals=output_avals,
is_static_mask=is_static_mask,
client=client,
eval_func=eval_func,
)

outvals = jax.lax.map(_batch_fun, batched_args)

return tuple(outvals), (0,) * len(outvals)


batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching


def _check_dtype(dtype: Any) -> None:
dt = np.dtype(dtype)
if dtypes.canonicalize_dtype(dt) != dt:
Expand Down Expand Up @@ -318,6 +354,7 @@ def apply_tesseract(
flat_args, input_pytreedef = jax.tree.flatten(inputs)
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
array_args, static_args = split_args(flat_args, is_static_mask)
static_args = tuple(_make_hashable(arg) for arg in static_args)

# Get abstract values for outputs, so we can unflatten them later
output_pytreedef, avals = None, None
Expand Down
42 changes: 26 additions & 16 deletions tesseract_jax/tesseract_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence
from typing import Any, TypeAlias

import jax.tree
Expand All @@ -12,6 +13,24 @@
PyTree: TypeAlias = Any


def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple:
"""Merge the elements of two lists based on a mask.

The length of the two lists is required to be equal to the length of the mask.
`combine_args` will populate the new list according to the mask: if the mask evaluates
to `False` it will take the next item of the first list, if it evaluate to `True` it will
take from the second list.

Example:
>>> combine_args(["foo", "bar"], [0, 1, 2], [1, 0, 0, 1, 1])
[0, "foo", "bar", 1, 2]
"""
assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0)
args0_iter, args1_iter = iter(args0), iter(args1)
combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask]
return tuple(combined_args)


def unflatten_args(
array_args: tuple[ArrayLike, ...],
static_args: tuple[Any, ...],
Expand All @@ -20,23 +39,14 @@ def unflatten_args(
remove_static_args: bool = False,
) -> PyTree:
"""Unflatten lists of arguments (static and not) into a pytree."""
combined_args = []
static_iter = iter(static_args)
array_iter = iter(array_args)

for is_static in is_static_mask:
if is_static:
elem = next(static_iter)
elem = elem.wrapped if hasattr(elem, "wrapped") else elem

if remove_static_args:
combined_args.append(None)
else:
combined_args.append(elem)

else:
combined_args.append(next(array_iter))
if remove_static_args:
static_args_converted = [None] * len(static_args)
else:
static_args_converted = [
elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args
]

combined_args = combine_args(array_args, static_args_converted, is_static_mask)
result = jax.tree.unflatten(input_pytreedef, combined_args)

if remove_static_args:
Expand Down
11 changes: 11 additions & 0 deletions tests/nested_tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def vector_jacobian_product(
return out


def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]):
jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs}

if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs:
jac["scalars.a"]["scalars.a"] = 10.0
if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs:
jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]]

return jac


def abstract_eval(abstract_inputs):
"""Calculate output shape of apply from the shape of its inputs."""
return {
Expand Down
Loading
Loading