Skip to content

Commit

Permalink
patched edge-case in LinearInterpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
packquickly committed Sep 2, 2023
1 parent 737bf39 commit 685a8f3
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ def _index(_ys):
prev_t = self.ts[index]
next_t = self.ts[index + 1]
diff_t = next_t - prev_t

return (
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
).ω
diff_nonzero = diff_t >= jnp.finfo(diff_t.dtype).eps
safe_diff = jnp.where(diff_nonzero, diff_t, jnp.ones_like(diff_t))
return jnp.where(
diff_nonzero,
(prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / safe_diff)).ω,
prev_ys
)

@eqx.filter_jit
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
Expand Down

0 comments on commit 685a8f3

Please sign in to comment.