Skip to content

Commit

Permalink
Merge branch 'high_order_ibm_test' of https://github.com/pnkraemer/pr…
Browse files Browse the repository at this point in the history
…obnum into high_order_ibm_test
  • Loading branch information
pnkraemer committed May 5, 2021
2 parents 4bdf42a + 5c27673 commit cd3dedb
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/probnum/diffeq/wrappedscipyodesolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Make a ProbNum ODE solution out of a scipy ODE solution."""
import numpy as np
from scipy.integrate._ivp.common import OdeSolution

from probnum import _randomvariablelist, diffeq, randvars
from probnum.filtsmooth.timeseriesposterior import (
DenseOutputLocationArgType,
DenseOutputValueType,
)


class WrappedScipyODESolution(diffeq.ODESolution):
"""Make a ProbNum ODESolution out of a SciPy OdeSolution."""

def __init__(self, scipy_solution: OdeSolution, rvs: list):
self.scipy_solution = scipy_solution

# rvs is of the type `list` of `RandomVariable` and can therefore be
# directly transformed into a _RandomVariableList
rv_states = _randomvariablelist._RandomVariableList(rvs)
super().__init__(locations=scipy_solution.ts, states=rv_states)

def __call__(self, t: DenseOutputLocationArgType) -> DenseOutputValueType:
"""Evaluate the time-continuous solution at time t.
Parameters
----------
t
Location / time at which to evaluate the continuous ODE solution.
Returns
-------
randvars.RandomVariable or _randomvariablelist._RandomVariableList
Estimate of the states at time ``t`` based on a fourth order polynomial.
"""

states = self.scipy_solution(t).T
if np.isscalar(t):
solution_as_rv = randvars.Constant(states)
else:
solution_as_rv = _randomvariablelist._RandomVariableList(
[randvars.Constant(state) for state in states]
)
return solution_as_rv
138 changes: 138 additions & 0 deletions src/probnum/diffeq/wrappedscipysolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Wrapper class of scipy.integrate. for RK23 and RK45.
Dense-output can not be used for DOP853, if you use other RK-methods,
make sure, that the current implementation works for them .
"""
import numpy as np
from scipy.integrate._ivp import rk
from scipy.integrate._ivp.common import OdeSolution

from probnum import diffeq, randvars
from probnum.diffeq import wrappedscipyodesolution
from probnum.type import FloatArgType


class WrappedScipyRungeKutta(diffeq.ODESolver):
"""Wrappper for Runge-Kutta methods from Scipy, implements the stepfunction and
dense output."""

def __init__(self, solver: rk.RungeKutta):
self.solver = solver
self.interpolants = None

# ProbNum ODESolver needs an ivp
ivp = diffeq.IVP(
timespan=[self.solver.t, self.solver.t_bound],
initrv=randvars.Constant(self.solver.y),
rhs=self.solver._fun,
)

# Dopri853 as implemented in SciPy computes the dense output differently.
if isinstance(solver, rk.DOP853):
raise TypeError(
"Dense output interpolation of DOP853 is currently not supported. Choose a different RK-method."
)
super().__init__(ivp=ivp, order=solver.order)

def initialise(self):
"""Return t0 and y0 (for the solver, which might be different to ivp.y0) and
initialize the solver. Reset the solver when solving the ODE multiple times,
i.e. explicitly setting y_old, t, y and f to the respective initial values,
otherwise those are wrong when running the solver twice.
Returns
-------
self.ivp.t0: float
initial time point
self.ivp.initrv: randvars.RandomVariable
initial random variable
"""

self.interpolants = []
self.solver.y_old = None
self.solver.t = self.ivp.t0
self.solver.y = self.ivp.initrv.mean
self.solver.f = self.solver.fun(self.solver.t, self.solver.y)
return self.ivp.t0, self.ivp.initrv

def step(
self, start: FloatArgType, stop: FloatArgType, current: randvars, **kwargs
):
"""Perform one ODE-step from start to stop and set variables to the
corresponding values.
To specify start and stop directly, rk_step() and not _step_impl() is used.
Parameters
----------
start : float
starting location of the step
stop : float
stopping location of the step
current : :obj:`list` of :obj:`RandomVariable`
current state of the ODE.
Returns
-------
random_var : randvars.RandomVariable
Estimated states of the discrete-time solution.
error_estimation : float
estimated error after having performed the step.
"""

y = current.mean
dt = stop - start
y_new, f_new = rk.rk_step(
self.solver.fun,
start,
y,
self.solver.f,
dt,
self.solver.A,
self.solver.B,
self.solver.C,
self.solver.K,
)

# Unnormalized error estimation is used as the error estimation is normalized in
# solve().
error_estimation = self.solver._estimate_error(self.solver.K, dt)
y_new_as_rv = randvars.Constant(y_new)

# Update the solver settings. This part is copied from scipy's _step_impl().
self.solver.h_previous = dt
self.solver.y_old = current
self.solver.t_old = start
self.solver.t = stop
self.solver.y = y_new
self.solver.h_abs = dt
self.solver.f = f_new
return y_new_as_rv, error_estimation

def rvlist_to_odesol(self, times: np.array, rvs: np.array):
"""Create a ScipyODESolution object which is a subclass of
diffeq.ODESolution."""
scipy_solution = OdeSolution(times, self.interpolants)
probnum_solution = wrappedscipyodesolution.WrappedScipyODESolution(
scipy_solution, rvs
)
return probnum_solution

def method_callback(self, time, current_guess, current_error):
"""Call dense output after each step and store the interpolants."""
dense = self.dense_output()
self.interpolants.append(dense)

def dense_output(self):
"""Compute the interpolant after each step.
Returns
-------
sol : rk.RkDenseOutput
Interpolant between the last and current location.
"""
Q = self.solver.K.T.dot(self.solver.P)
sol = rk.RkDenseOutput(
self.solver.t_old, self.solver.t, self.solver.y_old.mean, Q
)
return sol
31 changes: 31 additions & 0 deletions tests/test_diffeq/test_wrappedscipy_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from scipy.integrate._ivp import rk

from probnum import diffeq
from probnum.diffeq import wrappedscipysolver


def setup_solver(y0, ode):
scipysolver = rk.RK45(ode.rhs, ode.t0, y0, ode.tmax)
testsolver = wrappedscipysolver.WrappedScipyRungeKutta(
rk.RK45(ode.rhs, ode.t0, y0, ode.tmax)
)
return testsolver, scipysolver


def case_lorenz():
y0 = np.array([0.0, 1.0, 1.05])
ode = diffeq.lorenz([0.0, 1.0], y0)
return setup_solver(y0, ode)


def case_logistic():
y0 = np.array([0.1])
ode = diffeq.logistic([0.0, 1.0], y0)
return setup_solver(y0, ode)


def case_lotkavolterra():
y0 = np.array([0.1, 0.1])
ode = diffeq.lotkavolterra([0.0, 1.0], y0)
return setup_solver(y0, ode)
45 changes: 45 additions & 0 deletions tests/test_diffeq/test_wrappedscipyodesolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import pytest_cases

from probnum import _randomvariablelist, diffeq, randvars


@pytest_cases.fixture
@pytest_cases.parametrize_with_cases(
"testsolver, scipysolver", cases=".test_wrappedscipy_cases"
)
def solution_case(testsolver, scipysolver):
testsolution = testsolver.solve(diffeq.AdaptiveSteps(0.1, atol=1e-12, rtol=1e-12))
scipysolution = testsolution.scipy_solution
return testsolution, scipysolution


def test_locations(solution_case):
testsolution, scipysolution = solution_case
scipy_t = scipysolution.ts
probnum_t = testsolution.locations
np.testing.assert_allclose(scipy_t, probnum_t, atol=1e-14, rtol=1e-14)


def test_call_isscalar(solution_case):
testsolution, scipysolution = solution_case
t = 0.1
call_scalar = testsolution(t)
call_array = testsolution([0.1, 0.2, 0.3])
assert np.isscalar(t)
assert isinstance(call_scalar, randvars.Constant)
assert isinstance(call_array, _randomvariablelist._RandomVariableList)


def test_states(solution_case):
testsolution, scipysolution = solution_case
scipy_states = np.array(scipysolution(scipysolution.ts)).T
probnum_states = np.array(testsolution.states.mean)
np.testing.assert_allclose(scipy_states, probnum_states, atol=1e-14, rtol=1e-14)


def test_call__(solution_case):
testsolution, scipysolution = solution_case
scipy_call = scipysolution(scipysolution.ts)
probnum_call = testsolution(scipysolution.ts).mean.T
np.testing.assert_allclose(scipy_call, probnum_call, atol=1e-14, rtol=1e-14)
Loading

0 comments on commit cd3dedb

Please sign in to comment.