Skip to content

TypeError when using diffrax with JAX v0.7.0 #662

@pennbay

Description

@pennbay

Reproducible Example

Run the basic example from Diffrax documentation:

from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller)

print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])

It works when JAX is v0.6.2, but it fails when JAX is v0.7.0.
Details about this error are presented as follows:

Environment

Python version: 3.12.3

Pip package list:

diffrax                       0.7.0
equinox                       0.12.2
jax                           0.7.0.dev20250706+11fd61f3b   /opt/jax
jax-cuda12-pjrt               0.7.0.dev20250706             /opt/jaxlibs/jax_cuda12_pjrt
jax-cuda12-plugin             0.7.0.dev20250706             /opt/jaxlibs/jax_cuda12_plugin
jaxlib                        0.7.0.dev20250706             /opt/jaxlibs/jaxlib
jaxtyping                     0.3.2
nsys-jax                      0.1.dev1173+gefb11b7          /opt/nsys-jax/.github/container/nsys_jax

Errors

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
    [... skipping hidden 1 frame]

File /opt/jax/jax/_src/util.py:298, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    297   return f(*args, **kwargs)
--> 298 return cached(config.trace_context() if trace_context_in_key else _ignore(),
    299               *args, **kwargs)

TypeError: unhashable type: 'dict'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File /opt/jax/jax/_src/interpreters/partial_eval.py:1892, in _verify_params_are_hashable(primitive, params)
   1891 try:
-> 1892   hash(v)
   1893 except TypeError as e:

TypeError: unhashable type: 'dict'

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 7
      4 saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
      5 stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
