Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Delay differential equations #164

Closed
wants to merge 1 commit into from
Closed

[WIP] Delay differential equations #164

wants to merge 1 commit into from

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Sep 24, 2022

@thibmonsel

This is a quick WIP draft of how we might add support for delay diffeqs into Diffrax.

The goal is to make the API follow:

def vector_field(t, y, args, *, history):
    ...

delays = [lambda t, y, args: 1.0,
          lambda t, y, args: max(y, 1)]

diffeqsolve(ODETerm(vector_field), ..., delays=delays)

There's several pieces that still need doing:

  • The nonlinear solve, with respect to the dense solution over each step. (E.g. as per Section 4.1 of the DelayDiffEq.jl paper)
  • Detecting discontinuities and stepping to them directly. (Section 4.2)
  • Possibly add special support for "nice" delays, that we might be able to handle more efficiently? E.g. as long as our minimal delay is larger than our step size then the nonlinear solve can be skipped.
  • Adding documentation.
  • Adding an example.
  • Probably now would be a good time to figure out how to add support for solving DAEs as well (e.g. see [Feature request] differential algebraic equations #62). Both involve a nonlinear solve, and both involve passing extra information to the user-provided vector field. It might be that we can make use the same mechanisms for both. (And at the very least we should ensure that any choices we make now don't negatively impact DAE support later.)

@thibmonsel
Copy link
Contributor

thibmonsel commented Sep 27, 2022

@patrick-kidger These are notes and also a way for me to make sure that i don't misunderstand certain sections of the code

In integrate.py :

  • loop function does successive steps in order to solve the differential equation.
  • body_fun inside loop is one integration step and returns an updated State object.
  • Inside of body_fun the step size is adapted and the controller should be the one to handle discontinuities ?
  • Regarding discontinuities we might have some issues. In the general case we must not integrate past the discontinuites and pre-allocating memory for saveat.ts and saveat.ys cannot be exact. To cope with this, allocate a larger array for saveat.ts and saveat.ys and at the end remove/slice all of the Nans ?
  • diffeqsolve function initializes all of the stepsize_controller and solver_state, allocates memory for ts and ys and first state init_state. Call the loop function to integrate and return the solution.

@patrick-kidger
Copy link
Owner Author

I don't think I understand your penultimate point. Can you expand on that please?

Everything else is correct.

@thibmonsel
Copy link
Contributor

  • In the general case of a DDE : $y\prime(t) = f(t,y(t), y(t-\tau_1), y(t-\tau_2), \dots, y(t-\tau_n))$ where $\forall i, \tau_i = \tau_i(t,y(t))$.
  • During an integration step (tprev to tnext) we don't know a priori if we will get a discontinuity. There is a discontinuity if the following function cancels out $g_{i,r}(t) = t-\tau_i - \lambda_r$ where $\lambda_r$ are known previous discontinuities. If we need to save the values of y at ts, and we don't save values of y at discontinuities we will get incorrect estimates of y further down the integration. Hence, the issues of memory allocation for ys.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Sep 28, 2022

So I don't think there should be any memory allocation issues. (Indeed I don't think the changes we make here will need to touch the saving-to-output code at all.)

Diffrax already has a max_steps parameter, and (if saving dense output) we automatically preallocate a buffer that is max_steps long and save our dense interpolation into it as we go along. If we exceed that number of steps an error is thrown. And the existing interface to the dense interpolation already knows how to trim off any excess space in the buffer, when reading from it.

