From 9464b2b60ad089b96964845ae6e2f3c1b2d29457 Mon Sep 17 00:00:00 2001 From: ChrisDeGrendele Date: Fri, 26 Mar 2021 11:48:38 -0700 Subject: [PATCH 1/3] add info param to return scipy information --- torchdiffeq/_impl/odeint.py | 9 +++++++-- torchdiffeq/_impl/scipy_wrapper.py | 11 +++++++---- torchdiffeq/_impl/solvers.py | 4 ++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219ad..8aefbae0e 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -28,7 +28,7 @@ } -def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): +def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, info=False): """Integrate a system of ordinary differential equations. Solves the initial value problem for a non-stiff system of first order ODEs: @@ -58,12 +58,17 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even event_fn: Function that maps the state `y` to a Tensor. The solve terminates when event_fn evaluates to zero. If this is not None, all but the first elements of `t` are ignored. + info: scipy.solve_ivp by default returns a class OdeResult with informatino about + the solve. If info=True, this will return not just the solution vector `y`, but + the original `OdeResult` as well. Returns: y: Tensor, where the first dimension corresponds to different time points. Contains the solved value of y for each desired time point in `t`, with the initial value `y0` being the first element along the first dimension. + OdeResult: Only returned if info=True. See the following link for details. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html Raises: ValueError: if an invalid `method` is provided. @@ -74,7 +79,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) if event_fn is None: - solution = solver.integrate(t) + solution = solver.integrate(t, info=info) else: event_t, solution = solver.integrate_until_event(t[0], event_fn) event_t = event_t.to(t) diff --git a/torchdiffeq/_impl/scipy_wrapper.py b/torchdiffeq/_impl/scipy_wrapper.py index 06f93273e..017cdd667 100644 --- a/torchdiffeq/_impl/scipy_wrapper.py +++ b/torchdiffeq/_impl/scipy_wrapper.py @@ -22,11 +22,11 @@ def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs): self.solver = solver self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype) - def integrate(self, t): + def integrate(self, t, info=False): if t.numel() == 1: return torch.tensor(self.y0)[None].to(self.device, self.dtype) t = t.detach().cpu().numpy() - sol = solve_ivp( + oderesult = solve_ivp( self.func, t_span=[t.min(), t.max()], y0=self.y0, @@ -35,9 +35,12 @@ def integrate(self, t): rtol=self.rtol, atol=self.atol, ) - sol = torch.tensor(sol.y).T.to(self.device, self.dtype) + sol = torch.tensor(oderesult.y).T.to(self.device, self.dtype) sol = sol.reshape(-1, *self.shape) - return sol + if info: + return sol, oderesult + else: + return sol def convert_func_to_numpy(func, shape, device, dtype): diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index 6915f2bd9..bf26d31e1 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -21,7 +21,7 @@ def _before_integrate(self, t): def _advance(self, next_t): raise NotImplementedError - def integrate(self, t): + def integrate(self, t, **kwargs): solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0 t = t.to(self.dtype) @@ -91,7 +91,7 @@ def _grid_constructor(func, y0, t): def _step_func(self, func, t0, dt, t1, y0): pass - def integrate(self, t): + def integrate(self, t, **kwargs): time_grid = self.grid_constructor(self.func, self.y0, t) assert time_grid[0] == t[0] and time_grid[-1] == t[-1] From 1de63ab19f76a16ed16c8e5b48d6c084e3011b5a Mon Sep 17 00:00:00 2001 From: ChrisDeGrendele Date: Wed, 31 Mar 2021 09:23:05 -0700 Subject: [PATCH 2/3] Add error for nonscipy solvers + return ODEresult --- torchdiffeq/_impl/odeint.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index 8aefbae0e..e56ad4bd8 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -78,8 +78,15 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) + if method is not None: + if (method[0:5] != "scipy") and (info == True): + raise ValueError("The info parameter may only be used for scipy solvers!") + if event_fn is None: - solution = solver.integrate(t, info=info) + if info: + solution, OdeResult = solver.integrate(t, info=info) + else: + solution = solver.integrate(t, info=info) else: event_t, solution = solver.integrate_until_event(t[0], event_fn) event_t = event_t.to(t) @@ -90,7 +97,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solution = _flat_to_shape(solution, (len(t),), shapes) if event_fn is None: - return solution + return solution, OdeResult else: return event_t, solution From 6158af9ba2311675fa354f6c9b62e47794e3d83f Mon Sep 17 00:00:00 2001 From: ChrisDeGrendele Date: Wed, 31 Mar 2021 09:38:52 -0700 Subject: [PATCH 3/3] tests pass --- torchdiffeq/_impl/odeint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index e56ad4bd8..8b9a2c708 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -97,7 +97,10 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solution = _flat_to_shape(solution, (len(t),), shapes) if event_fn is None: - return solution, OdeResult + if info: + return solution, OdeResult + else: + return solution else: return event_t, solution