Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increasing divergence and oscillation between steps #60

Closed
rafaelvalle opened this issue Jun 14, 2019 · 13 comments
Closed

Increasing divergence and oscillation between steps #60

rafaelvalle opened this issue Jun 14, 2019 · 13 comments

Comments

@rafaelvalle
Copy link
Contributor

rafaelvalle commented Jun 14, 2019

Hey @rtqichen!

While doing experiments on a model, I found out that the L2 norm of the distance between subsequent states is increasing with the relative tolerance set to 0 and different values of the absolute tolerance.

I understand that that distance is proportionally upper bounded by atol + rtol * norm(state) but I would not expect it to grow because to me it suggests the solution is diverging or oscillating. For example, some state at step T is more similar to a state at step 0 than a state in step (0, T).

Any thoughts? I'm using the adaptive solver with RK45.
Here are some plots including the infinity norm of the state, the L2 norm of the distance between subsequent states and the respective error threshold.

Absolute tolerance 1e-5, Relative tolerance 0
atol1e-5rtol0

Absolute tolerance 1e-9, Relative tolerance 0
atol1e-9rtol0

Absolute tolerance 1e-10, Relative tolerance 0
atol1e-10rtol0

@rafaelvalle rafaelvalle changed the title Increasing divergence between subsequent steps / diverging solution Increasing divergence and oscillation between steps Jun 14, 2019
@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented Jun 14, 2019

Here's a plot of the difference between numerical solutions of absolute error tolerance value pairs in decreasing order, starting at approximately (1e-5, 1e-5 - e) and ending at (1e-10, 1e-10 -e).
I expected the distance between these numerical solutions to decrease proportionally to the decrease in absolute error, not increase. RK45.
image

@rtqichen
Copy link
Owner

rtqichen commented Jun 17, 2019

distance is proportionally upper bounded by atol + rtol * norm(state)

The distance between consecutive states isn't bounded as it depends on the learned f. What atol + rtol * norm(state) bounds is the error estimate, which is an estimate of the 5th-order derivative of x(t). (This is the error of a truncated Taylor series, which is what RK45 uses for each step of the solver.)

Also, a diverging ODE isn't really a problem unless you want to integrate to t=infinity. Sometimes the learned ODE learns to expand (e.g. if transforming something bounded in [0,1] to something in a much larger range of values).

step T is more similar to a state at step 0 than a state in step (0, T)

Hmm.. have you verified this? This would suggest that the steps (0, T) are not being useful, since a constant trajectory would've sufficed to transform x(0) to x(T). It's possible that a part of the state behaves this way in order to transform another part of the state, which is perfectly valid. But this shouldn't occur for the entire state, especially if you've used zero initialization for the ODE function.

difference between numerical solutions of absolute error tolerance value pairs in decreasing order

I don't think this pair-wise difference is very interpretable, as the solutions could just be oscillating around the same value. It would be more meaningful to always compare to the same solution (e.g. with tolerances set to 1e-10). This way you should see that the solutions are getting closer and closer to the 1e-10 estimate when decreasing tolerance.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented Jun 18, 2019

step T is more similar to a state at step 0 than a state in step (0, T)
Hmm.. have you verified this? Yes, take a look at the plots below. This behavior diminishes with smaller error tolerances. And, yes, I've used zero initialization for the ODE function.

The distance matrices below show that for relatively large error tolerances, the last state is more similar to states that are farther away in time from it.
state_dist_mat

I also find it unexpected that at higher error tolerances the Neural ODE passes through states that are closer to the target solution than the final solution itself.

For example, the plots below show the L2 distance between the current state and the target state(ground truth). In multiple cases, states before the final state are closer to the target state than the final state.
dist_from_target

@rtqichen
Copy link
Owner

rtqichen commented Jun 18, 2019

Hmm, correct me if I'm wrong, but these plots look like you're logging the states where f is evaluated, rather than using the odeint API to query the intermediate states? You want to be doing the latter, not the former, when analyzing states along the trajectory.

One thing to note that the RK45 solver evaluates f at intermediate probes that are not along the actual trajectory, and these probes come in bursts of 5 in this loop (https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/rk_common.py#L49).

Every 6 evaluates contains the first last one that is along the path, then 5 that are probes. This means the last evaluation isn't on the path.

The initial step size selection evaluates f twice (https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/misc.py#L84), so the RK45 algorithm always performs 2+7*steps 2+6*steps evaluations of f.

Instead of logging the states when f is evaluated, can you use the odeint API to query what the intermediate states are? These states will be along the actual estimated trajectory.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented Jun 19, 2019

Yes! With the exception of the last state, the states are the inputs to f, including the initial value.
The last state on these plots is the output of the odeint API at time 1 for integration time [0, 1] just like the mnist example. It should be on the path, better yet, the numerical solution itself.

RK45 algorithm always performs 2+7*steps evaluations of f.
In the L2 plot with absolute error 0.1, step 0 is the initial value, steps 1 and 2 belong to the initial step size selection, steps 3 through 7 and 9 through 13 are not on the trajectory, and steps 8 and 14 are on the trajectory, correct?

Is it possible to interact with the odeint API to query states along the estimated trajectory instead of the states evaluated at the integration time, in my case 0 and 1?

By the way, thank you for your help. It is extremely valuable!

@rtqichen
Copy link
Owner

Yes, do odeint(func, x0, t=torch.linspace(0, 1, 50)). This will return 50 points along the trajectory (at no extra cost compared to solving t=[0, 1] in terms of the number of func evaluations). This will give you actual states along the estimated path. The settings where you use a loose tolerance will likely look closer to when you use a tighter tolerance param.

No worries! Happy to help.

@rtqichen
Copy link
Owner

rtqichen commented Jun 20, 2019

I wrote a few mistakes in my earlier answer. The number of evaluations should be 2 + 6*steps. The last evaluation is the last state. I've corrected these mistakes in the previous reply. Apologies if I've confused you.

To clarify which evaluations during the solver are actually on the estimated path:

0  1 |  2  3  4  5  6  7 |  8  9  10 12 13 14
o  x |  x  x  x  x  x  o |  x  x  x  x  x  o

The evaluations indicated by o are those on the path.

The first two are for selecting the initial step size. This solver actually only takes 2 "steps", but evaluated 6 times per step.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented Jun 20, 2019

No sweat at all, the explanation is very clear now.
Regarding odeint(func, x0, t=torch.linspace(0, 1, 50)), I do not understand the interaction between providing step sizes of 1/50 and using adaptive step solvers like RK45 from our discussion.
Can you point me to resources that would explain this? Besides reading the code in this repo :-)

@rtqichen
Copy link
Owner

rtqichen commented Jun 20, 2019

The input t=torch.linspace(0, 1, 50) specifies what times should the ode solver output. The ODE solver will always integrate from min t (0) to max t (1), and the intermediate values of t have no effect on how the ODE the solved. Instead, the t argument is very useful if you ever want to know what the intermediate values are (e.g. for plotting). This comes at a very small additional cost and doesn't even use more evaluations of func.

(This follows the same API as scipy.integrate.odeint https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.odeint.html which was used in our earliest experiments.)

At some point I should update the README to contain this information, but I've been lazy..

If you want to understand the finer details of how ODEs can be solved, "Solving Ordinary Differential Equations I Nonstiff Problems" by Hairer et al. is pretty good and describes the details of RK. The method adams was based off of their section III.5.

@rafaelvalle
Copy link
Contributor Author

Great, that makes everything very clear!
I'll try to compile a few of the explanations you gave in the issues on this repo and possibly put a PR.

@rtqichen
Copy link
Owner

That would be very useful! (But please don't feel obligated to.)

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented Jul 16, 2019

@rtqichen Just a heads up that I'll soon send a PR with a FAQ addressing common issues and questions related to Neural ODEs, using your answer whenever possible.

@rafaelvalle
Copy link
Contributor Author

I wrote a few mistakes in my earlier answer. The number of evaluations should be 2 + 6*steps. The last evaluation is the last state. I've corrected these mistakes in the previous reply. Apologies if I've confused you.

To clarify which evaluations during the solver are actually on the estimated path:

0  1 |  2  3  4  5  6  7 |  8  9  10 12 13 14
o  x |  x  x  x  x  x  o |  x  x  x  x  x  o

The evaluations indicated by o are those on the path.

The first two are for selecting the initial step size. This solver actually only takes 2 "steps", but evaluated 6 times per step.

Hey Ricky (@rtqichen) , I just realized there's a typo in this example: The count jumps from 10 to 12...
I'll update it on the FAQ as well and put a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants