TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
TypeError Traceback (most recent call last)
File Untitled-6:12
[9](untitled-6:9) term = dx.ODETerm(lambda t, y, _: -a * y)
[10](untitled-6:10) return dx.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0).ys[0, -1]
---> [12](untitled-6:12) jax.grad(tograd)(1.0)
[... skipping hidden 10 frame]
File Untitled-6:10
[8](untitled-6:8) def tograd(a):
[9](untitled-6:9) term = dx.ODETerm(lambda t, y, _: -a * y)
---> [10](untitled-6:10) return dx.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0).ys[0, -1]
[... skipping hidden 27 frame]
File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1272, in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
[1268](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1268) del primals, tangents
[1269](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1269) perturb_val, perturb_body_fun = jtu.tree_map(
[1270](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1270) lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
[1271](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1271) )
-> [1272](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1272) perturb_val = _resolve_perturb_val(
[1273](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1273) init_val, body_fun, perturb_val, perturb_body_fun
[1274](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1274) )
[1275](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1275) t_final_val = jtu.tree_map(
[1276](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1276) _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
[1277](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1277) )
[1278](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1278) return final_val, t_final_val
File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1241, in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
[1238](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1238) else:
[1239](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1239) perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> [1241](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1241) perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
[1242](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1242) return perturb_val
[... skipping hidden 12 frame]
File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1214, in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
[1211](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1211) return _out
[1213](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1213) # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> [1214](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1214) jax.linearize(_to_linearize, dynamic)
[1215](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1215) if new_perturb_val is sentinel:
[1216](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1216) # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
[1217](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1217) # `_dynamic` having them. Presumably the user's `_body_fun` has no
[1218](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1218) # differentiable dependency whatsoever.
[1219](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1219) # This can happen if all the autograd is happening through
[1220](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1220) # `perturb_body_fun`.
[1221](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1221) return Static(perturb_val)
[... skipping hidden 5 frame]
File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1207, in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
[1205](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1205) def _to_linearize(_dynamic):
[1206](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1206) _body_fun, _val = combine(_dynamic, static)
-> [1207](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1207) _out = _body_fun(_val)
[1208](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1208) _dynamic_out, _static_out = partition(_out, is_inexact_array)
[1209](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1209) _dynamic_out = _record_symbolic_zeros(_dynamic_out)
[... skipping hidden 10 frame]
File ~/miniconda3/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:351, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
[344](~/python3.11/site-packages/jax/_src/custom_derivatives.py:344) msg = ("Custom JVP rule must produce primal and tangent outputs with "
[345](~/python3.11/site-packages/jax/_src/custom_derivatives.py:345) "corresponding shapes and dtypes, but got:\n{}")
[346](~/python3.11/site-packages/jax/_src/custom_derivatives.py:346) disagreements = (
[347](~/python3.11/site-packages/jax/_src/custom_derivatives.py:347) f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
[348](~/python3.11/site-packages/jax/_src/custom_derivatives.py:348) for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
[349](~/python3.11/site-packages/jax/_src/custom_derivatives.py:349) if av_et != av_t)
--> [351](~/python3.11/site-packages/jax/_src/custom_derivatives.py:351) raise TypeError(msg.format('\n'.join(disagreements)))
[352](~/python3.11/site-packages/jax/_src/custom_derivatives.py:352) yield primals_out + tangents_out, (out_tree, primal_avals)
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
Hi @patrick-kidger, it seems that the latest jax 0.4.34 has brought a new bug.
When differentiating through
diffeqsolvewithjax.grad, I get the error:Here is a MWE:
and the full stacktrace: