From c7efe51cc422b62cd3d727b83a0dae5179c50a2a Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 30 Mar 2024 21:40:08 +0100 Subject: [PATCH] Fix for https://github.com/patrick-kidger/optimistix/issues/48 --- equinox/internal/_loop/bounded.py | 3 +-- equinox/internal/_loop/common.py | 12 ++++++++- equinox/internal/_nontraceable.py | 42 +++++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/equinox/internal/_loop/bounded.py b/equinox/internal/_loop/bounded.py index 647ffc95..4e02ab22 100644 --- a/equinox/internal/_loop/bounded.py +++ b/equinox/internal/_loop/bounded.py @@ -32,7 +32,7 @@ def bounded_while_loop( - `body_fun`: As `lax.while_loop`. - `init_val`: As `lax.while_loop`. - `max_steps`: A bound on the maximum number of steps, after which the loop - terminates unconditionally. Can be set to `None` for arbitrarily many steps. + terminates unconditionally. - `buffers`: If passed, then every leaf of `tree_leaves(buffers(init_val))` must be an array; all such arrays become buffers supporting only `[]` and `.at[].set()`. However they will act efficiently, without spurious copies. @@ -43,7 +43,6 @@ def bounded_while_loop( **Returns:** The final value; as `lax.while_loop`. - """ if not isinstance(max_steps, int) or max_steps < 0: diff --git a/equinox/internal/_loop/common.py b/equinox/internal/_loop/common.py index 24219b9c..19e444c6 100644 --- a/equinox/internal/_loop/common.py +++ b/equinox/internal/_loop/common.py @@ -429,7 +429,17 @@ def new_cond_fun(val): if type(max_steps) is not int: raise ValueError("`max_steps` must be a Python integer") out = out & (step < max_steps) - return nonbatchable(out) + # Need to allow being constant across the batch. This seems to arise in some + # edge case when using `bounded_while_loop`, in which: + # - `lax.scan` saves a copy of its state (in particular `step`) for use in the + # backward pass; + # - for some reason the `jax.checkpoint` causes such state to pick up a spurious + # batch tracer. + # See: + # https://github.com/patrick-kidger/optimistix/issues/48#issuecomment-2009221739 + return nonbatchable( + out, name="`equinox.internal.while_loop`", allow_constant_across_batch=True + ) def new_body_fun(val): tag = object() diff --git a/equinox/internal/_nontraceable.py b/equinox/internal/_nontraceable.py index 01260b21..04efa073 100644 --- a/equinox/internal/_nontraceable.py +++ b/equinox/internal/_nontraceable.py @@ -9,6 +9,7 @@ import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir +import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import PyTree @@ -122,26 +123,37 @@ def nondifferentiable_backward( return combine(jtu.tree_unflatten(treedef, flat), static) -def _cannot_batch(x, b, *, msg): +def _cannot_batch(x, b, *, msg, allow_constant_across_batch): (x,) = x (b,) = b if b is batching.not_mapped: return x, b else: - raise ValueError(msg) + if allow_constant_across_batch: + x = error_if(x, jnp.min(x) != jnp.max(x), msg) + return x, b + else: + raise ValueError(msg) nonbatchable_p = jax.core.Primitive("nonbatchable") -nonbatchable_p.def_impl(lambda x, *, msg: x) -nonbatchable_p.def_abstract_eval(lambda x, *, msg: x) +nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x) +nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x) batching.primitive_batchers[nonbatchable_p] = _cannot_batch mlir.register_lowering( - nonbatchable_p, mlir.lower_fun(lambda x, *, msg: x, multiple_results=False) + nonbatchable_p, + mlir.lower_fun( + lambda x, *, msg, allow_constant_across_batch: x, multiple_results=False + ), ) def nonbatchable( - x: PyTree, *, name: Optional[str] = None, msg: Optional[str] = None + x: PyTree, + *, + name: Optional[str] = None, + msg: Optional[str] = None, + allow_constant_across_batch: bool = False, ) -> PyTree: """Identity function. Raises a trace-time assert if it is batched.""" dynamic, static = partition(x, is_array) @@ -149,7 +161,21 @@ def nonbatchable( if msg is None: if name is None: name = "This operation" - msg = f"Unexpected batch tracer. {name} cannot be vmap'd." - bind = ft.partial(nonbatchable_p.bind, msg=msg) + if allow_constant_across_batch: + msg = ( + f"Nonconstant batch. {name} has received a batch of values that were " + "expected to be constant. This is probably an internal error in the " + "library you are using." + ) + else: + msg = ( + f"Unexpected batch tracer. {name} cannot be vmap'd. This is probably " + "an internal error in the library you are using." + ) + bind = ft.partial( + nonbatchable_p.bind, + msg=msg, + allow_constant_across_batch=allow_constant_across_batch, + ) flat = map(bind, flat) return combine(jtu.tree_unflatten(treedef, flat), static)