From 5ec790f8a975483784a96affa34cf61025936aa7 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Tue, 28 Jun 2016 17:28:12 +0500 Subject: [PATCH] ENH: Initial version of solve_ivp --- scipy/integrate/__init__.py | 3 + scipy/integrate/_py/__init__.py | 4 + scipy/integrate/_py/common.py | 171 +++++++++++++++ scipy/integrate/_py/ivp.py | 190 ++++++++++++++++ scipy/integrate/_py/rk.py | 345 ++++++++++++++++++++++++++++++ scipy/integrate/setup.py | 3 + scipy/integrate/tests/test_ivp.py | 163 ++++++++++++++ 7 files changed, 879 insertions(+) create mode 100644 scipy/integrate/_py/__init__.py create mode 100644 scipy/integrate/_py/common.py create mode 100644 scipy/integrate/_py/ivp.py create mode 100644 scipy/integrate/_py/rk.py create mode 100644 scipy/integrate/tests/test_ivp.py diff --git a/scipy/integrate/__init__.py b/scipy/integrate/__init__.py index bc40adff7941..d3ec99129e5b 100644 --- a/scipy/integrate/__init__.py +++ b/scipy/integrate/__init__.py @@ -49,6 +49,8 @@ ode -- Integrate ODE using VODE and ZVODE routines. complex_ode -- Convert a complex-valued ODE to real-valued and integrate. solve_bvp -- Solve a boundary value problem for a system of ODEs. + solve_ivp -- Alternative routine for ODE integration with capabilities + similar to MATLAB. """ from __future__ import division, print_function, absolute_import @@ -57,6 +59,7 @@ from .quadpack import * from ._ode import * from ._bvp import solve_bvp +from ._py import solve_ivp __all__ = [s for s in dir() if not s.startswith('_')] from numpy.testing import Tester diff --git a/scipy/integrate/_py/__init__.py b/scipy/integrate/_py/__init__.py new file mode 100644 index 000000000000..d5c102ca48ce --- /dev/null +++ b/scipy/integrate/_py/__init__.py @@ -0,0 +1,4 @@ +"""Suite of ODE solvers implemented in Python.""" + + +from .ivp import solve_ivp diff --git a/scipy/integrate/_py/common.py b/scipy/integrate/_py/common.py new file mode 100644 index 000000000000..e4c7afd2148b --- /dev/null +++ b/scipy/integrate/_py/common.py @@ -0,0 +1,171 @@ +"""Common functions for ODE solvers.""" +import numpy as np +from scipy.optimize import brentq, OptimizeResult + + +EPS = np.finfo(float).eps + + +class ODEResult(OptimizeResult): + pass + + +def norm(x): + """Compute RMS norm.""" + return np.linalg.norm(x) / x.size ** 0.5 + + +def select_initial_step(fun, a, b, ya, fa, order, rtol, atol): + """Empirically select a good initial step. + + The algorithm is described in [1]_. + + Parameters + ---------- + fun : callable + Right-hand side of the system. + a : float + Initial value of the independent variable. + b : float + Final value value of the independent variable. + ya : ndarray, shape (n,) + Initial value of the dependent variable. + fa : ndarray, shape (n,) + Initial value of the derivative, i. e. ``fun(x0, y0)``. + order : float + Method order. + rtol : float + Desired relative tolerance. + atol : float + Desired absolute tolerance. + + Returns + ------- + h_abs : float + Absolute value of the suggested initial step. + + References + ---------- + .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential + Equations I: Nonstiff Problems", Sec. II.4. + """ + scale = atol + np.abs(ya) * rtol + d0 = norm(ya / scale) + d1 = norm(fa / scale) + if d0 < 1e-5 or d1 < 1e-5: + h0 = 1e-6 + else: + h0 = 0.01 * d0 / d1 + + s = np.sign(b - a) + y1 = ya + h0 * s * fa + f1 = fun(a + h0 * s, y1) + d2 = norm((f1 - fa) / scale) / h0 + + if d1 <= 1e-15 and d2 <= 1e-15: + h1 = max(1e-6, h0 * 1e-3) + else: + h1 = (0.01 / max(d1, d2)) ** (1 / order) + + return min(100 * h0, h1) + + +def get_active_events(g, g_new, direction): + """Find which event occurred during an integration step. + + Parameters + ---------- + g, g_new : array_like, shape (n_events,) + Values of event functions at a current and next points. + direction : ndarray, shape (n_events,) + Event "direction" according to definition in `solve_ivp`. + + Returns + ------- + active_events : ndarray + Indices of events which occurred during the step. + """ + g, g_new = np.asarray(g), np.asarray(g_new) + up = (g <= 0) & (g_new >= 0) + down = (g >= 0) & (g_new <= 0) + either = up | down + mask = (up & (direction > 0) | + down & (direction < 0) | + either & (direction == 0)) + + return np.nonzero(mask)[0] + + +def handle_events(sol, events, active_events, is_terminal, x, x_new): + """Helper function to handle events. + + Parameters + ---------- + sol : callable + Function ``sol(x)`` which evaluates an ODE solution. + events : list of callables, length n_events + Event functions. + active_events : ndarray + Indices of events which occurred + is_terminal : ndarray, shape (n_events,) + Which events are terminate. + x, x_new : float + Previous and new values of the independed variable, it will be used as + a bracketing interval. + + Returns + ------- + root_indices : ndarray + Indices of events which take zero before a possible termination. + roots : ndarray + Values of x at which events take zero values. + terminate : bool + Whether a termination event occurred. + """ + roots = [] + for event_index in active_events: + roots.append(solve_event_equation(events[event_index], sol, x, x_new)) + + roots = np.asarray(roots) + + if np.any(is_terminal[active_events]): + if x_new > x: + order = np.argsort(roots) + else: + order = np.argsort(-roots) + active_events = active_events[order] + roots = roots[order] + t = np.nonzero(is_terminal[active_events])[0][0] + active_events = active_events[:t + 1] + roots = roots[:t + 1] + terminate = True + else: + terminate = False + + return active_events, roots, terminate + + +def solve_event_equation(event, sol, x, x_new): + """Solve an equation corresponding to an ODE event. + + The equation is ``event(x, y(x)) = 0``, here ``y(x)`` is known from an + ODE solver using some sort of interpolation. It is solved by + `scipy.optimize.brentq` with xtol=atol=4*EPS. + + Parameters + ---------- + event : callable + Function ``event(x, y)``. + sol : callable + Computed solution ``y(x)``. It should be defined only between `x` and + `x_new`. + x, x_new : float + Previous and new values of the independed variable, it will be used as + a bracketing interval. + + Returns + ------- + root : float + Found solution. + """ + return brentq(lambda t: event(t, sol(t)), x, x_new, xtol=4 * EPS) diff --git a/scipy/integrate/_py/ivp.py b/scipy/integrate/_py/ivp.py new file mode 100644 index 000000000000..68a4db36b43d --- /dev/null +++ b/scipy/integrate/_py/ivp.py @@ -0,0 +1,190 @@ +"""Generic interface for initial value problem solvers.""" +from warnings import warn +import numpy as np +from .common import select_initial_step, EPS, ODEResult +from .rk import rk + + +METHOD_ORDER = { + 'RK23': 3, + 'RK45': 5 +} + + +TERMINATION_MESSAGES = { + 0: "The solver failed to reach the interval end or a termination event.", + 1: "The solver successfully reached the interval end.", + 2: "A termination event occurred." +} + + +def validate_tol(rtol, atol, n): + """Validate tolerance values.""" + if rtol < 100 * EPS: + warn("`rtol` is too low, setting to {}".format(100 * EPS)) + rtol = 100 * EPS + + atol = np.asarray(atol) + if atol.ndim > 0 and atol.shape != (n,): + raise ValueError("`atol` has wrong shape.") + + if np.any(atol < 0): + raise ValueError("`atol` must be positive.") + + return rtol, atol + + +def solve_ivp(fun, x_span, ya, rtol=1e-3, atol=1e-6, method='RK45', + events=None): + """Solve an initial value problem for a system of ODEs. + + This function numerically integrates a system of ODEs given an initial + value:: + + dy / dx = f(x, y) + y(a) = ya + + Here x is a 1-dimensional independent variable, y(x) is a n-dimensional + vector-valued function and ya is a n-dimensional vector with initial + values. + + Parameters + ---------- + fun : callable + Right-hand side of the system. The calling signature is ``fun(x, y)``. + Here ``x`` is a scalar, and ``y`` is ndarray with shape (n,). It + must return an array_like with shape (n,). + x_span : 2-tuple of floats + Interval of integration (a, b). The solver starts with x=a and + integrates until it reaches x=b. + ya : array_like, shape (n,) + Initial values for y. + rtol, atol : float and array_like, optional + Relative and absolute tolerances. The solver keeps the error estimates + less than ``atol` + rtol * abs(y)``. Here `rtol` controls a relative + accuracy (number of correct digits). But if a component of `y` is + approximately below `atol` then the error only needs to fall within + the same `atol` threshold, and the number of correct digits is not + guaranteed. If components of y have different scales, it might be + beneficial to set different `atol` values for different components by + passing array_like with shape (n,) for `atol`. Default values are + 1e-3 for `rtol` and 1e-6 for `atol`. + method : string, optional + Integration method to use: + + * 'RK45' (default): Explicit Runge-Kutta method of order 5 with an + automatic step size control [1]_. A 4-th order accurate quartic + polynomial is used for the continuous extension [2]_. + * 'RK23': Explicit Runge-Kutta method of order 3 with an automatic + step size control [3]_. A 3-th order accurate cubic Hermit + polynomial is used for the continuous extension. + + events : callable, list of callables or None, optional + Events to track. Events are defined by functions which take + a zero value at a point of an event. Each function must have a + signature ``event(x, y)`` and return float, the solver will find an + accurate value of ``x`` at which ``event(x, y(x)) = 0`` using a root + finding algorithm. Additionally each ``event`` function might have + attributes: + + * terminate : bool, whether to terminate integration if this + event occurs. Implicitly False if not assigned. + * direction : float, direction of crossing a zero. If `direction` + is positive then `event` must go from negative to positive, and + vice-versa if `direction` is negative. If 0, then either way will + count. Implicitly 0 if not assigned. + + You can assign attributes like ``event.terminate = True`` to any + function in Python. If None (default), events won't be tracked. + + Returns + ------- + Bunch object with the following fields defined: + sol : PPoly + Found solution for y as `scipy.interpolate.PPoly` instance, a C1 + continuous spline. + x : ndarray, shape (n_points,) + Values of the independent variable at which the solver made steps. + y : ndarray, shape (n, n_points) + Solution values at `x`. + yp : ndarray, shape (n, n_points) + Solution derivatives at `x`, i.e. ``fun(x, y)``. + x_events : ndarray, tuple of ndarray or None + Arrays containing values of x at each corresponding events was + detected. If `events` contained only 1 event, then `x_events` will + be ndarray itself. None if `events` was None. + status : int + Reason for algorithm termination: + + * 0: The solver failed to reach the interval end or a termination + event. + * 1: The solver successfully reached the interval end. + * 2: A termination event occurred. + + message : string + Verbal description of the termination reason. + success : bool + True if the solver reached the interval end or a termination event + (``status > 0``). + + References + ---------- + .. [1] J. R. Dormand, P. J. Prince, "A family of embedded Runge-Kutta + formulae", Journal of Computational and Applied Mathematics, Vol. 6, + No. 1, pp. 19-26, 1980. + .. [2] L. W. Shampine, "Some Practical Runge-Kutta Formulas", Mathematics + of Computation,, Vol. 46, No. 173, pp. 135–150, 1986. + .. [3] P. Bogacki, L.F. Shampine, "A 3(2) Pair of Runge-Kutta Formulas", + Appl. Math. Lett. Vol. 2, No. 4. pp. 321-325, 1989. + """ + if method not in METHOD_ORDER: + raise ValueError("`method` must be 'RK23' or 'RK45'.") + + ya = np.atleast_1d(ya) + if ya.ndim != 1: + raise ValueError("`ya` must be 1-dimensional.") + + a, b = float(x_span[0]), float(x_span[1]) + if a == b: + raise ValueError("Initial and final `x` must be distinct.") + + def fun_wrapped(x, y): + return np.asarray(fun(x, y), dtype=float) + + fa = fun_wrapped(a, ya) + if fa.shape != ya.shape: + raise ValueError("`fun` return is expected to have shape {}, " + "but actually has {}.".format(ya.shape, fa.shape)) + + if callable(events): + events = (events,) + + h_abs = select_initial_step(fun, a, b, ya, fa, METHOD_ORDER[method], + rtol, atol) + + if events is not None: + direction = np.empty(len(events)) + is_terminal = np.empty(len(events), dtype=bool) + for i, event in enumerate(events): + try: + is_terminal[i] = event.terminate + except AttributeError: + is_terminal[i] = False + + try: + direction[i] = event.direction + except AttributeError: + direction[i] = 0 + else: + direction = None + is_terminal = None + + max_step = 0.1 * np.abs(b - a) + + status, sol, xs, ys, fs, x_events = rk( + fun_wrapped, a, b, ya, fa, h_abs, rtol, atol, max_step, method, + events, direction, is_terminal) + + return ODEResult(sol=sol, x=xs, y=ys, yp=fs, x_events=x_events, + status=status, message=TERMINATION_MESSAGES[status], + success=status > 0) diff --git a/scipy/integrate/_py/rk.py b/scipy/integrate/_py/rk.py new file mode 100644 index 000000000000..b1140208358b --- /dev/null +++ b/scipy/integrate/_py/rk.py @@ -0,0 +1,345 @@ +"""Explicit Runge-Kutta methods.""" +import numpy as np +from scipy.interpolate import PPoly +from .common import get_active_events, handle_events, norm + +# Algorithm parameters. + +# Multiply steps computed from asymptotic behaviour of errors by this. +SAFETY = 0.9 + +MAX_FACTOR = 5 # Maximum allowed increase in a step size. +MIN_FACTOR = 0.2 # Minimum allowed decrease in a step size. + +# Butcher tables. See `rk_step` for explanation. + +# Bogacki–Shampine scheme. +C23 = np.array([1/2, 3/4]) +A23 = [np.array([1/2]), + np.array([0, 3/4])] +B23 = np.array([2/9, 1/3, 4/9]) +# Coefficients for estimation errors. The difference between B's for lower +# and higher order accuracy methods. +E23 = np.array([5/72, -1/12, -1/9, 1/8]) + +# Dormand–Prince scheme. +C45 = np.array([1/5, 3/10, 4/5, 8/9, 1]) +A45 = [np.array([1/5]), + np.array([3/40, 9/40]), + np.array([44/45, -56/15, 32/9]), + np.array([19372/6561, -25360/2187, 64448/6561, -212/729]), + np.array([9017/3168, -355/33, 46732/5247, 49/176, -5103/18656])] +B45 = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]) +E45 = np.array([-71/57600, 0, 71/16695, -71/1920, 17253/339200, -22/525, 1/40]) + +# Coefficients to compute y(x + 0.5 * h) from RK stages with a 4-rd order +# accuracy. Then it can be used for quartic interpolation with a 4-rd order +# accuracy. +M45 = np.array([613/3072, 0, 125/159, -125/1536, 8019/54272, -11/96, 1/16]) + + +def prepare_method(method, n, s): + """Choose appropriate matrices for a RK method. + + See `rk_step` for the explanation of returned matrices. + """ + if method == 'RK45': + A = A45 + B = B45 + C = C45 + E = E45 + M = M45 + K = np.empty((7, n)) + order = 5 + elif method == 'RK23': + A = A23 + B = B23 + C = C23 + E = E23 + M = None + order = 3 + K = np.empty((4, n)) + else: + raise ValueError("`method` must be 'RK45' or 'RK23'.") + + return A, B, C, E, M, K, order + + +def rk_step(fun, x, y, f, h, A, B, C, E, K): + """Perform a single Runge-Kutta step. + + This function computes a prediction of an explicit Runge-Kutta method and + also estimates the error of a less accurate method. + + Notation for Butcher tableau is as in [1]_. + + Parameters + ---------- + fun : callable + Right-hand side of the system. + x : float + Current value of the independent variable. + y : ndarray, shape (n,) + Current value of the solution. + f : ndarray, shape (n,) + Current value of the derivative of the solution, i.e. ``fun(x, y)``. + h : float, shape (n,) + Step for x to use. + A : list of ndarray, length n_stages - 1 + Coefficients for combining previous RK stages for computing the next + stage. For explicit methods the coefficients above the main diagonal + are zeros, so they are stored as a list of arrays of increasing + lengths. The first stage is always just `f`, thus no coefficients are + required. + B : ndarray, shape (n_stages,) + Coefficients for combining RK stages for computing the final + prediction. + C : ndarray, shape (n_stages - 1,) + Coefficients for incrementing x for computing RK stages. The value for + the first stage is always zero, thus it is not stored. + E : ndarray, shape (n_stages + 1,) + Coefficients for estimating the error of a less accurate method. They + are computed as the difference between b's in an extended tableau. + K : ndarray, shape (n_stages + 1, n) + Storage array for putting RK stages here. Stages are stored in rows. + + Returns + ------- + y_new : ndarray, shape (n,) + Solution at x + h computed with a higher accuracy. + f_new : ndarray, shape (n,) + Derivative ``fun(x + h, y_new)``. + error : ndarray, shape (n,) + Error estimate. + + References + ---------- + .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential + Equations I: Nonstiff Problems", Sec. II.4. + """ + K[0] = f + for s, (a, c) in enumerate(zip(A, C)): + dy = np.dot(K[:s + 1].T, a) * h + K[s + 1] = fun(x + c * h, y + dy) + + y_new = y + h * np.dot(K[:-1].T, B) + f_new = fun(x + h, y_new) + + K[-1] = f_new + error = np.dot(K.T, E) * h + + return y_new, f_new, error + + +def create_spline(x, y, f, ym): + """Create a cubic or quartic spline given values and derivatives. + + Parameters + ---------- + x : ndarray, shape (n_points,) + Values of the independent variable. + y : ndarray, shape (n_points, n) + Values of the dependent variable at `x`. + f : ndarray, shape (n_points, n) + Values of the derivatives of `y` evaluated at `x`. + ym : ndarray with shape (n_points, n) or None + Values of the dependent variables at middle points between values + of `x`. If None, a cubic spline will be constructed, and a quartic + spline otherwise. + + Returns + ------- + sol : PPoly + Constructed spline as a PPoly instance. + """ + from scipy.interpolate import PPoly + + if x[-1] < x[0]: + x = x[::-1] + y = y[::-1] + if ym is not None: + ym = ym[::-1] + f = f[::-1] + + h = np.diff(x) + + y0 = y[:-1] + y1 = y[1:] + f0 = f[:-1] + f1 = f[1:] + + n_points, n = y.shape + h = h[:, None] + if ym is None: + c = np.empty((4, n_points - 1, n)) + slope = (y1 - y0) / h + t = (f0 + f1 - 2 * slope) / h + c[0] = t / h + c[1] = (slope - f0) / h - t + c[2] = f0 + c[3] = y0 + else: + c = np.empty((5, n_points - 1, n)) + c[0] = (-8 * y0 - 8 * y1 + 16 * ym) / h**4 + (- 2 * f0 + 2 * f1) / h**3 + c[1] = (18 * y0 + 14 * y1 - 32 * ym) / h**3 + (5 * f0 - 3 * f1) / h**2 + c[2] = (-11 * y0 - 5 * y1 + 16 * ym) / h**2 + (-4 * f0 + f1) / h + c[3] = f0 + c[4] = y0 + + c = np.rollaxis(c, 2) + return PPoly(c, x, extrapolate=True, axis=1) + + +def create_spline_one_step(x, x_new, y, y_new, f, f_new, ym): + """Create a spline for a single step. + + Parameters + ---------- + x, x_new : float + Previous and new values of the independed variable. + y, y_new : float + Previous and new values of the dependent variable. + f, f_new : float + Previous and new values of the derivative of the dependent variable. + ym : float or None + Value of the dependent variable at the middle point between `x` and + `x_new`. If provided the quartic spline is constructed, if None + the cubic spline is constructed. + + Returns + ------- + sol : PPoly + Constructed spline as a PPoly instance. + """ + if x_new < x: + x0, x1 = x_new, x + y0, y1 = y_new, y + f0, f1 = f_new, f + else: + x0, x1 = x, x_new + y0, y1 = y, y_new + f0, f1 = f, f_new + + h = x1 - x0 + n = y.shape[0] + if ym is None: + c = np.empty((4, 1, n), dtype=y.dtype) + slope = (y1 - y0) / h + t = (f0 + f1 - 2 * slope) / h + c[0] = t / h + c[1] = (slope - f0) / h - t + c[2] = f0 + c[3] = y0 + else: + c = np.empty((5, 1, n), dtype=y.dtype) + c[0] = (-8 * y0 - 8 * y1 + 16 * ym) / h**4 + (- 2 * f0 + 2 * f1) / h**3 + c[1] = (18 * y0 + 14 * y1 - 32 * ym) / h**3 + (5 * f0 - 3 * f1) / h**2 + c[2] = (-11 * y0 - 5 * y1 + 16 * ym) / h**2 + (-4 * f0 + f1) / h + c[3] = f0 + c[4] = y0 + + c = np.rollaxis(c, 2) + return PPoly(c, [x0, x1], extrapolate=True, axis=1) + + +def rk(fun, a, b, ya, fa, h_abs, rtol, atol, max_step, method, events, + direction, is_terminal): + """Integrate an ODE by Runge-Kutta method.""" + s = np.sign(b - a) + + A, B, C, E, M, K, order = prepare_method(method, ya.shape[0], s) + + x = a + y = ya + f = fa + + ys = [y] + xs = [x] + fs = [f] + if order == 3: + yms = None + else: + yms = [] + + if events is not None: + g = [event(x, y) for event in events] + x_events = [[] for _ in range(len(events))] + else: + x_events = None + + status = None + while status is None: + h_abs = min(h_abs, max_step) + + d = abs(b - x) + if h_abs > d: + status = 1 + h_abs = d + x_new = b + h = h_abs * s + else: + h = h_abs * s + x_new = x + h + if x_new == x: # h less than spacing between numbers. + status = 0 + + y_new, f_new, error = rk_step(fun, x, y, f, h, A, B, C, E, K) + scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol + error_norm = norm(error / scale) + + if error_norm > 1: + h_abs *= max(MIN_FACTOR, SAFETY * error_norm**(-1/order)) + status = None + continue + + if M is not None: + ym = y + 0.5 * h * np.dot(K.T, M) + else: + ym = None + + if events is not None: + g_new = [event(x_new, y_new) for event in events] + active_events = get_active_events(g, g_new, direction) + g = g_new + if active_events.size > 0: + sol = create_spline_one_step(x, x_new, y, y_new, f, f_new, ym) + root_indices, roots, terminate = handle_events( + sol, events, active_events, is_terminal, x, x_new) + + for e, xe in zip(root_indices, roots): + x_events[e].append(xe) + + if terminate: + status = 2 + x_new = roots[-1] + y_new = sol(x_new) + if ym is not None: + ym = sol(0.5 * (x + x_new)) + f_new = fun(x_new, y_new) + + x = x_new + y = y_new + f = f_new + ys.append(y) + xs.append(x) + fs.append(f) + if ym is not None: + yms.append(ym) + + with np.errstate(divide='ignore'): + h_abs *= min(MAX_FACTOR, SAFETY * error_norm**(-1/order)) + + xs = np.asarray(xs) + ys = np.asarray(ys) + fs = np.asarray(fs) + if yms is not None: + yms = np.asarray(yms) + + sol = create_spline(xs, ys, fs, yms) + + if x_events: + x_events = [np.asarray(xe) for xe in x_events] + if len(x_events) == 1: + x_events = x_events[0] + + return status, sol, xs, ys.T, fs.T, x_events diff --git a/scipy/integrate/setup.py b/scipy/integrate/setup.py index dc3b00fdbac9..37513cbeba97 100755 --- a/scipy/integrate/setup.py +++ b/scipy/integrate/setup.py @@ -89,9 +89,12 @@ def configuration(parent_package='',top_path=None): depends=(odepack_src + mach_src), **lapack_opt) + config.add_subpackage('_py') + config.add_data_dir('tests') return config + if __name__ == '__main__': from numpy.distutils.core import setup setup(**configuration(top_path='').todict()) diff --git a/scipy/integrate/tests/test_ivp.py b/scipy/integrate/tests/test_ivp.py new file mode 100644 index 000000000000..864d2b4449c2 --- /dev/null +++ b/scipy/integrate/tests/test_ivp.py @@ -0,0 +1,163 @@ +import numpy as np +from numpy.testing import (assert_, assert_allclose, run_module_suite, + assert_equal) +from scipy.integrate import solve_ivp + + +def fun_rational(x, y): + return np.array([y[1] / x, + y[1] * (y[0] + 2 * y[1] - 1) / (x * (y[0] - 1))]) + + +def sol_rational(x): + return np.vstack((x / (x + 10), 10 * x / (x + 10)**2)) + + +def event_rational_1(x, y): + return y[0] - y[1] ** 0.7 + + +def event_rational_2(x, y): + return y[1] ** 0.6 - y[0] + + +def event_rational_3(x, y): + return x - 7.4 + + +def compute_error(y, y_true, rtol, atol): + e = (y - y_true) / (atol + rtol * y_true) + return np.sqrt(np.sum(e**2, axis=0) / e.shape[0]) + + +def test_rk(): + rtol = 1e-3 + atol = 1e-6 + for method in ['RK23', 'RK45']: + for x_span in ([5, 9], [5, 1]): + res = solve_ivp(fun_rational, x_span, [1/3, 2/9], rtol=rtol, + atol=atol, method=method) + assert_equal(res.x[0], x_span[0]) + assert_equal(res.x[-1], x_span[-1]) + assert_(res.x_events is None) + assert_(res.success) + assert_equal(res.status, 1) + + y_true = sol_rational(res.x) + e = compute_error(res.y, y_true, rtol, atol) + assert_(np.all(e < 0.2)) + + xc = np.linspace(*x_span) + yc_true = sol_rational(xc) + yc = res.sol(xc) + + e = compute_error(yc, yc_true, rtol, atol) + assert_(np.all(e < 0.2)) + + assert_allclose(res.sol(res.x), res.y, rtol=1e-15, atol=1e-15) + assert_allclose(res.sol(res.x, 1), res.yp, rtol=1e-15, atol=1e-13) + + +def test_rk_events(): + event_rational_3.terminate = True + + for method in ['RK23', 'RK45']: + res = solve_ivp(fun_rational, [5, 8], [1/3, 2/9], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 1) + assert_equal(res.x_events[1].size, 1) + assert_(5.3 < res.x_events[0][0] < 5.7) + assert_(7.3 < res.x_events[1][0] < 7.7) + + event_rational_1.direction = 1 + event_rational_2.direction = 1 + res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 1) + assert_equal(res.x_events[1].size, 0) + assert_(5.3 < res.x_events[0][0] < 5.7) + + event_rational_1.direction = -1 + event_rational_2.direction = -1 + res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 0) + assert_equal(res.x_events[1].size, 1) + assert_(7.3 < res.x_events[1][0] < 7.7) + + event_rational_1.direction = 0 + event_rational_2.direction = 0 + + res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method, + events=(event_rational_1, event_rational_2, + event_rational_3)) + assert_equal(res.status, 2) + assert_equal(res.x_events[0].size, 1) + assert_equal(res.x_events[1].size, 0) + assert_equal(res.x_events[2].size, 1) + assert_(5.3 < res.x_events[0][0] < 5.7) + assert_(7.3 < res.x_events[2][0] < 7.5) + + # Also test that termination by event doesn't break interpolants. + xc = np.linspace(res.x[0], res.x[-1]) + yc_true = sol_rational(xc) + yc = res.sol(xc) + e = compute_error(yc, yc_true, 1e-3, 1e-6) + assert_(np.all(e < 0.2)) + + # Test in backward direction. + event_rational_1.direction = 0 + event_rational_2.direction = 0 + for method in ['RK23', 'RK45']: + res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 1) + assert_equal(res.x_events[1].size, 1) + assert_(5.3 < res.x_events[0][0] < 5.7) + assert_(7.3 < res.x_events[1][0] < 7.7) + + event_rational_1.direction = -1 + event_rational_2.direction = -1 + res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 1) + assert_equal(res.x_events[1].size, 0) + assert_(5.3 < res.x_events[0][0] < 5.7) + + event_rational_1.direction = 1 + event_rational_2.direction = 1 + res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method, + events=(event_rational_1, event_rational_2)) + assert_equal(res.status, 1) + assert_equal(res.x_events[0].size, 0) + assert_equal(res.x_events[1].size, 1) + assert_(7.3 < res.x_events[1][0] < 7.7) + + event_rational_1.direction = 0 + event_rational_2.direction = 0 + + res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method, + events=(event_rational_1, event_rational_2, + event_rational_3)) + assert_equal(res.status, 2) + assert_equal(res.x_events[0].size, 0) + assert_equal(res.x_events[1].size, 1) + assert_equal(res.x_events[2].size, 1) + assert_(7.3 < res.x_events[1][0] < 7.7) + assert_(7.3 < res.x_events[2][0] < 7.5) + + # Also test that termination by event doesn't break interpolants. + xc = np.linspace(res.x[-1], res.x[0]) + yc_true = sol_rational(xc) + yc = res.sol(xc) + e = compute_error(yc, yc_true, 1e-3, 1e-6) + assert_(np.all(e < 0.2)) + + +if __name__ == '__main__': + run_module_suite()