diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 75d7f14274..6e1be1c6b2 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -471,25 +471,15 @@ def step( if solver is None: solver = self.solver - if save is False: - # Don't pass previous solution - self._solution = solver.step( - None, - self.built_model, - dt, - npts=npts, - external_variables=external_variables, - inputs=inputs, - ) - else: - self._solution = solver.step( - self._solution, - self.built_model, - dt, - npts=npts, - external_variables=external_variables, - inputs=inputs, - ) + self._solution = solver.step( + self._solution, + self.built_model, + dt, + npts=npts, + external_variables=external_variables, + inputs=inputs, + save=save, + ) def get_variable_array(self, *variables): """ diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 536734fabe..8b9fd509dd 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -575,7 +575,14 @@ def solve(self, model, t_eval, external_variables=None, inputs=None): return solution def step( - self, old_solution, model, dt, npts=2, external_variables=None, inputs=None + self, + old_solution, + model, + dt, + npts=2, + external_variables=None, + inputs=None, + save=True, ): """ Step the solution of the model forward by a given time increment. The @@ -599,7 +606,8 @@ def step( values at the current time inputs : dict, optional Any input parameters to pass to the model when solving - + save : bool + Turn on to store the solution of all previous timesteps Raises ------ @@ -677,7 +685,7 @@ def step( pybamm.logger.debug( "Step time: {}".format(timer.format(solution.solve_time)) ) - if old_solution is None: + if save is False or old_solution is None: return solution else: return old_solution + solution