Skip to content

Commit 90bf627

Browse files
ArianAmanipatrick-kidger
authored andcommitted
Fix dtype comparison using 'is' instead of '==' in _integrate.py
- Replace 'is' with '==' for dtype comparisons on lines 265 and 388 - Fixes assertion errors in scenarios involving serialization/distributed computing - Ensures robust dtype comparison regardless of how dtype objects are created - Maintains backward compatibility and intended behavior Fixes #678
1 parent eb2c2b2 commit 90bf627

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

diffrax/_integrate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _clip_to_end(tprev, tnext, t1, keep_step):
262262
# The tolerance means that we don't end up with too-small intervals for
263263
# dense output, which then gives numerically unstable answers due to floating
264264
# point errors.
265-
if tnext.dtype is jnp.dtype("float64"):
265+
if tnext.dtype == jnp.dtype("float64"):
266266
tol = 1e-10
267267
else:
268268
tol = 1e-6
@@ -385,7 +385,7 @@ def body_fun_aux(state):
385385
error_order,
386386
state.controller_state,
387387
)
388-
assert jnp.result_type(keep_step) is jnp.dtype(bool)
388+
assert jnp.result_type(keep_step) == jnp.dtype(bool)
389389

390390
#
391391
# Do some book-keeping.

0 commit comments

Comments
 (0)