diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index c776fbc..78e90eb 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from typing import Any, TypeVar +import jax.numpy as jnp import jax.tree import numpy as np from jax import ShapeDtypeStruct, dtypes, extend @@ -246,7 +247,7 @@ def tesseract_dispatch_batching( ) -> Any: """Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian).""" new_args = [ - arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0) + arg if ax is batching.not_mapped else jnp.moveaxis(arg, ax, 0) for arg, ax in zip(array_args, axes, strict=True) ]