Skip to content

Commit

Permalink
ENH: Initial version of solve_ivp
Browse files Browse the repository at this point in the history
  • Loading branch information
nmayorov committed Apr 20, 2017
1 parent 7b41dc4 commit 5ec790f
Show file tree
Hide file tree
Showing 7 changed files with 879 additions and 0 deletions.
3 changes: 3 additions & 0 deletions scipy/integrate/__init__.py
Expand Up @@ -49,6 +49,8 @@
ode -- Integrate ODE using VODE and ZVODE routines.
complex_ode -- Convert a complex-valued ODE to real-valued and integrate.
solve_bvp -- Solve a boundary value problem for a system of ODEs.
solve_ivp -- Alternative routine for ODE integration with capabilities
similar to MATLAB.
"""
from __future__ import division, print_function, absolute_import

Expand All @@ -57,6 +59,7 @@
from .quadpack import *
from ._ode import *
from ._bvp import solve_bvp
from ._py import solve_ivp

__all__ = [s for s in dir() if not s.startswith('_')]
from numpy.testing import Tester
Expand Down
4 changes: 4 additions & 0 deletions scipy/integrate/_py/__init__.py
@@ -0,0 +1,4 @@
"""Suite of ODE solvers implemented in Python."""


from .ivp import solve_ivp
171 changes: 171 additions & 0 deletions scipy/integrate/_py/common.py
@@ -0,0 +1,171 @@
"""Common functions for ODE solvers."""
import numpy as np
from scipy.optimize import brentq, OptimizeResult


EPS = np.finfo(float).eps


class ODEResult(OptimizeResult):
pass


def norm(x):
"""Compute RMS norm."""
return np.linalg.norm(x) / x.size ** 0.5


def select_initial_step(fun, a, b, ya, fa, order, rtol, atol):
"""Empirically select a good initial step.
The algorithm is described in [1]_.
Parameters
----------
fun : callable
Right-hand side of the system.
a : float
Initial value of the independent variable.
b : float
Final value value of the independent variable.
ya : ndarray, shape (n,)
Initial value of the dependent variable.
fa : ndarray, shape (n,)
Initial value of the derivative, i. e. ``fun(x0, y0)``.
order : float
Method order.
rtol : float
Desired relative tolerance.
atol : float
Desired absolute tolerance.
Returns
-------
h_abs : float
Absolute value of the suggested initial step.
References
----------
.. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
Equations I: Nonstiff Problems", Sec. II.4.
"""
scale = atol + np.abs(ya) * rtol
d0 = norm(ya / scale)
d1 = norm(fa / scale)
if d0 < 1e-5 or d1 < 1e-5:
h0 = 1e-6
else:
h0 = 0.01 * d0 / d1

s = np.sign(b - a)
y1 = ya + h0 * s * fa
f1 = fun(a + h0 * s, y1)
d2 = norm((f1 - fa) / scale) / h0

if d1 <= 1e-15 and d2 <= 1e-15:
h1 = max(1e-6, h0 * 1e-3)
else:
h1 = (0.01 / max(d1, d2)) ** (1 / order)

return min(100 * h0, h1)


def get_active_events(g, g_new, direction):
"""Find which event occurred during an integration step.
Parameters
----------
g, g_new : array_like, shape (n_events,)
Values of event functions at a current and next points.
direction : ndarray, shape (n_events,)
Event "direction" according to definition in `solve_ivp`.
Returns
-------
active_events : ndarray
Indices of events which occurred during the step.
"""
g, g_new = np.asarray(g), np.asarray(g_new)
up = (g <= 0) & (g_new >= 0)
down = (g >= 0) & (g_new <= 0)
either = up | down
mask = (up & (direction > 0) |
down & (direction < 0) |
either & (direction == 0))

return np.nonzero(mask)[0]


def handle_events(sol, events, active_events, is_terminal, x, x_new):
"""Helper function to handle events.
Parameters
----------
sol : callable
Function ``sol(x)`` which evaluates an ODE solution.
events : list of callables, length n_events
Event functions.
active_events : ndarray
Indices of events which occurred
is_terminal : ndarray, shape (n_events,)
Which events are terminate.
x, x_new : float
Previous and new values of the independed variable, it will be used as
a bracketing interval.
Returns
-------
root_indices : ndarray
Indices of events which take zero before a possible termination.
roots : ndarray
Values of x at which events take zero values.
terminate : bool
Whether a termination event occurred.
"""
roots = []
for event_index in active_events:
roots.append(solve_event_equation(events[event_index], sol, x, x_new))

roots = np.asarray(roots)

if np.any(is_terminal[active_events]):
if x_new > x:
order = np.argsort(roots)
else:
order = np.argsort(-roots)
active_events = active_events[order]
roots = roots[order]
t = np.nonzero(is_terminal[active_events])[0][0]
active_events = active_events[:t + 1]
roots = roots[:t + 1]
terminate = True
else:
terminate = False

return active_events, roots, terminate


def solve_event_equation(event, sol, x, x_new):
"""Solve an equation corresponding to an ODE event.
The equation is ``event(x, y(x)) = 0``, here ``y(x)`` is known from an
ODE solver using some sort of interpolation. It is solved by
`scipy.optimize.brentq` with xtol=atol=4*EPS.
Parameters
----------
event : callable
Function ``event(x, y)``.
sol : callable
Computed solution ``y(x)``. It should be defined only between `x` and
`x_new`.
x, x_new : float
Previous and new values of the independed variable, it will be used as
a bracketing interval.
Returns
-------
root : float
Found solution.
"""
return brentq(lambda t: event(t, sol(t)), x, x_new, xtol=4 * EPS)
190 changes: 190 additions & 0 deletions scipy/integrate/_py/ivp.py
@@ -0,0 +1,190 @@
"""Generic interface for initial value problem solvers."""
from warnings import warn
import numpy as np
from .common import select_initial_step, EPS, ODEResult
from .rk import rk


METHOD_ORDER = {
'RK23': 3,
'RK45': 5
}


TERMINATION_MESSAGES = {
0: "The solver failed to reach the interval end or a termination event.",
1: "The solver successfully reached the interval end.",
2: "A termination event occurred."
}


def validate_tol(rtol, atol, n):
"""Validate tolerance values."""
if rtol < 100 * EPS:
warn("`rtol` is too low, setting to {}".format(100 * EPS))
rtol = 100 * EPS

atol = np.asarray(atol)
if atol.ndim > 0 and atol.shape != (n,):
raise ValueError("`atol` has wrong shape.")

if np.any(atol < 0):
raise ValueError("`atol` must be positive.")

return rtol, atol


def solve_ivp(fun, x_span, ya, rtol=1e-3, atol=1e-6, method='RK45',
events=None):
"""Solve an initial value problem for a system of ODEs.
This function numerically integrates a system of ODEs given an initial
value::
dy / dx = f(x, y)
y(a) = ya
Here x is a 1-dimensional independent variable, y(x) is a n-dimensional
vector-valued function and ya is a n-dimensional vector with initial
values.
Parameters
----------
fun : callable
Right-hand side of the system. The calling signature is ``fun(x, y)``.
Here ``x`` is a scalar, and ``y`` is ndarray with shape (n,). It
must return an array_like with shape (n,).
x_span : 2-tuple of floats
Interval of integration (a, b). The solver starts with x=a and
integrates until it reaches x=b.
ya : array_like, shape (n,)
Initial values for y.
rtol, atol : float and array_like, optional
Relative and absolute tolerances. The solver keeps the error estimates
less than ``atol` + rtol * abs(y)``. Here `rtol` controls a relative
accuracy (number of correct digits). But if a component of `y` is
approximately below `atol` then the error only needs to fall within
the same `atol` threshold, and the number of correct digits is not
guaranteed. If components of y have different scales, it might be
beneficial to set different `atol` values for different components by
passing array_like with shape (n,) for `atol`. Default values are
1e-3 for `rtol` and 1e-6 for `atol`.
method : string, optional
Integration method to use:
* 'RK45' (default): Explicit Runge-Kutta method of order 5 with an
automatic step size control [1]_. A 4-th order accurate quartic
polynomial is used for the continuous extension [2]_.
* 'RK23': Explicit Runge-Kutta method of order 3 with an automatic
step size control [3]_. A 3-th order accurate cubic Hermit
polynomial is used for the continuous extension.
events : callable, list of callables or None, optional
Events to track. Events are defined by functions which take
a zero value at a point of an event. Each function must have a
signature ``event(x, y)`` and return float, the solver will find an
accurate value of ``x`` at which ``event(x, y(x)) = 0`` using a root
finding algorithm. Additionally each ``event`` function might have
attributes:
* terminate : bool, whether to terminate integration if this
event occurs. Implicitly False if not assigned.
* direction : float, direction of crossing a zero. If `direction`
is positive then `event` must go from negative to positive, and
vice-versa if `direction` is negative. If 0, then either way will
count. Implicitly 0 if not assigned.
You can assign attributes like ``event.terminate = True`` to any
function in Python. If None (default), events won't be tracked.
Returns
-------
Bunch object with the following fields defined:
sol : PPoly
Found solution for y as `scipy.interpolate.PPoly` instance, a C1
continuous spline.
x : ndarray, shape (n_points,)
Values of the independent variable at which the solver made steps.
y : ndarray, shape (n, n_points)
Solution values at `x`.
yp : ndarray, shape (n, n_points)
Solution derivatives at `x`, i.e. ``fun(x, y)``.
x_events : ndarray, tuple of ndarray or None
Arrays containing values of x at each corresponding events was
detected. If `events` contained only 1 event, then `x_events` will
be ndarray itself. None if `events` was None.
status : int
Reason for algorithm termination:
* 0: The solver failed to reach the interval end or a termination
event.
* 1: The solver successfully reached the interval end.
* 2: A termination event occurred.
message : string
Verbal description of the termination reason.
success : bool
True if the solver reached the interval end or a termination event
(``status > 0``).
References
----------
.. [1] J. R. Dormand, P. J. Prince, "A family of embedded Runge-Kutta
formulae", Journal of Computational and Applied Mathematics, Vol. 6,
No. 1, pp. 19-26, 1980.
.. [2] L. W. Shampine, "Some Practical Runge-Kutta Formulas", Mathematics
of Computation,, Vol. 46, No. 173, pp. 135–150, 1986.
.. [3] P. Bogacki, L.F. Shampine, "A 3(2) Pair of Runge-Kutta Formulas",
Appl. Math. Lett. Vol. 2, No. 4. pp. 321-325, 1989.
"""
if method not in METHOD_ORDER:
raise ValueError("`method` must be 'RK23' or 'RK45'.")

ya = np.atleast_1d(ya)
if ya.ndim != 1:
raise ValueError("`ya` must be 1-dimensional.")

a, b = float(x_span[0]), float(x_span[1])
if a == b:
raise ValueError("Initial and final `x` must be distinct.")

def fun_wrapped(x, y):
return np.asarray(fun(x, y), dtype=float)

fa = fun_wrapped(a, ya)
if fa.shape != ya.shape:
raise ValueError("`fun` return is expected to have shape {}, "
"but actually has {}.".format(ya.shape, fa.shape))

if callable(events):
events = (events,)

h_abs = select_initial_step(fun, a, b, ya, fa, METHOD_ORDER[method],
rtol, atol)

if events is not None:
direction = np.empty(len(events))
is_terminal = np.empty(len(events), dtype=bool)
for i, event in enumerate(events):
try:
is_terminal[i] = event.terminate
except AttributeError:
is_terminal[i] = False

try:
direction[i] = event.direction
except AttributeError:
direction[i] = 0
else:
direction = None
is_terminal = None

max_step = 0.1 * np.abs(b - a)

status, sol, xs, ys, fs, x_events = rk(
fun_wrapped, a, b, ya, fa, h_abs, rtol, atol, max_step, method,
events, direction, is_terminal)

return ODEResult(sol=sol, x=xs, y=ys, yp=fs, x_events=x_events,
status=status, message=TERMINATION_MESSAGES[status],
success=status > 0)

0 comments on commit 5ec790f

Please sign in to comment.