Skip to content

Differentiating through diffeqsolve under jax 0.4.34 raises TypeError: Custom JVP rule ... #508

@gautierronan

Description

@gautierronan

Hi @patrick-kidger, it seems that the latest jax 0.4.34 has brought a new bug.

When differentiating through diffeqsolve with jax.grad, I get the error:

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

Here is a MWE:

import diffrax as dx
import jax.numpy as jnp
import jax

solver = dx.Dopri5()
y0 = jnp.array([2., 3.])

def tograd(a):
    term = dx.ODETerm(lambda t, y, _: -a * y)
    return dx.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0).ys[0, -1]

jax.grad(tograd)(1.0)

and the full stacktrace:

TypeError                                 Traceback (most recent call last)
File Untitled-6:12
      [9](untitled-6:9)     term = dx.ODETerm(lambda t, y, _: -a * y)
     [10](untitled-6:10)     return dx.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0).ys[0, -1]
---> [12](untitled-6:12) jax.grad(tograd)(1.0)

    [... skipping hidden 10 frame]

File Untitled-6:10
      [8](untitled-6:8) def tograd(a):
      [9](untitled-6:9)     term = dx.ODETerm(lambda t, y, _: -a * y)
---> [10](untitled-6:10)     return dx.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0).ys[0, -1]

    [... skipping hidden 27 frame]

File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1272, in _stop_gradient_on_unperturbed_jvp(***failed resolving arguments***)
   [1268](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1268) del primals, tangents
   [1269](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1269) perturb_val, perturb_body_fun = jtu.tree_map(
   [1270](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1270)     lambda _, t: t is not None, (init_val, body_fun), (t_init_val, t_body_fun)
   [1271](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1271) )
-> [1272](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1272) perturb_val = _resolve_perturb_val(
   [1273](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1273)     init_val, body_fun, perturb_val, perturb_body_fun
   [1274](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1274) )
   [1275](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1275) t_final_val = jtu.tree_map(
   [1276](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1276)     _perturb_to_tang, t_final_val, perturb_val, is_leaf=_is_none
   [1277](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1277) )
   [1278](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1278) return final_val, t_final_val

File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1241, in _resolve_perturb_val(final_val, body_fun, perturb_final_val, perturb_body_fun)
   [1238](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1238)         else:
   [1239](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1239)             perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)
-> [1241](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1241) perturb_val = jax.eval_shape(_resolve_perturb_val_impl).value
   [1242](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1242) return perturb_val

    [... skipping hidden 12 frame]

File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1214, in _resolve_perturb_val.<locals>._resolve_perturb_val_impl()
   [1211](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1211)     return _out
   [1213](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1213) # Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
-> [1214](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1214) jax.linearize(_to_linearize, dynamic)
   [1215](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1215) if new_perturb_val is sentinel:
   [1216](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1216)     # `_dynamic_out` in `_to_linearize` had no JVP tracers at all, despite
   [1217](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1217)     # `_dynamic` having them. Presumably the user's `_body_fun` has no
   [1218](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1218)     # differentiable dependency whatsoever.
   [1219](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1219)     # This can happen if all the autograd is happening through
   [1220](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1220)     # `perturb_body_fun`.
   [1221](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1221)     return Static(perturb_val)

    [... skipping hidden 5 frame]

File ~/miniconda3/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1207, in _resolve_perturb_val.<locals>._resolve_perturb_val_impl.<locals>._to_linearize(_dynamic)
   [1205](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1205) def _to_linearize(_dynamic):
   [1206](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1206)     _body_fun, _val = combine(_dynamic, static)
-> [1207](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1207)     _out = _body_fun(_val)
   [1208](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1208)     _dynamic_out, _static_out = partition(_out, is_inexact_array)
   [1209](~/python3.11/site-packages/equinox/internal/_loop/checkpointed.py:1209)     _dynamic_out = _record_symbolic_zeros(_dynamic_out)

    [... skipping hidden 10 frame]

File ~/miniconda3/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:351, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
    [344](~/python3.11/site-packages/jax/_src/custom_derivatives.py:344)     msg = ("Custom JVP rule must produce primal and tangent outputs with "
    [345](~/python3.11/site-packages/jax/_src/custom_derivatives.py:345)            "corresponding shapes and dtypes, but got:\n{}")
    [346](~/python3.11/site-packages/jax/_src/custom_derivatives.py:346)     disagreements = (
    [347](~/python3.11/site-packages/jax/_src/custom_derivatives.py:347)         f"  primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
    [348](~/python3.11/site-packages/jax/_src/custom_derivatives.py:348)         for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
    [349](~/python3.11/site-packages/jax/_src/custom_derivatives.py:349)         if av_et != av_t)
--> [351](~/python3.11/site-packages/jax/_src/custom_derivatives.py:351)     raise TypeError(msg.format('\n'.join(disagreements)))
    [352](~/python3.11/site-packages/jax/_src/custom_derivatives.py:352) yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal bool[] with tangent bool[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])
  primal int32[] with tangent int32[], expecting tangent ShapedArray(float0[])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions