diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219a..8b9a2c70 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -28,7 +28,7 @@ } -def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): +def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, info=False): """Integrate a system of ordinary differential equations. Solves the initial value problem for a non-stiff system of first order ODEs: @@ -58,12 +58,17 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even event_fn: Function that maps the state `y` to a Tensor. The solve terminates when event_fn evaluates to zero. If this is not None, all but the first elements of `t` are ignored. + info: scipy.solve_ivp by default returns a class OdeResult with informatino about + the solve. If info=True, this will return not just the solution vector `y`, but + the original `OdeResult` as well. Returns: y: Tensor, where the first dimension corresponds to different time points. Contains the solved value of y for each desired time point in `t`, with the initial value `y0` being the first element along the first dimension. + OdeResult: Only returned if info=True. See the following link for details. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html Raises: ValueError: if an invalid `method` is provided. @@ -73,8 +78,15 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) + if method is not None: + if (method[0:5] != "scipy") and (info == True): + raise ValueError("The info parameter may only be used for scipy solvers!") + if event_fn is None: - solution = solver.integrate(t) + if info: + solution, OdeResult = solver.integrate(t, info=info) + else: + solution = solver.integrate(t, info=info) else: event_t, solution = solver.integrate_until_event(t[0], event_fn) event_t = event_t.to(t) @@ -85,7 +97,10 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solution = _flat_to_shape(solution, (len(t),), shapes) if event_fn is None: - return solution + if info: + return solution, OdeResult + else: + return solution else: return event_t, solution diff --git a/torchdiffeq/_impl/scipy_wrapper.py b/torchdiffeq/_impl/scipy_wrapper.py index 06f93273..017cdd66 100644 --- a/torchdiffeq/_impl/scipy_wrapper.py +++ b/torchdiffeq/_impl/scipy_wrapper.py @@ -22,11 +22,11 @@ def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs): self.solver = solver self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype) - def integrate(self, t): + def integrate(self, t, info=False): if t.numel() == 1: return torch.tensor(self.y0)[None].to(self.device, self.dtype) t = t.detach().cpu().numpy() - sol = solve_ivp( + oderesult = solve_ivp( self.func, t_span=[t.min(), t.max()], y0=self.y0, @@ -35,9 +35,12 @@ def integrate(self, t): rtol=self.rtol, atol=self.atol, ) - sol = torch.tensor(sol.y).T.to(self.device, self.dtype) + sol = torch.tensor(oderesult.y).T.to(self.device, self.dtype) sol = sol.reshape(-1, *self.shape) - return sol + if info: + return sol, oderesult + else: + return sol def convert_func_to_numpy(func, shape, device, dtype): diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index 6915f2bd..bf26d31e 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -21,7 +21,7 @@ def _before_integrate(self, t): def _advance(self, next_t): raise NotImplementedError - def integrate(self, t): + def integrate(self, t, **kwargs): solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0 t = t.to(self.dtype) @@ -91,7 +91,7 @@ def _grid_constructor(func, y0, t): def _step_func(self, func, t0, dt, t1, y0): pass - def integrate(self, t): + def integrate(self, t, **kwargs): time_grid = self.grid_constructor(self.func, self.y0, t) assert time_grid[0] == t[0] and time_grid[-1] == t[-1]