----> 7 sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
      8                   stepsize_controller=stepsize_controller)
     10 print(sol.ts)  # DeviceArray([0.   , 1.   , 2.   , 3.    ])
     11 print(sol.ys)  # DeviceArray([1.   , 0.368, 0.135, 0.0498])

    [... skipping hidden 18 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1389 init_state = State(
   1390     y=y0,
   1391     tprev=tprev,
   (...)   1409     event_mask=event_mask,
   1410 )
   1412 #
   1413 # Main loop
   1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
   1417     args=args,
   1418     terms=terms,
   1419     solver=solver,
   1420     stepsize_controller=stepsize_controller,
   1421     event=event,
   1422     saveat=saveat,
   1423     t0=t0,
   1424     t1=t1,
   1425     dt0=dt0,
   1426     max_steps=max_steps,
   1427     init_state=init_state,
   1428     throw=throw,
   1429     passed_solver_state=passed_solver_state,
   1430     passed_controller_state=passed_controller_state,
   1431     progress_meter=progress_meter,
   1432 )
   1434 #
   1435 # Finish up
   1436 #
   1438 progress_meter.close(final_state.progress_meter_state)

    [... skipping hidden 1 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
    295     outer_while_loop = ft.partial(
    296         _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
    297     )
    298     msg = None
--> 299 final_state = self._loop(
    300     terms=terms,
    301     saveat=saveat,
    302     init_state=init_state,
    303     max_steps=max_steps,
    304     inner_while_loop=inner_while_loop,
    305     outer_while_loop=outer_while_loop,
    306     **kwargs,
    307 )
    308 if msg is not None:
    309     final_state = eqxi.nondifferentiable_backward(
    310         final_state, msg=msg, symbolic=True
    311     )

File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:619, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
    617 static_made_jump = init_state.made_jump
    618 static_result = init_state.result
--> 619 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
    620 if traced_jump:
    621     static_made_jump = None

    [... skipping hidden 16 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:349, in loop.<locals>.body_fun_aux(state)
    342 state = _handle_static(state)
    344 #
    345 # Actually do some differential equation solving! Make numerical steps, adapt
    346 # step sizes, all that jazz.
    347 #
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    350     terms,
    351     state.tprev,
    352     state.tnext,
    353     state.y,
    354     args,
    355     state.solver_state,
    356     state.made_jump,
    357 )
    359 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
    360 # we get a negative value for y, and then get a NaN vector field. (And then
    361 # everything breaks.) See #143.
    362 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

    [... skipping hidden 1 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1149, in AbstractRungeKutta.step(***failed resolving arguments***)
   1142 const_result = const_result_sentinel = object()
   1143 # Needs to be an `eqxi.while_loop` as:
   1144 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
   1145 #     more stage on the first step.
   1146 # (b) to work around a limitation of JAX's autodiff being unable to express
   1147 #     "triangular computations" (every stage depends on all previous stages)
   1148 #     without spurious copies.
-> 1149 final_val = eqxi.while_loop(
   1150     cond_stage,
   1151     rk_stage,
   1152     init_val,
   1153     max_steps=num_stages,
   1154     buffers=buffers,
   1155     kind="checkpointed" if self.scan_kind is None else self.scan_kind,
   1156     checkpoints=num_stages,
   1157     base=num_stages,
   1158 )
   1159 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
   1160 assert const_result is not const_result_sentinel

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
    105 elif kind == "checkpointed":
    106     del kind, base
--> 107     return checkpointed_while_loop(
    108         cond_fun,
    109         body_fun,
    110         init_val,
    111         max_steps=max_steps,
    112         buffers=buffers,
    113         checkpoints=checkpoints,
    114     )
    115 elif kind == "bounded":
    116     del kind, checkpoints

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py:247, in checkpointed_while_loop(***failed resolving arguments***)
    245 cond_fun_ = filter_closure_convert(cond_fun_, init_val_)
    246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248 vjp_arg = (init_val_, body_fun_)
    249 final_val_ = _checkpointed_while_loop(
    250     vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251 )

    [... skipping hidden 17 frame]

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:471, in common_rewrite.<locals>.new_body_fun(val)
    469 step, pred, _, val = val
    470 buffer_val = _wrap_buffers(val, pred, tag)
--> 471 buffer_val2 = body_fun(buffer_val)
    472 # Needed to work with `disable_jit`, as then we lose the automatic
    473 # ArrayLike->Array cast provided by JAX's while loops.
    474 # The input `val` is already cast to Array below, so this matches that.
    475 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1050, in AbstractRungeKutta.step.<locals>.rk_stage(val)
   1048     assert ki is not _unused
   1049     assert ks is not _unused
-> 1050     ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
   1051 nonlocal const_result
   1052 if const_result is const_result_sentinel:

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:625, in AbstractRungeKutta.step.<locals>.ty_map(fn, *trees)
    624 def ty_map(fn, *trees):
--> 625     return t_map(lambda *_trees: y_map(fn, *_trees), *trees)

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:604, in AbstractRungeKutta.step.<locals>.t_map(fn, implicit_val, *trees)
    601     else:
    602         return fn(*_trees)
--> 604 return jtu.tree_map(_fn, tableaus, *trees)

    [... skipping hidden 2 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:602, in AbstractRungeKutta.step.<locals>.t_map.<locals>._fn(tableau, *_trees)
    600     return implicit_val
    601 else:
--> 602     return fn(*_trees)

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:625, in AbstractRungeKutta.step.<locals>.ty_map.<locals>.<lambda>(*_trees)
    624 def ty_map(fn, *trees):
--> 625     return t_map(lambda *_trees: y_map(fn, *_trees), *trees)

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:611, in AbstractRungeKutta.step.<locals>.y_map(fn, *trees)
    608 def _fn(_, *_trees):
    609     return fn(*_trees)
--> 611 return jtu.tree_map(_fn, y0, *trees)

    [... skipping hidden 2 frame]

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:609, in AbstractRungeKutta.step.<locals>.y_map.<locals>._fn(_, *_trees)
    608 def _fn(_, *_trees):
--> 609     return fn(*_trees)

File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1050, in AbstractRungeKutta.step.<locals>.rk_stage.<locals>.<lambda>(x, xs)
   1048     assert ki is not _unused
   1049     assert ks is not _unused
-> 1050     ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
   1051 nonlocal const_result
   1052 if const_result is const_result_sentinel:

    [... skipping hidden 1 frame]

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:343, in _BufferItem.set(self, x, pred, **kwargs)
    341 else:
    342     makes_false_steps = True
--> 343 return self._buffer._op(
    344     pred, self._item, x, _maybe_set, kwargs, makes_false_steps
    345 )

    [... skipping hidden 1 frame]

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:298, in _Buffer._op(self, pred, item, x, op, kwargs, makes_false_steps)
    296     array = self._array._op(pred, item, x, op, kwargs, makes_false_steps)
    297 else:
--> 298     array = op(
    299         pred,
    300         self._array,
    301         x,
    302         item,
    303         kwargs=kwargs,
    304         makes_false_steps=makes_false_steps,
    305     )
    306 return _Buffer(array, self._pred, self._tag, self._makes_false_steps)

File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:248, in _maybe_set(pred, xs, x, i, kwargs, makes_false_steps)
    246 i_dynamic, i_static = partition(i, is_array)
    247 i_dynamic_leaves, i_treedef = jtu.tree_flatten(i_dynamic)
--> 248 [out] = maybe_set_p.bind(
    249     pred,
    250     xs,
    251     x,
    252     *i_dynamic_leaves,
    253     i_static=i_static,
    254     i_treedef=i_treedef,
    255     kwargs=kwargs,
    256     makes_false_steps=makes_false_steps,
    257 )
    258 return out

    [... skipping hidden 5 frame]

File /opt/jax/jax/_src/interpreters/partial_eval.py:1894, in _verify_params_are_hashable(primitive, params)
   1892   hash(v)
   1893 except TypeError as e:
-> 1894   raise TypeError(
   1895     "As of JAX v0.7, parameters to jaxpr equations must have __hash__ and "
   1896     f"__eq__ methods. In a call to primitive {primitive}, the value of "
   1897     f"parameter {k} was not hashable: {v}") from e

TypeError: As of JAX v0.7, parameters to jaxpr equations must have __hash__ and __eq__ methods. In a call to primitive maybe_set, the value of parameter kwargs was not hashable: {}

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions