Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for adaptive step size choice #372

Merged
merged 2 commits into from
Feb 20, 2024

Conversation

ParticularlyPythonicBS
Copy link
Contributor

PR to complete #371
Added example to PIDController with a simple pendulum example illustrating the effect of tolerance of adaptive step choice controller.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I think this'd be a great addition to the docs.

@@ -179,6 +179,52 @@ class PIDController(
common to refer to solving an equation to specific tolerances, without
necessarily stating which solver was used.)

!!! Example
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a long example, so let's maybe make it expandable with ???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds great!

```
We can integrate this using:
```python
def integrator(dynamics, state0, stepsize_controller):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good to make this example a little more concise. Perhaps something like this:

import diffrax as dfx

def dynamics(t, y, args):
    dtheta = y["omega"]
    domega = - jnp.sin(y["theta"])
    return dict(theta=dtheta, omega=domega)

y0 = dict(theta=0.1, omega=0)
term = dfx.ODETerm(dynamics)
sol = dfx.diffeqsolve(
    term, solver, t0=0, t1=1000, dt0=0.1, y0,
    saveat=dfx.SaveAts(ts=jnp.linspace(0, 1000, 10000),
    max_steps=2**20,
    stepsize_controller=...
)

WDYT?

(WIth a little ..., as we discuss different choices of that below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works well,
yeah the class was a vestige of the much larger system I simplified for this, but this is much better as an example

The phase portraits of the pendulum from the different tolerances clearly
illustrate the impact of the choice of `rtol` and `atol` on the accuracy of
the solution.
![Phase portrait of pendulum](../imgs/pendulum_adaptive_steps.png)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably wouldn't use "incorrect" in the title here, as it's suggestive of a bug. This is the numerical method doing exactly what's expected, namely a cheap low-accuracy solution!
Maybe something like "less accurate" and "more accurate" instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pendulum_adaptive_steps

Sounds good, here is the corrected plot.

@ParticularlyPythonicBS
Copy link
Contributor Author

Just following up to see if anything else is needed

@patrick-kidger
Copy link
Owner

Can you push your changes from the last round of comments?
After that, let's merge this! :)

@patrick-kidger patrick-kidger merged commit d97ba20 into patrick-kidger:main Feb 20, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Alright, merged -- thank you for the contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants