Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to include control variables? #128

Closed
juliuskittler opened this issue Oct 13, 2020 · 1 comment
Closed

How to include control variables? #128

juliuskittler opened this issue Oct 13, 2020 · 1 comment

Comments

@juliuskittler
Copy link

juliuskittler commented Oct 13, 2020

Many thanks for your great work. The library is really solving a relevant problem. It's working well for me for ODE's purely based on state variables. However, I am facing some issues when trying to include control variables u in addition to state variables x. Although this is working well for the forward simulation (i.e. prediction of x by applying the odeint function once), it is not working when I try to fit parameters (i.e. multiple iterations, in each of which there is a forward simulation followed by a pytorch optimization step). By that, I mean that I am getting a very bad fit and during optimization the losses can go up.

Hence, my main question is: What is your recommended way to include control variables u into the ODE if the goal is to fit the parameters of the ODE?


To elaborate, I am sharing the following reproducible example (Google Colab Notebook) with you, which shows how I have tried to include control variables u so far: https://drive.google.com/file/d/1_F80gHZG5tJzKKWU4Ku-po_iroo2B1kz/view?usp=sharing

torchdiffeq.odeint expects as first argument the system function, which is a function of time t and states x (https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/odeint.py). In order to handle controls, I have defined this function s.t. it internally gets the controls u corresponding to the input t. For this, I have used an interpolation function (based on a historic sequence of u values and corresponding timepoints t). I wrote a custom interpolation function because existing functions generally use numpy (which results in errors if the inputs are tensors that require gradients).

As you can see, the fitted parameters are not better than the initial parameters when I apply this approach. Once I remove the control u (and replace it with state x), then fitting works well.

@rtqichen
Copy link
Owner

rtqichen commented Oct 15, 2020

So the way you're handling u is okay, but the biggest problem is that u in this example has discontinuities.

Our ODE solvers rely on Taylor series, so x(t) needs to have enough continuous derivatives. But with a discontinuous u(t), the first derivative of x is discontinuous (in general this should be true; I didn't fully delve into how u affects x in your example).

To test this, I changed

u_fun = lambda t: input_interp(t, u_time, u_vals)

to

u_fun = lambda t: torch.sin(t) * 0.5 + 0.5

then, just for sanity, used a finer grid timepoints = torch.linspace(0, 30, 100) and larger learning rate optimizer = torch.optim.Adam(params, lr = 0.01). The loss then goes down and I think it's converging to the true solution:

Iter 991/1000 - Loss: 0.025541825219988823 - Time: 0s
[9.422207832336426, 9.926614761352539, 0.4010310173034668]

There's a simple fix to this! Provided that you know exactly where the discontinuities are.

We can view the combined system [u, x](t). We can still solve for this ODE, we just can't rely on the solver to handle discontinuities in the state. The grid_points option is used for exactly this: for select times values, the solver will restart its Taylor series approximation and never attempt to use a Taylor approximation across these values. I kept the timepoints and optimizer changes, changed u_fun back to the original, and changed

x_pred = torchdiffeq.odeint(rhs_fun, x0, timepoints, method = "dopri5")

to

x_pred = torchdiffeq.odeint(rhs_fun, x0, timepoints, method = "dopri5", options={"grid_points": torch.tensor(u_time), "eps": 1e-6})

Optimization then proceeds okay:

Iter 451/1000 - Loss: 0.2769869863986969 - Time: 0s
[9.268129348754883, 10.511204719543457, 0.25258389115333557]

The small gotcha is that there's an eps argument that basically ensures you never evaluate anywhere within an epsilon-ball of the time value where the discontinuities occur. This can add some numerical error, but it shouldn't be a problem. (Also, in your particular example, I could also run with eps being 0.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants