Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 31, 2024
1 parent 60612c1 commit 0a2c4b2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
3 changes: 1 addition & 2 deletions equinox/internal/_loop/bounded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def bounded_while_loop(
- `body_fun`: As `lax.while_loop`.
- `init_val`: As `lax.while_loop`.
- `max_steps`: A bound on the maximum number of steps, after which the loop
terminates unconditionally. Can be set to `None` for arbitrarily many steps.
terminates unconditionally.
- `buffers`: If passed, then every leaf of `tree_leaves(buffers(init_val))` must
be an array; all such arrays become buffers supporting only `[]` and
`.at[].set()`. However they will act efficiently, without spurious copies.
Expand All @@ -43,7 +43,6 @@ def bounded_while_loop(
**Returns:**
The final value; as `lax.while_loop`.
"""

if not isinstance(max_steps, int) or max_steps < 0:
Expand Down
12 changes: 11 additions & 1 deletion equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,17 @@ def new_cond_fun(val):
if type(max_steps) is not int:
raise ValueError("`max_steps` must be a Python integer")
out = out & (step < max_steps)
return nonbatchable(out)
# Need to allow being constant across the batch. This seems to arise in some
# edge case when using `bounded_while_loop`, in which:
# - `lax.scan` saves a copy of its state (in particular `step`) for use in the
# backward pass;
# - for some reason the `jax.checkpoint` causes such state to pick up a spurious
# batch tracer.
# See:
# https://github.com/patrick-kidger/optimistix/issues/48#issuecomment-2009221739
return nonbatchable(
out, name="`equinox.internal.while_loop`", allow_constant_across_batch=True
)

def new_body_fun(val):
tag = object()
Expand Down
42 changes: 34 additions & 8 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -122,34 +123,59 @@ def nondifferentiable_backward(
return combine(jtu.tree_unflatten(treedef, flat), static)


def _cannot_batch(x, b, *, msg):
def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
(x,) = x
(b,) = b
if b is batching.not_mapped:
return x, b
else:
raise ValueError(msg)
if allow_constant_across_batch:
x = error_if(x, jnp.min(x) != jnp.max(x), msg)
return jnp.take(x, 0, axis=b), batching.not_mapped
else:
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg: x)
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
mlir.register_lowering(
nonbatchable_p, mlir.lower_fun(lambda x, *, msg: x, multiple_results=False)
nonbatchable_p,
mlir.lower_fun(
lambda x, *, msg, allow_constant_across_batch: x, multiple_results=False
),
)


def nonbatchable(
x: PyTree, *, name: Optional[str] = None, msg: Optional[str] = None
x: PyTree,
*,
name: Optional[str] = None,
msg: Optional[str] = None,
allow_constant_across_batch: bool = False,
) -> PyTree:
"""Identity function. Raises a trace-time assert if it is batched."""
dynamic, static = partition(x, is_array)
flat, treedef = jtu.tree_flatten(dynamic)
if msg is None:
if name is None:
name = "This operation"
msg = f"Unexpected batch tracer. {name} cannot be vmap'd."
bind = ft.partial(nonbatchable_p.bind, msg=msg)
if allow_constant_across_batch:
msg = (
f"Nonconstant batch. {name} has received a batch of values that were "
"expected to be constant. This is probably an internal error in the "
"library you are using."
)
else:
msg = (
f"Unexpected batch tracer. {name} cannot be vmap'd. This is probably "
"an internal error in the library you are using."
)
bind = ft.partial(
nonbatchable_p.bind,
msg=msg,
allow_constant_across_batch=allow_constant_across_batch,
)
flat = map(bind, flat)
return combine(jtu.tree_unflatten(treedef, flat), static)

0 comments on commit 0a2c4b2

Please sign in to comment.