Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add info param to return scipy information #157

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 18 additions & 3 deletions torchdiffeq/_impl/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None):
def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, info=False):
"""Integrate a system of ordinary differential equations.

Solves the initial value problem for a non-stiff system of first order ODEs:
Expand Down Expand Up @@ -58,12 +58,17 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even
event_fn: Function that maps the state `y` to a Tensor. The solve terminates when
event_fn evaluates to zero. If this is not None, all but the first elements of
`t` are ignored.
info: scipy.solve_ivp by default returns a class OdeResult with informatino about
the solve. If info=True, this will return not just the solution vector `y`, but
the original `OdeResult` as well.

Returns:
y: Tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
`t`, with the initial value `y0` being the first element along the first
dimension.
OdeResult: Only returned if info=True. See the following link for details.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html

Raises:
ValueError: if an invalid `method` is provided.
Expand All @@ -73,8 +78,15 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even

solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options)

if method is not None:
if (method[0:5] != "scipy") and (info == True):
raise ValueError("The info parameter may only be used for scipy solvers!")

if event_fn is None:
solution = solver.integrate(t)
if info:
solution, OdeResult = solver.integrate(t, info=info)
else:
solution = solver.integrate(t, info=info)
else:
event_t, solution = solver.integrate_until_event(t[0], event_fn)
event_t = event_t.to(t)
Expand All @@ -85,7 +97,10 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even
solution = _flat_to_shape(solution, (len(t),), shapes)

if event_fn is None:
return solution
if info:
return solution, OdeResult
else:
return solution
else:
return event_t, solution

Expand Down
11 changes: 7 additions & 4 deletions torchdiffeq/_impl/scipy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, func, y0, rtol, atol, solver="LSODA", **unused_kwargs):
self.solver = solver
self.func = convert_func_to_numpy(func, self.shape, self.device, self.dtype)

def integrate(self, t):
def integrate(self, t, info=False):
if t.numel() == 1:
return torch.tensor(self.y0)[None].to(self.device, self.dtype)
t = t.detach().cpu().numpy()
sol = solve_ivp(
oderesult = solve_ivp(
self.func,
t_span=[t.min(), t.max()],
y0=self.y0,
Expand All @@ -35,9 +35,12 @@ def integrate(self, t):
rtol=self.rtol,
atol=self.atol,
)
sol = torch.tensor(sol.y).T.to(self.device, self.dtype)
sol = torch.tensor(oderesult.y).T.to(self.device, self.dtype)
sol = sol.reshape(-1, *self.shape)
return sol
if info:
return sol, oderesult
else:
return sol


def convert_func_to_numpy(func, shape, device, dtype):
Expand Down
4 changes: 2 additions & 2 deletions torchdiffeq/_impl/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _before_integrate(self, t):
def _advance(self, next_t):
raise NotImplementedError

def integrate(self, t):
def integrate(self, t, **kwargs):
solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
solution[0] = self.y0
t = t.to(self.dtype)
Expand Down Expand Up @@ -91,7 +91,7 @@ def _grid_constructor(func, y0, t):
def _step_func(self, func, t0, dt, t1, y0):
pass

def integrate(self, t):
def integrate(self, t, **kwargs):
time_grid = self.grid_constructor(self.func, self.y0, t)
assert time_grid[0] == t[0] and time_grid[-1] == t[-1]

Expand Down