Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling discontinuities in time derivative? #58

Closed
jaschau opened this issue Feb 9, 2022 · 12 comments · Fixed by #89
Closed

Handling discontinuities in time derivative? #58

jaschau opened this issue Feb 9, 2022 · 12 comments · Fixed by #89

Comments

@jaschau
Copy link

jaschau commented Feb 9, 2022

Hi,
first of all, let me say that this looks like an amazing project.
I am looking forward to playing around with this :).

In a concrete problem I am dealing with, I have a forced system where the external force is piecewise constant. The external force changes at specific time points (t1, ..., tn), causing a discontinuity of the time derivative.
I would like to use adaptive step-size solvers for increased accuracy, but naively applying adaptive step-size solvers will "waste" a lot of steps to find the point of change.

Would including the change points in SaveAt avoid this problem?
Or is there some other recommended way to handle this?

@jaschau
Copy link
Author

jaschau commented Feb 9, 2022

This feels related to #13 but I cannot judge if this is precisely the same problem.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 9, 2022

No worries, glad you like it.

As long as the time points (t1, ..., tn) are known in advance then this is is possible. Set diffeqsolve(..., stepsize_controller=PIDController(jump_ts=jnp.array([t1, t2, ..., tn]))).

(Note that SaveAt won't help with this problem at all; the differential equation solve is handled independently of the choice of output times.)

FWIW #13 is about handling the case in which the (t1, ..., tn) are not known until during the solve, because they depend on the evolving state somehow.

@jaschau
Copy link
Author

jaschau commented Feb 9, 2022

Thanks for the clarification! Makes total sense now that you would need event handling for #13. Luckily, since it's an external control, I know the time points in advance and can readily use your solution.

If I may, I'd like to take the opportunity to ask for one more clarification.
How do you implement mini-batching w/ adaptive solvers? Are the steps taken the same for all entries in the batch and determined by the batch entry with the most rapid dynamics? Or are the steps taken different for each batch entry?

The reason why I'm asking is that in my case with external forcing, I'd expect the steps to depend on the external inputs as they heavily affect the dynamics. If the steps taken are the same for all entries in the batch, I'd expect that I might run into performance issues w/ batching.

@patrick-kidger
Copy link
Owner

The steps are not the same for all batch elements; they're determined independently. So you should be all right!

@jaschau
Copy link
Author

jaschau commented Feb 9, 2022

Great! Then I'm even more excited! :) Thanks for the prompt clarifications and congratulations for the great work!

@jaschau jaschau closed this as completed Feb 9, 2022
@jaschau
Copy link
Author

jaschau commented Mar 30, 2022

Hi,
I have finally come around to giving this a try. But whenever I specify step_ts and jump_ts in the PID controller, I always encounter a RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.
I have managed to come up with a minimum example that demonstrates the issue on a harmonic oscillator with discontinuous external forcing. At t=7.5, the external forcing is discontinuously changed. Without specifying step_ts and jump_ts, everything works fine. When I specify step_ts and jump_ts, I encounter the RuntimeError.
This feels related to #86, but I am not sure.

# %%
import os
import numpy as np
import matplotlib.pyplot as plt

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import optax  # https://github.com/deepmind/optax


# %% [markdown]
# # Pendulum without forcing

# %%
def ode_simple_pendulum(t, y, args):
    # first order formulation of the simple pendulum
    # dx^2/dt^2 + 2 lambda omega dx/dt + omega^2 x = 0
    # note that we measure t in 2 pi/omega = T, i.e., the period of the undamped
    # pendulum in the small-angle approximation
    x, v = y
    lambd = args[0]
    # return vector of [dotx, dotv]
    return [v, - 4 * np.pi**2 * x - 4 * np.pi * lambd*v]

lambda_pendulum = [0.05]
# initial condition
y0 = [1.5, 0]

tspan = (0.0, 15.)
tgrid = jnp.arange(0., 15., step=0.02)

term = diffrax.ODETerm(ode_simple_pendulum)
solver = diffrax.Dopri5()
solution = diffrax.diffeqsolve(
    term, 
    solver,
    t0=tspan[0],
    t1=tspan[1],
    args=lambda_pendulum,
    dt0=0.02,
    saveat=diffrax.SaveAt(ts=tgrid),
    y0=y0,
    stepsize_controller=diffrax.PIDController(
        rtol=1e-4,
        atol=1e-6
    ),
)

# %% [markdown]
# ## Veryify that solution looks reasonable

# %%
# first component corresponds to x(t), second component to x'(t)
plt.plot(tgrid, solution.ys[0])

# %% [markdown]
# # Pendulum with forcing (w/o discontinuity handling)

# %%
# time points of discontinuities in the forcing
t_force = jnp.array([0., 7.5])
# force values to apply until next discontinuity
force = jnp.array([10., -10.])
# constant interpolation
force_fct = lambda t: force[jnp.maximum(0, jnp.searchsorted(t_force, t) - 1)]
# outputs 10.
print(force_fct(0.0))
# outputs 10.
print(force_fct(7.3))
# outputs -10.
print(force_fct(7.51))


def ode_simple_pendulum_driven(t, y, args):
    # first order formulation of the simple pendulum
    # dx^2/dt^2 + 2 lambda omega dx/dt + omega^2 x = 0
    # note that we measure t in 2 pi/omega = T, i.e., the period of the undamped
    # pendulum in the small-angle approximation
    x, v = y
    lambd = args[0]
    # return vector of [dotx, dotv]
    return [v, - 4 * np.pi**2 * x - 4 * np.pi * lambd*v + force_fct(t)]

term = diffrax.ODETerm(ode_simple_pendulum_driven)
solver = diffrax.Dopri5()
solution = diffrax.diffeqsolve(
    term, 
    solver,
    t0=tspan[0],
    t1=tspan[1],
    args=lambda_pendulum,
    dt0=0.02,
    saveat=diffrax.SaveAt(ts=tgrid),
    y0=y0,
    stepsize_controller=diffrax.PIDController(
        rtol=1e-4,
        atol=1e-6
    ),
)

# %% [markdown]
# ## Verify solution

# %%
plt.plot(tgrid, solution.ys[0])

# %% [markdown]
# # Pendulum with forcing (w/discontinuity handling)
# This fails with `ERROR:absl:Outside call <function _call.<locals>.<lambda> at 0x7f94402e4e50> threw exception The maximum number of solver steps was reached. Try increasing max_steps`.

# %%
term = diffrax.ODETerm(ode_simple_pendulum_driven)
solver = diffrax.Dopri5()
solution = diffrax.diffeqsolve(
    term, 
    solver,
    t0=tspan[0],
    t1=tspan[1],
    args=lambda_pendulum,
    dt0=0.02,
    saveat=diffrax.SaveAt(ts=tgrid),
    y0=y0,
    # this will takes ages with 1e6 max steps w/o convergence
    #max_steps=int(1e6),
    stepsize_controller=diffrax.PIDController(
        rtol=1e-4,
        atol=1e-6,
        # this leads to issues
        step_ts=[7.5], 
        jump_ts=[7.5],
    ),
)

@jaschau jaschau reopened this Mar 30, 2022
@jaschau
Copy link
Author

jaschau commented Mar 30, 2022

Just for completeness, this happens with

diffrax 0.0.6
jax 0.3.1
jaxlib 0.3.0+cuda11.cudnn82

@jaschau
Copy link
Author

jaschau commented Mar 30, 2022

And one last remark: I encountered similar issues in DiffEqFlux.jl when specifying discontinuities in float32 accuracy and had to make sure I used float64. Don't know if it's the same root cause here.

@patrick-kidger
Copy link
Owner

This should be fixed in #89. This was a relatively subtle bug in how jump_ts was handled. Basically, specifying jump_ts would make the solver step to the floating point number 1 ulp prior to that specified in jump_ts (which is correct). After making a jump, the logic was then to begin the next step 1 ulp from the endpoint of the previous step... which then placed it precisely on the jump. This then counted as being "before" the jump, and it got stuck.

The correct behaviour is to move forward two ulps after a jump -- so that the next step starts just after the jump!

Running your code with and without specifying jump_ts, I now find that without specifying jump_ts the solver takes 233 steps; with specifying jump_ts the solvers takes only 181 steps. Hurrah, the expected performance improvements.

By the way, specifying step_ts and jump_ts to be the same values is technically contradictory. The former is used to specify locations that the solver must step to. The latter is used to specify locations that the solver must step to either side of (and so avoid ever making a step that touches or crosses that one floating point number in particular).
The latter takes precedence so things don't break, but in this example you only need to specify jump_ts, and the step_ts does nothing.

@jaschau
Copy link
Author

jaschau commented Mar 30, 2022

Thanks a lot for the swift response and fix and the explanation of the semantics of jump_ts and step_ts! Really appreciated! I was wondering about the difference between step_ts and jump_ts myself; I think it would be great to add the explanation above to the code as an additional comment.
I actually started with only specifying jump_ts. However, this triggers an Exception in

lib/python3.8/site-packages/diffrax/step_size_controller/adaptive.py in _clip_jump_ts(self, t0, t1)
    550                 f"{t1.dtype}."
    551             )
--> 552         t0_index = jnp.searchsorted(self.step_ts, t0)
    553         t1_index = jnp.searchsorted(self.step_ts, t1)
    554         cond = t0_index < t1_index

because step_ts is None.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 30, 2022

Right, this was a separate issue that I've fixed along the way. step_ts and jump_ts should, cross fingers, be reliable now.

@jaschau
Copy link
Author

jaschau commented Mar 30, 2022

Cool, I will check it on my original problem as well once the changes have been merged into master and report back if I encounter any further problems. Again, thanks for the great support!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants