From 34ba78a83ddb0eb27587296d4844cd9ee2a6f6cf Mon Sep 17 00:00:00 2001 From: Scott Marquis Date: Thu, 10 Oct 2019 17:06:33 +0100 Subject: [PATCH] #633 updated klu to take in abs and rel tol --- pybamm/solvers/klu_sparse_solver.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pybamm/solvers/klu_sparse_solver.py b/pybamm/solvers/klu_sparse_solver.py index 019d5ddff7..cd8b25be71 100644 --- a/pybamm/solvers/klu_sparse_solver.py +++ b/pybamm/solvers/klu_sparse_solver.py @@ -20,16 +20,15 @@ def have_klu(): class KLU(pybamm.DaeSolver): """Solve a discretised model, using sundials with the KLU sparse linear solver. - Parameters + Parameters ---------- - method : str, optional - The method to use in solve_ivp (default is "BDF") - tolerance : float, optional - The tolerance for the solver (default is 1e-8). Set as the both reltol and - abstol in solve_ivp. + rtol : float, optional + The relative tolerance for the solver (default is 1e-6). + atol : float, optional + The absolute tolerance for the solver (default is 1e-6). root_method : str, optional The method to use to find initial conditions (default is "lm") - tolerance : float, optional + root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-8). max_steps: int, optional The maximum number of steps the solver will take before terminating @@ -37,12 +36,13 @@ class KLU(pybamm.DaeSolver): """ def __init__( - self, method="ida", tol=1e-8, root_method="lm", root_tol=1e-6, max_steps=1000 + self, rtol=1e-6, atol=1e-6, root_method="lm", root_tol=1e-6, max_steps=1000 ): + if klu_spec is None: raise ImportError("KLU is not installed") - super().__init__(method, tol, root_method, root_tol, max_steps) + super().__init__("ida", rtol, atol, root_method, root_tol, max_steps) def integrate(self, residuals, y0, t_eval, events, mass_matrix, jacobian): """ @@ -71,8 +71,8 @@ def integrate(self, residuals, y0, t_eval, events, mass_matrix, jacobian): def eqsres(t, y, ydot, return_residuals): return_residuals[:] = residuals(t, y, ydot) - rtol = self.tol - atol = self.tol + rtol = self._rtol + atol = self._atol if jacobian: jac_y0_t0 = jacobian(t_eval[0], y0)