-
-
Notifications
You must be signed in to change notification settings - Fork 177
TypeError when using diffrax with JAX v0.7.0 #662
Copy link
Copy link
Closed
Labels
questionUser queriesUser queries
Description
Reproducible Example
Run the basic example from Diffrax documentation:
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller)
print(sol.ts) # DeviceArray([0. , 1. , 2. , 3. ])
print(sol.ys) # DeviceArray([1. , 0.368, 0.135, 0.0498])It works when JAX is v0.6.2, but it fails when JAX is v0.7.0.
Details about this error are presented as follows:
Environment
Python version: 3.12.3
Pip package list:
diffrax 0.7.0
equinox 0.12.2
jax 0.7.0.dev20250706+11fd61f3b /opt/jax
jax-cuda12-pjrt 0.7.0.dev20250706 /opt/jaxlibs/jax_cuda12_pjrt
jax-cuda12-plugin 0.7.0.dev20250706 /opt/jaxlibs/jax_cuda12_plugin
jaxlib 0.7.0.dev20250706 /opt/jaxlibs/jaxlib
jaxtyping 0.3.2
nsys-jax 0.1.dev1173+gefb11b7 /opt/nsys-jax/.github/container/nsys_jaxErrors
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[... skipping hidden 1 frame]
File /opt/jax/jax/_src/util.py:298, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
297 return f(*args, **kwargs)
--> 298 return cached(config.trace_context() if trace_context_in_key else _ignore(),
299 *args, **kwargs)
TypeError: unhashable type: 'dict'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File /opt/jax/jax/_src/interpreters/partial_eval.py:1892, in _verify_params_are_hashable(primitive, params)
1891 try:
-> 1892 hash(v)
1893 except TypeError as e:
TypeError: unhashable type: 'dict'
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[2], line 7
4 saveat = diffrax.SaveAt(ts=[0., 1., 2., 3.])
5 stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
----> 7 sol = diffrax.diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
8 stepsize_controller=stepsize_controller)
10 print(sol.ts) # DeviceArray([0. , 1. , 2. , 3. ])
11 print(sol.ys) # DeviceArray([1. , 0.368, 0.135, 0.0498])
[... skipping hidden 18 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:1416, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
1389 init_state = State(
1390 y=y0,
1391 tprev=tprev,
(...) 1409 event_mask=event_mask,
1410 )
1412 #
1413 # Main loop
1414 #
-> 1416 final_state, aux_stats = adjoint.loop(
1417 args=args,
1418 terms=terms,
1419 solver=solver,
1420 stepsize_controller=stepsize_controller,
1421 event=event,
1422 saveat=saveat,
1423 t0=t0,
1424 t1=t1,
1425 dt0=dt0,
1426 max_steps=max_steps,
1427 init_state=init_state,
1428 throw=throw,
1429 passed_solver_state=passed_solver_state,
1430 passed_controller_state=passed_controller_state,
1431 progress_meter=progress_meter,
1432 )
1434 #
1435 # Finish up
1436 #
1438 progress_meter.close(final_state.progress_meter_state)
[... skipping hidden 1 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_adjoint.py:299, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
295 outer_while_loop = ft.partial(
296 _outer_loop, kind="checkpointed", checkpoints=self.checkpoints
297 )
298 msg = None
--> 299 final_state = self._loop(
300 terms=terms,
301 saveat=saveat,
302 init_state=init_state,
303 max_steps=max_steps,
304 inner_while_loop=inner_while_loop,
305 outer_while_loop=outer_while_loop,
306 **kwargs,
307 )
308 if msg is not None:
309 final_state = eqxi.nondifferentiable_backward(
310 final_state, msg=msg, symbolic=True
311 )
File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:619, in loop(solver, stepsize_controller, event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop, progress_meter)
617 static_made_jump = init_state.made_jump
618 static_result = init_state.result
--> 619 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
620 if traced_jump:
621 static_made_jump = None
[... skipping hidden 16 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_integrate.py:349, in loop.<locals>.body_fun_aux(state)
342 state = _handle_static(state)
344 #
345 # Actually do some differential equation solving! Make numerical steps, adapt
346 # step sizes, all that jazz.
347 #
--> 349 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
350 terms,
351 state.tprev,
352 state.tnext,
353 state.y,
354 args,
355 state.solver_state,
356 state.made_jump,
357 )
359 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
360 # we get a negative value for y, and then get a NaN vector field. (And then
361 # everything breaks.) See #143.
362 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)
[... skipping hidden 1 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1149, in AbstractRungeKutta.step(***failed resolving arguments***)
1142 const_result = const_result_sentinel = object()
1143 # Needs to be an `eqxi.while_loop` as:
1144 # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
1145 # more stage on the first step.
1146 # (b) to work around a limitation of JAX's autodiff being unable to express
1147 # "triangular computations" (every stage depends on all previous stages)
1148 # without spurious copies.
-> 1149 final_val = eqxi.while_loop(
1150 cond_stage,
1151 rk_stage,
1152 init_val,
1153 max_steps=num_stages,
1154 buffers=buffers,
1155 kind="checkpointed" if self.scan_kind is None else self.scan_kind,
1156 checkpoints=num_stages,
1157 base=num_stages,
1158 )
1159 _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
1160 assert const_result is not const_result_sentinel
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/loop.py:107, in while_loop(***failed resolving arguments***)
105 elif kind == "checkpointed":
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
110 init_val,
111 max_steps=max_steps,
112 buffers=buffers,
113 checkpoints=checkpoints,
114 )
115 elif kind == "bounded":
116 del kind, checkpoints
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/checkpointed.py:247, in checkpointed_while_loop(***failed resolving arguments***)
245 cond_fun_ = filter_closure_convert(cond_fun_, init_val_)
246 cond_fun_ = jtu.tree_map(_stop_gradient, cond_fun_)
--> 247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
[... skipping hidden 17 frame]
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:471, in common_rewrite.<locals>.new_body_fun(val)
469 step, pred, _, val = val
470 buffer_val = _wrap_buffers(val, pred, tag)
--> 471 buffer_val2 = body_fun(buffer_val)
472 # Needed to work with `disable_jit`, as then we lose the automatic
473 # ArrayLike->Array cast provided by JAX's while loops.
474 # The input `val` is already cast to Array below, so this matches that.
475 buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1050, in AbstractRungeKutta.step.<locals>.rk_stage(val)
1048 assert ki is not _unused
1049 assert ks is not _unused
-> 1050 ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
1051 nonlocal const_result
1052 if const_result is const_result_sentinel:
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:625, in AbstractRungeKutta.step.<locals>.ty_map(fn, *trees)
624 def ty_map(fn, *trees):
--> 625 return t_map(lambda *_trees: y_map(fn, *_trees), *trees)
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:604, in AbstractRungeKutta.step.<locals>.t_map(fn, implicit_val, *trees)
601 else:
602 return fn(*_trees)
--> 604 return jtu.tree_map(_fn, tableaus, *trees)
[... skipping hidden 2 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:602, in AbstractRungeKutta.step.<locals>.t_map.<locals>._fn(tableau, *_trees)
600 return implicit_val
601 else:
--> 602 return fn(*_trees)
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:625, in AbstractRungeKutta.step.<locals>.ty_map.<locals>.<lambda>(*_trees)
624 def ty_map(fn, *trees):
--> 625 return t_map(lambda *_trees: y_map(fn, *_trees), *trees)
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:611, in AbstractRungeKutta.step.<locals>.y_map(fn, *trees)
608 def _fn(_, *_trees):
609 return fn(*_trees)
--> 611 return jtu.tree_map(_fn, y0, *trees)
[... skipping hidden 2 frame]
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:609, in AbstractRungeKutta.step.<locals>.y_map.<locals>._fn(_, *_trees)
608 def _fn(_, *_trees):
--> 609 return fn(*_trees)
File ~/.local/lib/python3.12/site-packages/diffrax/_solver/runge_kutta.py:1050, in AbstractRungeKutta.step.<locals>.rk_stage.<locals>.<lambda>(x, xs)
1048 assert ki is not _unused
1049 assert ks is not _unused
-> 1050 ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
1051 nonlocal const_result
1052 if const_result is const_result_sentinel:
[... skipping hidden 1 frame]
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:343, in _BufferItem.set(self, x, pred, **kwargs)
341 else:
342 makes_false_steps = True
--> 343 return self._buffer._op(
344 pred, self._item, x, _maybe_set, kwargs, makes_false_steps
345 )
[... skipping hidden 1 frame]
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:298, in _Buffer._op(self, pred, item, x, op, kwargs, makes_false_steps)
296 array = self._array._op(pred, item, x, op, kwargs, makes_false_steps)
297 else:
--> 298 array = op(
299 pred,
300 self._array,
301 x,
302 item,
303 kwargs=kwargs,
304 makes_false_steps=makes_false_steps,
305 )
306 return _Buffer(array, self._pred, self._tag, self._makes_false_steps)
File ~/.local/lib/python3.12/site-packages/equinox/internal/_loop/common.py:248, in _maybe_set(pred, xs, x, i, kwargs, makes_false_steps)
246 i_dynamic, i_static = partition(i, is_array)
247 i_dynamic_leaves, i_treedef = jtu.tree_flatten(i_dynamic)
--> 248 [out] = maybe_set_p.bind(
249 pred,
250 xs,
251 x,
252 *i_dynamic_leaves,
253 i_static=i_static,
254 i_treedef=i_treedef,
255 kwargs=kwargs,
256 makes_false_steps=makes_false_steps,
257 )
258 return out
[... skipping hidden 5 frame]
File /opt/jax/jax/_src/interpreters/partial_eval.py:1894, in _verify_params_are_hashable(primitive, params)
1892 hash(v)
1893 except TypeError as e:
-> 1894 raise TypeError(
1895 "As of JAX v0.7, parameters to jaxpr equations must have __hash__ and "
1896 f"__eq__ methods. In a call to primitive {primitive}, the value of "
1897 f"parameter {k} was not hashable: {v}") from e
TypeError: As of JAX v0.7, parameters to jaxpr equations must have __hash__ and __eq__ methods. In a call to primitive maybe_set, the value of parameter kwargs was not hashable: {}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries