-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'high_order_ibm_test' of https://github.com/pnkraemer/pr…
…obnum into high_order_ibm_test
- Loading branch information
Showing
5 changed files
with
430 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Make a ProbNum ODE solution out of a scipy ODE solution.""" | ||
import numpy as np | ||
from scipy.integrate._ivp.common import OdeSolution | ||
|
||
from probnum import _randomvariablelist, diffeq, randvars | ||
from probnum.filtsmooth.timeseriesposterior import ( | ||
DenseOutputLocationArgType, | ||
DenseOutputValueType, | ||
) | ||
|
||
|
||
class WrappedScipyODESolution(diffeq.ODESolution): | ||
"""Make a ProbNum ODESolution out of a SciPy OdeSolution.""" | ||
|
||
def __init__(self, scipy_solution: OdeSolution, rvs: list): | ||
self.scipy_solution = scipy_solution | ||
|
||
# rvs is of the type `list` of `RandomVariable` and can therefore be | ||
# directly transformed into a _RandomVariableList | ||
rv_states = _randomvariablelist._RandomVariableList(rvs) | ||
super().__init__(locations=scipy_solution.ts, states=rv_states) | ||
|
||
def __call__(self, t: DenseOutputLocationArgType) -> DenseOutputValueType: | ||
"""Evaluate the time-continuous solution at time t. | ||
Parameters | ||
---------- | ||
t | ||
Location / time at which to evaluate the continuous ODE solution. | ||
Returns | ||
------- | ||
randvars.RandomVariable or _randomvariablelist._RandomVariableList | ||
Estimate of the states at time ``t`` based on a fourth order polynomial. | ||
""" | ||
|
||
states = self.scipy_solution(t).T | ||
if np.isscalar(t): | ||
solution_as_rv = randvars.Constant(states) | ||
else: | ||
solution_as_rv = _randomvariablelist._RandomVariableList( | ||
[randvars.Constant(state) for state in states] | ||
) | ||
return solution_as_rv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Wrapper class of scipy.integrate. for RK23 and RK45. | ||
Dense-output can not be used for DOP853, if you use other RK-methods, | ||
make sure, that the current implementation works for them . | ||
""" | ||
import numpy as np | ||
from scipy.integrate._ivp import rk | ||
from scipy.integrate._ivp.common import OdeSolution | ||
|
||
from probnum import diffeq, randvars | ||
from probnum.diffeq import wrappedscipyodesolution | ||
from probnum.type import FloatArgType | ||
|
||
|
||
class WrappedScipyRungeKutta(diffeq.ODESolver): | ||
"""Wrappper for Runge-Kutta methods from Scipy, implements the stepfunction and | ||
dense output.""" | ||
|
||
def __init__(self, solver: rk.RungeKutta): | ||
self.solver = solver | ||
self.interpolants = None | ||
|
||
# ProbNum ODESolver needs an ivp | ||
ivp = diffeq.IVP( | ||
timespan=[self.solver.t, self.solver.t_bound], | ||
initrv=randvars.Constant(self.solver.y), | ||
rhs=self.solver._fun, | ||
) | ||
|
||
# Dopri853 as implemented in SciPy computes the dense output differently. | ||
if isinstance(solver, rk.DOP853): | ||
raise TypeError( | ||
"Dense output interpolation of DOP853 is currently not supported. Choose a different RK-method." | ||
) | ||
super().__init__(ivp=ivp, order=solver.order) | ||
|
||
def initialise(self): | ||
"""Return t0 and y0 (for the solver, which might be different to ivp.y0) and | ||
initialize the solver. Reset the solver when solving the ODE multiple times, | ||
i.e. explicitly setting y_old, t, y and f to the respective initial values, | ||
otherwise those are wrong when running the solver twice. | ||
Returns | ||
------- | ||
self.ivp.t0: float | ||
initial time point | ||
self.ivp.initrv: randvars.RandomVariable | ||
initial random variable | ||
""" | ||
|
||
self.interpolants = [] | ||
self.solver.y_old = None | ||
self.solver.t = self.ivp.t0 | ||
self.solver.y = self.ivp.initrv.mean | ||
self.solver.f = self.solver.fun(self.solver.t, self.solver.y) | ||
return self.ivp.t0, self.ivp.initrv | ||
|
||
def step( | ||
self, start: FloatArgType, stop: FloatArgType, current: randvars, **kwargs | ||
): | ||
"""Perform one ODE-step from start to stop and set variables to the | ||
corresponding values. | ||
To specify start and stop directly, rk_step() and not _step_impl() is used. | ||
Parameters | ||
---------- | ||
start : float | ||
starting location of the step | ||
stop : float | ||
stopping location of the step | ||
current : :obj:`list` of :obj:`RandomVariable` | ||
current state of the ODE. | ||
Returns | ||
------- | ||
random_var : randvars.RandomVariable | ||
Estimated states of the discrete-time solution. | ||
error_estimation : float | ||
estimated error after having performed the step. | ||
""" | ||
|
||
y = current.mean | ||
dt = stop - start | ||
y_new, f_new = rk.rk_step( | ||
self.solver.fun, | ||
start, | ||
y, | ||
self.solver.f, | ||
dt, | ||
self.solver.A, | ||
self.solver.B, | ||
self.solver.C, | ||
self.solver.K, | ||
) | ||
|
||
# Unnormalized error estimation is used as the error estimation is normalized in | ||
# solve(). | ||
error_estimation = self.solver._estimate_error(self.solver.K, dt) | ||
y_new_as_rv = randvars.Constant(y_new) | ||
|
||
# Update the solver settings. This part is copied from scipy's _step_impl(). | ||
self.solver.h_previous = dt | ||
self.solver.y_old = current | ||
self.solver.t_old = start | ||
self.solver.t = stop | ||
self.solver.y = y_new | ||
self.solver.h_abs = dt | ||
self.solver.f = f_new | ||
return y_new_as_rv, error_estimation | ||
|
||
def rvlist_to_odesol(self, times: np.array, rvs: np.array): | ||
"""Create a ScipyODESolution object which is a subclass of | ||
diffeq.ODESolution.""" | ||
scipy_solution = OdeSolution(times, self.interpolants) | ||
probnum_solution = wrappedscipyodesolution.WrappedScipyODESolution( | ||
scipy_solution, rvs | ||
) | ||
return probnum_solution | ||
|
||
def method_callback(self, time, current_guess, current_error): | ||
"""Call dense output after each step and store the interpolants.""" | ||
dense = self.dense_output() | ||
self.interpolants.append(dense) | ||
|
||
def dense_output(self): | ||
"""Compute the interpolant after each step. | ||
Returns | ||
------- | ||
sol : rk.RkDenseOutput | ||
Interpolant between the last and current location. | ||
""" | ||
Q = self.solver.K.T.dot(self.solver.P) | ||
sol = rk.RkDenseOutput( | ||
self.solver.t_old, self.solver.t, self.solver.y_old.mean, Q | ||
) | ||
return sol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
from scipy.integrate._ivp import rk | ||
|
||
from probnum import diffeq | ||
from probnum.diffeq import wrappedscipysolver | ||
|
||
|
||
def setup_solver(y0, ode): | ||
scipysolver = rk.RK45(ode.rhs, ode.t0, y0, ode.tmax) | ||
testsolver = wrappedscipysolver.WrappedScipyRungeKutta( | ||
rk.RK45(ode.rhs, ode.t0, y0, ode.tmax) | ||
) | ||
return testsolver, scipysolver | ||
|
||
|
||
def case_lorenz(): | ||
y0 = np.array([0.0, 1.0, 1.05]) | ||
ode = diffeq.lorenz([0.0, 1.0], y0) | ||
return setup_solver(y0, ode) | ||
|
||
|
||
def case_logistic(): | ||
y0 = np.array([0.1]) | ||
ode = diffeq.logistic([0.0, 1.0], y0) | ||
return setup_solver(y0, ode) | ||
|
||
|
||
def case_lotkavolterra(): | ||
y0 = np.array([0.1, 0.1]) | ||
ode = diffeq.lotkavolterra([0.0, 1.0], y0) | ||
return setup_solver(y0, ode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
import pytest_cases | ||
|
||
from probnum import _randomvariablelist, diffeq, randvars | ||
|
||
|
||
@pytest_cases.fixture | ||
@pytest_cases.parametrize_with_cases( | ||
"testsolver, scipysolver", cases=".test_wrappedscipy_cases" | ||
) | ||
def solution_case(testsolver, scipysolver): | ||
testsolution = testsolver.solve(diffeq.AdaptiveSteps(0.1, atol=1e-12, rtol=1e-12)) | ||
scipysolution = testsolution.scipy_solution | ||
return testsolution, scipysolution | ||
|
||
|
||
def test_locations(solution_case): | ||
testsolution, scipysolution = solution_case | ||
scipy_t = scipysolution.ts | ||
probnum_t = testsolution.locations | ||
np.testing.assert_allclose(scipy_t, probnum_t, atol=1e-14, rtol=1e-14) | ||
|
||
|
||
def test_call_isscalar(solution_case): | ||
testsolution, scipysolution = solution_case | ||
t = 0.1 | ||
call_scalar = testsolution(t) | ||
call_array = testsolution([0.1, 0.2, 0.3]) | ||
assert np.isscalar(t) | ||
assert isinstance(call_scalar, randvars.Constant) | ||
assert isinstance(call_array, _randomvariablelist._RandomVariableList) | ||
|
||
|
||
def test_states(solution_case): | ||
testsolution, scipysolution = solution_case | ||
scipy_states = np.array(scipysolution(scipysolution.ts)).T | ||
probnum_states = np.array(testsolution.states.mean) | ||
np.testing.assert_allclose(scipy_states, probnum_states, atol=1e-14, rtol=1e-14) | ||
|
||
|
||
def test_call__(solution_case): | ||
testsolution, scipysolution = solution_case | ||
scipy_call = scipysolution(scipysolution.ts) | ||
probnum_call = testsolution(scipysolution.ts).mean.T | ||
np.testing.assert_allclose(scipy_call, probnum_call, atol=1e-14, rtol=1e-14) |
Oops, something went wrong.