Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for https://github.com/patrick-kidger/optimistix/issues/48 #694

Merged
merged 1 commit into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions equinox/internal/_loop/bounded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def bounded_while_loop(
- `body_fun`: As `lax.while_loop`.
- `init_val`: As `lax.while_loop`.
- `max_steps`: A bound on the maximum number of steps, after which the loop
terminates unconditionally. Can be set to `None` for arbitrarily many steps.
terminates unconditionally.
- `buffers`: If passed, then every leaf of `tree_leaves(buffers(init_val))` must
be an array; all such arrays become buffers supporting only `[]` and
`.at[].set()`. However they will act efficiently, without spurious copies.
Expand All @@ -43,7 +43,6 @@ def bounded_while_loop(
**Returns:**

The final value; as `lax.while_loop`.

"""

if not isinstance(max_steps, int) or max_steps < 0:
Expand Down
12 changes: 11 additions & 1 deletion equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,17 @@ def new_cond_fun(val):
if type(max_steps) is not int:
raise ValueError("`max_steps` must be a Python integer")
out = out & (step < max_steps)
return nonbatchable(out)
# Need to allow being constant across the batch. This seems to arise in some
# edge case when using `bounded_while_loop`, in which:
# - `lax.scan` saves a copy of its state (in particular `step`) for use in the
# backward pass;
# - for some reason the `jax.checkpoint` causes such state to pick up a spurious
# batch tracer.
# See:
# https://github.com/patrick-kidger/optimistix/issues/48#issuecomment-2009221739
return nonbatchable(
out, name="`equinox.internal.while_loop`", allow_constant_across_batch=True
)

def new_body_fun(val):
tag = object()
Expand Down
42 changes: 34 additions & 8 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -122,34 +123,59 @@ def nondifferentiable_backward(
return combine(jtu.tree_unflatten(treedef, flat), static)


def _cannot_batch(x, b, *, msg):
def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
(x,) = x
(b,) = b
if b is batching.not_mapped:
return x, b
else:
raise ValueError(msg)
if allow_constant_across_batch:
x = error_if(x, jnp.min(x) != jnp.max(x), msg)
return x[0], batching.not_mapped
else:
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg: x)
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
mlir.register_lowering(
nonbatchable_p, mlir.lower_fun(lambda x, *, msg: x, multiple_results=False)
nonbatchable_p,
mlir.lower_fun(
lambda x, *, msg, allow_constant_across_batch: x, multiple_results=False
),
)


def nonbatchable(
x: PyTree, *, name: Optional[str] = None, msg: Optional[str] = None
x: PyTree,
*,
name: Optional[str] = None,
msg: Optional[str] = None,
allow_constant_across_batch: bool = False,
) -> PyTree:
"""Identity function. Raises a trace-time assert if it is batched."""
dynamic, static = partition(x, is_array)
flat, treedef = jtu.tree_flatten(dynamic)
if msg is None:
if name is None:
name = "This operation"
msg = f"Unexpected batch tracer. {name} cannot be vmap'd."
bind = ft.partial(nonbatchable_p.bind, msg=msg)
if allow_constant_across_batch:
msg = (
f"Nonconstant batch. {name} has received a batch of values that were "
"expected to be constant. This is probably an internal error in the "
"library you are using."
)
else:
msg = (
f"Unexpected batch tracer. {name} cannot be vmap'd. This is probably "
"an internal error in the library you are using."
)
bind = ft.partial(
nonbatchable_p.bind,
msg=msg,
allow_constant_across_batch=allow_constant_across_batch,
)
flat = map(bind, flat)
return combine(jtu.tree_unflatten(treedef, flat), static)
Loading