Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

BUG: integrate: handle additional function arguments with the dopri5 …

…solver (fix for ticket 1392)
  • Loading branch information...
commit 1e2d0c92c85b15adfa15abbc8bae37ab9ce5d0a0 1 parent 388b3da
@WarrenWeckesser WarrenWeckesser authored
Showing with 118 additions and 2 deletions.
  1. +1 −1  scipy/integrate/ode.py
  2. +117 −1 scipy/integrate/tests/test_integrate.py
View
2  scipy/integrate/ode.py
@@ -746,7 +746,7 @@ def reset(self,n,has_jac):
self.success = 1
def run(self,f,jac,y0,t0,t1,f_params,jac_params):
- x,y,iwork,idid = self.runner(*((f,t0,y0,t1) + tuple(self.call_args)))
+ x,y,iwork,idid = self.runner(*((f,t0,y0,t1) + tuple(self.call_args) + (f_params,)))
if idid < 0:
warnings.warn(self.name + ': ' +
self.messages.get(idid, 'Unexpected idid=%s'%idid))
View
118 scipy/integrate/tests/test_integrate.py
@@ -7,7 +7,8 @@
from numpy import arange, zeros, array, dot, sqrt, cos, sin, eye, pi, exp, \
allclose
-from numpy.testing import assert_, TestCase, run_module_suite
+from numpy.testing import assert_, TestCase, run_module_suite, \
+ assert_array_almost_equal
from scipy.integrate import odeint, ode, complex_ode
#------------------------------------------------------------------------------
@@ -206,5 +207,120 @@ def verify(self, zs, t):
#------------------------------------------------------------------------------
+def f(t, x):
+ dxdt = [x[1], -x[0]]
+ return dxdt
+
+def jac(t, x):
+ j = array([[ 0.0, 1.0],
+ [-1.0, 0.0]])
+ return j
+
+def f1(t, x, omega):
+ dxdt = [omega*x[1], -omega*x[0]]
+ return dxdt
+
+def jac1(t, x, omega):
+ j = array([[ 0.0, omega],
+ [-omega, 0.0]])
+ return j
+
+def f2(t, x, omega1, omega2):
+ dxdt = [omega1*x[1], -omega2*x[0]]
+ return dxdt
+
+def jac2(t, x, omega1, omega2):
+ j = array([[ 0.0, omega1],
+ [-omega2, 0.0]])
+ return j
+
+def fv(t, x, omega):
+ dxdt = [omega[0]*x[1], -omega[1]*x[0]]
+ return dxdt
+
+def jacv(t, x, omega):
+ j = array([[ 0.0, omega[0]],
+ [-omega[1], 0.0]])
+ return j
+
+
+class ODECheckParameterUse(object):
+ """Call an ode-class solver with several cases of parameter use."""
+
+ # This class is intentionally not a TestCase subclass.
+ # solver_name must be set before tests can be run with this class.
+
+ # Set these in subclasses.
+ solver_name = ''
+ solver_uses_jac = False
+
+ def _get_solver(self, f, jac):
+ solver = ode(f, jac)
+ if self.solver_uses_jac:
+ solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7,
+ with_jacobian=self.solver_uses_jac)
+ else:
+ # XXX Shouldn't set_integrator *always* accept the keyword arg
+ # 'with_jacobian', and perhaps raise an exception if it is set
+ # to True if the solver can't actually use it?
+ solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7)
+ return solver
+
+ def _check_solver(self, solver):
+ ic = [1.0, 0.0]
+ solver.set_initial_value(ic, 0.0)
+ solver.integrate(pi)
+ assert_array_almost_equal(solver.y, [-1.0, 0.0])
+
+ def test_no_params(self):
+ solver = self._get_solver(f, jac)
+ self._check_solver(solver)
+
+ def test_one_scalar_param(self):
+ solver = self._get_solver(f1, jac1)
+ omega = 1.0
+ solver.set_f_params(omega)
+ if self.solver_uses_jac:
+ solver.set_jac_params(omega)
+ self._check_solver(solver)
+
+ def test_two_scalar_params(self):
+ solver = self._get_solver(f2, jac2)
+ omega1 = 1.0
+ omega2 = 1.0
+ solver.set_f_params(omega1, omega2)
+ if self.solver_uses_jac:
+ solver.set_jac_params(omega1, omega2)
+ self._check_solver(solver)
+
+ def test_vector_param(self):
+ solver = self._get_solver(fv, jacv)
+ omega = [1.0, 1.0]
+ solver.set_f_params(omega)
+ if self.solver_uses_jac:
+ solver.set_jac_params(omega)
+ self._check_solver(solver)
+
+
+class DOPRI5CheckParameterUse(ODECheckParameterUse, TestCase):
+ solver_name = 'dopri5'
+ solver_uses_jac = False
+
+
+class DOP853CheckParameterUse(ODECheckParameterUse, TestCase):
+ solver_name = 'dop853'
+ solver_uses_jac = False
+
+
+class VODECheckParameterUse(ODECheckParameterUse, TestCase):
+ solver_name = 'vode'
+ solver_uses_jac = True
+
+
+class ZVODECheckParameterUse(ODECheckParameterUse, TestCase):
+ solver_name = 'zvode'
+ solver_uses_jac = True
+
+
if __name__ == "__main__":
run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.