From our point of view now: I think all we need to do is wrap the step size controller, so that if a discontinuity is detected then the endpoint of the next step is clipped to the discontinuity. Also we should set made_jump=True. (This is used to tell FSAL solvers know that in this case F!=L, and the stepsize controller to perturb the start of the next step just after the discontinuity. I can expand more on these points if you're curious.)

Does that make sense?

@thibmonsel
Copy link
Contributor

That makes sense for max_steps ! What does made_jump=True do ? Is it just telling the solver to not use FSAL for the next step ? Not sure to understand the stepsize controller perturbation after the discontinuity.

Regarding discontinuities, the controller will also need to return a new variable i call step_back that tells if we need to restart the integration at the previous step. This new bool step_back just makes sure that the discontinuity is not "too close" to t0, if that would be the case numerical integration would have some problems. Hence, also updating the state.dense_save_index and other arguments. The State_ will also need another argument ys_history that appends at each step the corresponding "historical data" in order to then use properly the :

 history = DenseInterpolation(
                ts=state.dense_ts,
                ts_size=state.dense_save_index + 1,
                interpolation_cls=solver.interpolation_cls,
                infos=state.ys_history,
                direction=1,
            )

With that being said, i think ill need to modify the book-keeping that in integrate.py (line 272 and so on).
Let me know if I need to elaborate 🚀

@patrick-kidger
Copy link
Owner Author

I don't think either a step_back or a ys_history is needed.

Regarding the former: indeed, after a discontinuity we should place t0 at the floating-point number that comes after the discontinuity. This is one of the things that should be tied to made_jump. This currently happens for the PIDController here:

_t1 = jnp.where(made_jump, nextafter(nextafter(t1)), t1)

But we should hoist this out to happen in the main integrate.py. (So that it happens regardless of choice of stepsize controller.)

For ys_history: this is the same thing as is already stored in dense_infos.

@thibmonsel
Copy link
Contributor

Ok so this gives a little push to t1 if its on a discontinuities in order to continue the integration ?

_t1 = jnp.where(made_jump, nextafter(nextafter(t1)), t1)

I'm not sure to see how not having a step_back would ensure that we don't integrate from t0 to t1_discontinuous when the interval length is small enough to give bad estimates of y

And regarding ys_history how would we get accurate estimations of y(t-delay) at the start of our integration if we only use dense_infos (this variable gives the last interpolation of y right) ?

@thibmonsel
Copy link
Contributor

Here is a concrete example of the "problems" that for now i cant see how to fix (thats why i was proposing the ys_history and step_back.

Concrete example :

Issue with ys_history

def vector_field(t, y, args, *, history):
        return jnp.array(
            [
                1.8 * y[0] * (1 - history(t)[0]),
                1.8 * y[1] * (1 - history(t)[1])
    
            ]
        )

delay1 = lambda t, y, args: 3.0
delay2 = lambda t, y, args: 4.0
delays = [delay1, delay2]

initial_history = lambda t : jnp.array([1.2, 1.2])

For a given time step integration here from tprev to tnext.
The first equation will fetch values of y from tprev - 3 to tnext-3
The second equation will fetch values of y from tprev - 4 to tnext-4.

If we are at tprev= 3.5 and tnext= 3.6 :

The second equation of the vector field will go and use the initial_history function and the first will use the interpolated function of y thanks to dense_infos.

Issue for step back :

In our given example the discontinuites are located at multiples of 3 and 4. Let us suppose we integrate from tprev=5.99 to tnext=6.2. With the current state of implementation (tell me if im wrong), we would clip tnext to 6.0 and at the next step the integratation is from tprev=5.99 to tnext=6.0 and that results in instabilities.

@patrick-kidger
Copy link
Owner Author

Regarding ys_history.

So dense_infos actually has the entire history of the solve so far. (i.e. it's not just the previous step.)

However it is true that this doesn't contain the initial condition (the initial history function). Probably we'll need to do something like this:

# assume `y0_history` is passed in to integrate.py::loop

# as already added in this PR (line 147 at time of writing)
history = DenseInterpolation(...)
history_vals = []
for delay in delays:
    delay_val = delay(state.tprev, state.y, args)
    delay_val = state.tprev - delay_val  # whoops, forgot this line in my first draft!

    history_delay_val = jnp.max(delay_val, t0)
    history_val = history.evaluate(history_delay_val)

    y0_delay_val = jnp.min(delay_val, t0)
    y0_val = y0_history(y0_delay_val)

    history_val = jnp.where(delay_val < t0, y0_delay_val, history_val)
    history_val.append(history_val)
history_vals = tuple(history_vals)

Bit annoying that we need to evaluate y0_history on every step, ever once we're past the start, but c'est la vie. I think this is good enough for a first draft; I can think of some other possibly-too-clever ideas to try and work around this later.


Regarding step_back: thanks, that example is super useful. You're concerned about the dense interpolation over the region [5.99, 6] being a bad interpolation, due to floating point errors? Right, that's a valid concern. I think we can handle it with the existing variables though.

Let me start by explaining how discontinuity handling works at present.

First of all, Diffrax already has some support for discontinuity handling via PIDController(jump_ts=...). If our next step is proposed over the interval [a, b], and we have some discontinuity τ such that τ in jump_ts and a < τ < b, then the proposed next step is instead trimmed to happen over the interval [a, prevbefore(b)], where prevbefore(b) is the floating-point number immediately before b. Moreover, the next-next-step then happens over the interval [nextafter(b), ?], where now nextafter(b) is the floating-point number immediately after b.

Moving from prevbefore(b) to nextafter(b) is the purpose of the double-nextafter in the code snippet of my previous message.

Now we'd like DDE solvers to work regardless of the choice of stepsize controller, so I think we need to:

  1. Hoist the double-nextafter bit from PIDController into the main loop.
  2. Write something that looks a bit like this bit:

    that handles the prevbefore part.

Okay, moving on. Setting made_jump does a few things. Off the top of my head, I think it:

  1. Tells an FSAL solver that the F=L condition is broken, and to re-evaluate.
  2. Makes the double-nextafter bit happen in PIDController.
  3. Tells PIDController that the original stepsize b - a should be used when selecting the next step, not the trimmed stepsize prevbefore(b) - a. (Step size controllers work by proposing the next step as a multiple of the previous one.)

In practice we'll be hoisting 2. and I think want to leave 1. and 3. alone; we should just make sure to set made_jump = True so that they happen.

On to your point about numerical instability! Right, so this is an issue I've encountered before, at the very end of a differential equation solve. When doing an entire solve over the full interval [t0, t1], it may happen that we end up stepping to some t1 - ε, so that the very last step happens over [t1 - ε, t1], and indeed this means that the last piece of our dense interpolation produces wacky values. (Which breaks things like solution.evaluate(t1).)

This actually comes up frequently using fixed step size controllers. Even if analytically we have that dt divides t1 - t0, in practice floating point errors may mean that doing t0 + dt + dt + ... + dt == t1 - ε.

The solution is the code here:

def _clip_to_end(tprev, tnext, t1, keep_step):

which clips things slightly away from, or clips directly to, t1. (Note that this also needs a little care when rejecting steps, hence the dependence on keep_step.)

In practice this is only currently being used to hande clipping to the very end of the integration t1. I can see your point that we should probably also apply the same logic every time we clip to a discontinuity. (Both DDE discontinuities and PIDController(jump_ts=...) discontinuities, I think.

If you want to read a little more about the current clipping implementation, then this is discussed in #86 and #58.

@thibmonsel
Copy link
Contributor

thibmonsel commented Sep 30, 2022

Thanks for the helpful insight Patrick, I really appreciate it !

In your 4th paragraph after the history = DenseInterpolation(...) snippet did you mean ?

and we have some discontinuity τ such that τ in jump_ts and a < τ < b, then the proposed next step is instead trimmed to happen over the interval [a, prevbefore( τ )], where prevbefore( τ ) is the floating-point number immediately before $\tau$

and not

and we have some discontinuity τ such that τ in jump_ts and a < τ < b, then the proposed next step is instead trimmed to happen over the interval [a, prevbefore(b)], where prevbefore(b) is the floating-point number immediately before b.


dense_infos has the whole past information of the integration so we can use that to get y(t-delay) when t-delay > 0, if im not mistaken. I agree that we can fetch any value of y(t-tau) with history = DenseInterpolation(...) if t-tau >= 0.


However, I disagree on the instantiation of history_vals = tuple(history_vals) in order to get the estimate y of the new step. Depending on the integration schema this code snippet might work but not in all cases.

For an Euler schema :
ynext = ynext + hf(yprev) the history_vals will work to get ynext

For a ERK method :

yn+1=yn+1/6(K1+2K2+2K3+K4)
K1=hf(xn,yn)
K2=hf(xn+1/2h,yn+1/2K1)
K3=hf(xn+1/2h,yn+1/2K2)
K4=hf(xn+h,yn+K3)

history_vals will handle the computation of K1 but we still need new function evaluations of the history function with K2,K3,K4 and this can only be done by having an interpolated y(t-tau) that span both the negative and positive times.


Regarding Diffrax discontinuity handling, I can assume that jump_ts works well to handle discontinuities and does choose accordingly the "good" steps. However, in our case of general DDEs we dont have this jump_ts known in advance, we only know the existence of one when trying to take a new step if that makes sense. (One work around would be to change controller.jump_ts everytime we a new jump and restart the step ?)


I've been using the ConstantStepSize controller in order to get a first version of the code, since you talk about the PID controller with the made_jump would you recommend me switching to it ? It would depend on how you thought the controllers need to work i.e. would the ConstantStepSize handling discontinuites be incoherent with its definition of a constant stepsize controller ....

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Sep 30, 2022

b vs τ: correct; I should have written τ.

Regarding tuple(history_vals): agreed; you're right. We should still have a VectorFieldWrapper, but instead do:

### integrate.py
class _HistoryVectorField(eqx.Module):
  vector_field: Callable
  dense_interp: DenseInterpolation
  y0_history: Callable
  delays: Sequence[Callable]

  def __call__(self, t, y, args):
    history = ...  # implementation as above
    return self.vector_field(t, y, args, history=history)

### integrate.py::loop
is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)

def _apply_history(x):
  if is_vf_wrapper(x):
    vector_field = _HistoryVectorField(x.vector_field, dense_interp, y0_history, delays)
    return VectorFieldWrapper(vector_field)
  else:
    return x

terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)

Regarding jumps_ts: I think we should build in an independent/separate mechanism here. (And in particular not try to overwrite jump_ts.) As you say, we don't know the jump points in advance. Moreover this is something specific to PIDController and we would like to be able to solve DDEs with any stepsize controller.

Indeed I think it makes sense to solve DDEs with a fixed stepsize controller. For simplicity I propose that given a discontinuity τ in [t0+2dt, t0+3dt], then ConstantStepSize should place steps at e.g.

t0
t0+dt
t0+2dt
τ
τ+dt
τ+2dt
...

@thibmonsel
Copy link
Contributor

Roger that for ConstantStepSize implementation !


Regarding the _HistoryVectorField this seems to make sense on my end ! Hopefully i'll get something working soon enough !


Regarding discontinuity checking (ie jump_ts conversation from before).
I was considering to adding a new discontinuities argument in the _State class that is much like state.dense_ts that adds progressively the detected jumps by replacing a jnp.inf in the stack. Lmk if you had something else in mind.
👍

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Sep 30, 2022

Discontinuities: that is, a record of all discontinuites recorded, for the purposes of outputting this to the user as an additional statistic? [i.e. not for any internal purpose.] Sounds reasonable to me. Let's gate that on a new SaveAt(discontinuities=True).

@thibmonsel
Copy link
Contributor

thibmonsel commented Oct 3, 2022

The discontinuities need to be booked becomes they can give rise to new discontinuities in the general case #164 (comment) .. In the delays are constant for example here we can optimize better the code

@patrick-kidger
Copy link
Owner Author

Ah, agreed!

@thibmonsel
Copy link
Contributor

Hi again Patrick,
Would you mind telling me where you do the final clipping of your ys for dense outputs, i'll be needing it for the discontinuities and i can't seen to pin point its locating. Thanks in advance !

@patrick-kidger
Copy link
Owner Author

Sure thing. It's here:

index = jnp.clip(index - 1, a_min=0, a_max=maxlen)

To expand on what's going on here. We can't clip the size of the memory buffer. (JAX doesn't support such things.) So instead we pass a collection of buffers all of size max_steps in to DenseInterpolation:

infos: DenseInfos

and additionally specify how far through them we got:

And this is what is then used in the clipping that I first linked above.

@thibmonsel
Copy link
Contributor

thibmonsel commented Oct 4, 2022

Thanks, so your function

def _interpret_t(self, t: Scalar, left: bool) -> Tuple[Scalar, Scalar]:
gets the index of t in ts, afterwards slicing is done to get rid of the inf ? I know that slicing and dynamic shaped arrays are a pain in jax .........

All of the "magic" for array management is done within the GobalInterpolation class then ?

Not completely sure I have the same definition of clipping. I mean in clipping the inf doing the following operation :
[0, 1, 2, 3, inf, inf, inf] -> clipping [0, 1, 2, 3]

@patrick-kidger
Copy link
Owner Author

So dynamically shaped arrays don't exist in JAX. As such we don't do any clipping at all.

E.g. try running sol = diffeqsolve(..., saveat=SaveAt(steps=True)); print(sol.ys). This will save the output across all steps, but the number of steps isn't known until runtime. So the output will be of fixed length max_steps, regardless of how many steps are taken. (And padded with infs if needed.)

I suspect we will need to do something similar here. I think the maximal number of discontinuities we can encounter is also of size max_steps, and that likewise we should allocate a buffer of this size. My rationale is that every time we encounter a discontinuity we step to it directly, so if we have more discontinuities than steps then the solve will fail anyway.

@thibmonsel
Copy link
Contributor

Ok ill try that thanks to see what happens. My current issue is that with this state.discontinuity array that has a shape of max_steps and with the discontinuities, the current process uses this large array state.discontinuity to check out the new roots. This means that we are solving max_steps times a nonlinear solver at each step ... I have tried to instantiate a tuple for state.discontinuity but we can't vmap unless maybe if we create a custom container of PyTrees ?

As of now during initialisation state.discontinuity = jnp.full(max_steps + 1, jnp.inf) and a index is used to add new values.

@patrick-kidger
Copy link
Owner Author

I think instantiating a single array, and then performing a vmap'd nonlinear solve, is probably the correct approach.

I there's a few tricks we can use to speed this up.

For one thing: could we perform a cheap vmap'd bisection search until we've identified the first discontinuity, and then switch to Newton's method for just this first one?

(I recall DelayDiffEq.jl had a step where they evaluated over ten intermediate points, that might be relevant here.)

@thibmonsel
Copy link
Contributor

thibmonsel commented Oct 5, 2022

Definitely open to those ideas, doing a cheap bisection first could probably speed up the process !

Just in terms of speed if:

  • state.discontinuity =jnp.full(200 + 1, jnp.inf) the integration of a simple DDEs is 3.2s
  • state.discontinuity =jnp.full(max_steps + 1, jnp.inf)the integration here is 14s

recall DelayDiffEq.jl had a step where they evaluated over ten intermediate points, that might be relevant here.)

I'll take a look at the documentation/paper but if you have a reference to that I wont say no to that.

Btw, a contribution will be coming your way (not sure i got rights to push on delay branch) with :

  • a workable example of DDE for only ConstantStepSize and any solver to be used but limited since we cant do (yet) any adaptive step solvers.

Workable example

Modeling $x'(t) = 1.8 * x(t) * ( 1 - x(t-\tau))$ with $x(t&lt;0) = 1.2$

def vector_field(t, y, args, *, history):
        return 1.8 * y * (1 - history[0])
    
delay1 = lambda t, y, args: 2.0
delays = [delay1]
y0_history = lambda t: jnp.array([1.2])

y0 = y0_history(0.0)
history_val = y0_history(delay1(0.0, y0, None))
discontinuity = None
if (y0_history(0.0) != vector_field(0.0, y0, None, history=tuple(history_val))).any():
    discontinuity = (0.0,)


made_jump = discontinuity is None
t0, t1 = 0.0, 100.0
ts = jnp.linspace(t0, t1, 1000)

sol = diffrax.diffeqsolve(
    diffrax.ODETerm(vector_field),
    diffrax.Dopri5(),
    t0=ts[0],
    t1=ts[-1],
    dt0=ts[1] - ts[0],
    y0=y0,
    max_steps=10 ** 5,
    stepsize_controller=diffrax.ConstantStepSize(),
    adjoint=diffrax.NoAdjoint(),
    saveat=diffrax.SaveAt(ts=ts, dense=True),
    delays=delays,
    discontinuity=discontinuity,
    y0_history=y0_history,
    made_jump=made_jump,
)

Depending on the value of $\tau$ the problem stiffness cripples the integration (most likely because we only have a controller ConstantStepSize)

@patrick-kidger
Copy link
Owner Author

In terms of speed -- hmm, that's definitely an unfortunate slow-down. Let's see how well we can optimise this, and if needed we can introduce an additional max_discontinuities variable to control the size of this buffer.

Regarding the ten points in DelayDiffEq.jl -- see section 4.2 of https://arxiv.org/abs/2208.12879, in which they "check for sign changes ... at pre-defined number of equally spaced time points ... in the current time interval". I'm not sure where I got the number ten from; possibly somewhere else in the same paper or just something I'd read elsewhere.

"a contribution will be coming your way" -- excellent, I look forward to it! Open a pull request against this branch; we'll iterate here until this is ready to merge.

Regarding your code snippet: a few thoughts that come to mind looking at it:

  • We should make it that you pass y0=y0_history if delays is not None; i.e. no separate y0_history argument.
  • Thinking about it, there's no reason delays has to be specifically a tuple of callables. We should make this be a general PyTree of callables! Then for example someone could pass a dictionary delays={0.1: lambda t, y, args: 0.1} and then access this in their vector field as history[0.1], which is quite an elegant syntax. Or if they have only a single delay they could pass in delays=lambda t, y, args: ... and then access that as just history in the vector field.
  • FYI you should generally avoid passing in large values for max_steps. For deep technical reasons, the integration time will usually increase every time max_steps passes a power of 16. (As it just-so-happens, you won't see that here due to use adjoint=NoAdjoint(), which uses a different codepath, so I'm just mentioning this as a general good-to-know heads-up.)

@patrick-kidger
Copy link
Owner Author

Closing in favour of #169.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants