Skip to content

Commit

Permalink
fixed bug with options; moved RKAdaptiveStepsizeODESolver to rk_common;
Browse files Browse the repository at this point in the history
fixed bug with adjoint_params.
  • Loading branch information
patrick-kidger committed Aug 5, 2020
1 parent 1c8cc50 commit fb0ac0f
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 145 deletions.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -21,7 +21,7 @@
description="ODE solvers and adjoint sensitivity analysis in PyTorch.",
url="https://github.com/rtqichen/torchdiffeq",
packages=setuptools.find_packages(),
install_requires=['torch>=1.0.0'],
install_requires=['torch>=1.3.0'],
python_requires='~=3.6',
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
3 changes: 1 addition & 2 deletions torchdiffeq/_impl/adaptive_heun.py
@@ -1,6 +1,5 @@
import torch
from .rk_common import _ButcherTableau
from .solvers import RKAdaptiveStepsizeODESolver
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver


_ADAPTIVE_HEUN_TABLEAU = _ButcherTableau(
Expand Down
8 changes: 5 additions & 3 deletions torchdiffeq/_impl/adjoint.py
Expand Up @@ -43,6 +43,8 @@ def backward(ctx, grad_y):

if adjoint_options is None:
adjoint_options = {}
else:
adjoint_options = adjoint_options.copy()

# We assume that any grid points are given to us ordered in the same direction as for the forward pass (for
# compatibility with setting adjoint_options = options), so we need to flip them around here.
Expand All @@ -51,7 +53,6 @@ def backward(ctx, grad_y):
except KeyError:
pass
else:
adjoint_options = adjoint_options.copy()
adjoint_options['grid_points'] = grid_points.flip(0)

# Backward compatibility: by default use a mixed L-infinity/RMS norm over the input, where we treat t, each
Expand Down Expand Up @@ -155,9 +156,10 @@ def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None

# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if adjoint_params is not None and not isinstance(func, nn.Module):
if adjoint_params is None and not isinstance(func, nn.Module):
raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they '
'can be specified explicitly via the `adjoint_params` argument.')
'can be specified explicitly via the `adjoint_params` argument. If there are no parameters '
'then it is allowable to set `adjoint_params=()`.')

# Must come before we default adjoint_options to options; using the same norm for both wouldn't make any sense.
try:
Expand Down
4 changes: 2 additions & 2 deletions torchdiffeq/_impl/bosh3.py
@@ -1,6 +1,6 @@
import torch
from .rk_common import _ButcherTableau
from .solvers import RKAdaptiveStepsizeODESolver
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver


_BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau(
alpha=torch.tensor([1/2, 3/4, 1.], dtype=torch.float64),
Expand Down
3 changes: 1 addition & 2 deletions torchdiffeq/_impl/dopri5.py
@@ -1,6 +1,5 @@
import torch
from .rk_common import _ButcherTableau
from .solvers import RKAdaptiveStepsizeODESolver
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver


_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau(
Expand Down
4 changes: 2 additions & 2 deletions torchdiffeq/_impl/dopri8.py
@@ -1,7 +1,7 @@
import numpy as np
import torch
from .rk_common import _ButcherTableau
from .solvers import RKAdaptiveStepsizeODESolver
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver


A = [ 1/18, 1/12, 1/8, 5/16, 3/8, 59/400, 93/200, 5490023248/9719169821, 13/20, 1201146811/1299019798, 1, 1, 1]

Expand Down
3 changes: 2 additions & 1 deletion torchdiffeq/_impl/misc.py
Expand Up @@ -161,6 +161,8 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS):
# Normalise method and options
if options is None:
options = {}
else:
options = options.copy()
if method is None:
method = 'dopri5'
if method not in SOLVERS:
Expand Down Expand Up @@ -197,7 +199,6 @@ def _check_inputs(func, y0, t, rtol, atol, method, options, SOLVERS):
except KeyError:
pass
else:
options = options.copy()
options['grid_points'] = -grid_points

# Can only do after having normalised time
Expand Down
131 changes: 131 additions & 0 deletions torchdiffeq/_impl/rk_common.py
@@ -1,5 +1,11 @@
import bisect
import collections
import torch
from .interp import _interp_evaluate, _interp_fit
from .misc import (_compute_error_ratio,
_select_initial_step,
_optimal_step_size)
from .solvers import AdaptiveStepsizeODESolver


_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha, beta, c_sol, c_error')
Expand Down Expand Up @@ -95,3 +101,128 @@ def rk4_alt_step_func(func, t, dt, y, k1=None):
k3 = func(t + dt * _two_thirds, y + dt * (k2 - k1 * _one_third))
k4 = func(t + dt, y + dt * (k1 - k2 + k3))
return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125


class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeODESolver):
order: int
tableau: _ButcherTableau
mid: torch.Tensor

def __init__(self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2,
max_num_steps=2 ** 31 - 1, grid_points=None, eps=0., dtype=torch.float64, **kwargs):
super(RKAdaptiveStepsizeODESolver, self).__init__(dtype=dtype, y0=y0, **kwargs)

# We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use
# `dtype` (defaulting to float64).
dtype = torch.promote_types(dtype, y0.dtype)
device = y0.device

self.func = lambda t, y: func(t.type_as(y), y)
self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device)
self.atol = torch.as_tensor(atol, dtype=dtype, device=device)
self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device)
self.safety = torch.as_tensor(safety, dtype=dtype, device=device)
self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device)
self.dfactor = torch.as_tensor(dfactor, dtype=dtype, device=device)
self.max_num_steps = torch.as_tensor(max_num_steps, dtype=torch.int32, device=device)
grid_points = torch.tensor([], dtype=dtype, device=device) if grid_points is None else grid_points.to(dtype)
self.grid_points = grid_points
self.eps = torch.as_tensor(eps, dtype=dtype, device=device)
self.dtype = dtype

# Copy from class to instance to set device
self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=device, dtype=y0.dtype),
beta=[b.to(device=device, dtype=y0.dtype) for b in self.tableau.beta],
c_sol=self.tableau.c_sol.to(device=device, dtype=y0.dtype),
c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype))
self.mid = self.mid.to(device=device, dtype=y0.dtype)

def _before_integrate(self, t):
f0 = self.func(t[0], self.y0)
if self.first_step is None:
first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol,
self.norm, f0=f0)
else:
first_step = self.first_step
self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, [self.y0] * 5)
self.next_grid_index = min(bisect.bisect(self.grid_points.tolist(), t[0]), len(self.grid_points) - 1)

def _advance(self, next_t):
"""Interpolate through the next time point, integrating as necessary."""
n_steps = 0
while next_t > self.rk_state.t1:
assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
self.rk_state = self._adaptive_step(self.rk_state)
n_steps += 1
return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)

def _adaptive_step(self, rk_state):
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
y0, f0, _, t0, dt, interp_coeff = rk_state
# dtypes: self.y0.dtype (probably float32); self.dtype (probably float64)
# used for state and timelike objects respectively.
# Then:
# y0.dtype == self.y0.dtype
# f0.dtype == self.y0.dtype
# t0.dtype == self.dtype
# dt.dtype == self.dtype
# for coeff in interp_coeff: coeff.dtype == self.y0.dtype


########################################################
# Assertions #
########################################################
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0)

########################################################
# Make step, respecting prescribed grid points #
########################################################
on_grid = len(self.grid_points) and t0 < self.grid_points[self.next_grid_index] < t0 + dt
if on_grid:
dt = self.grid_points[self.next_grid_index] - t0
eps = min(0.5 * dt, self.eps)
dt = dt - eps
else:
eps = 0

y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=self.tableau)
# dtypes:
# y1.dtype == self.y0.dtype
# f1.dtype == self.y0.dtype
# y1_error.dtype == self.dtype
# k.dtype == self.y0.dtype

########################################################
# Error Ratio #
########################################################
error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm)
accept_step = error_ratio <= 1
# dtypes:
# error_ratio.dtype == self.dtype

########################################################
# Update RK State #
########################################################
t_next = t0 + dt + 2 * eps if accept_step else t0
y_next = y1 if accept_step else y0
if on_grid and accept_step:
# We've just passed a discontinuity in f; we should update f to match the side of the discontinuity we're
# now on.
if eps != 0:
f1 = self.func(t_next, y_next)
if self.next_grid_index != len(self.grid_points) - 1:
self.next_grid_index += 1
f_next = f1 if accept_step else f0
interp_coeff = self._interp_fit(y0, y1, k, dt) if accept_step else interp_coeff
dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order)
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
return rk_state

def _interp_fit(self, y0, y1, k, dt):
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
dt = dt.type_as(y0)
y_mid = y0 + k.matmul(dt * self.mid).view_as(y0)
f0 = k[..., 0]
f1 = k[..., -1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
133 changes: 1 addition & 132 deletions torchdiffeq/_impl/solvers.py
@@ -1,12 +1,6 @@
import abc
import bisect
import torch
from .interp import _interp_evaluate, _interp_fit
from .rk_common import _ButcherTableau, _RungeKuttaState, _runge_kutta_step
from .misc import (_compute_error_ratio,
_handle_unused_kwargs,
_select_initial_step,
_optimal_step_size)
from .misc import _handle_unused_kwargs


class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -107,128 +101,3 @@ def _linear_interp(self, t0, t1, y0, y1, t):
return y1
slope = (t - t0) / (t1 - t0)
return y0 + slope * (y1 - y0)


class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeODESolver):
order: int
tableau: _ButcherTableau
mid: torch.Tensor

def __init__(self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2,
max_num_steps=2 ** 31 - 1, grid_points=None, eps=0., dtype=torch.float64, **kwargs):
super(RKAdaptiveStepsizeODESolver, self).__init__(dtype=dtype, y0=y0, **kwargs)

# We use mixed precision. y has its original dtype (probably float32), whilst all 'time'-like objects use
# `dtype` (defaulting to float64).
dtype = torch.promote_types(dtype, y0.dtype)
device = y0.device

self.func = lambda t, y: func(t.type_as(y), y)
self.rtol = torch.as_tensor(rtol, dtype=dtype, device=device)
self.atol = torch.as_tensor(atol, dtype=dtype, device=device)
self.first_step = None if first_step is None else torch.as_tensor(first_step, dtype=dtype, device=device)
self.safety = torch.as_tensor(safety, dtype=dtype, device=device)
self.ifactor = torch.as_tensor(ifactor, dtype=dtype, device=device)
self.dfactor = torch.as_tensor(dfactor, dtype=dtype, device=device)
self.max_num_steps = torch.as_tensor(max_num_steps, dtype=torch.int32, device=device)
grid_points = torch.tensor([], dtype=dtype, device=device) if grid_points is None else grid_points.to(dtype)
self.grid_points = grid_points
self.eps = torch.as_tensor(eps, dtype=dtype, device=device)
self.dtype = dtype

# Copy from class to instance to set device
self.tableau = _ButcherTableau(alpha=self.tableau.alpha.to(device=device, dtype=y0.dtype),
beta=[b.to(device=device, dtype=y0.dtype) for b in self.tableau.beta],
c_sol=self.tableau.c_sol.to(device=device, dtype=y0.dtype),
c_error=self.tableau.c_error.to(device=device, dtype=y0.dtype))
self.mid = self.mid.to(device=device, dtype=y0.dtype)

def _before_integrate(self, t):
f0 = self.func(t[0], self.y0)
if self.first_step is None:
first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol,
self.norm, f0=f0)
else:
first_step = self.first_step
self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, [self.y0] * 5)
self.next_grid_index = min(bisect.bisect(self.grid_points.tolist(), t[0]), len(self.grid_points) - 1)

def _advance(self, next_t):
"""Interpolate through the next time point, integrating as necessary."""
n_steps = 0
while next_t > self.rk_state.t1:
assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
self.rk_state = self._adaptive_step(self.rk_state)
n_steps += 1
return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)

def _adaptive_step(self, rk_state):
"""Take an adaptive Runge-Kutta step to integrate the ODE."""
y0, f0, _, t0, dt, interp_coeff = rk_state
# dtypes: self.y0.dtype (probably float32); self.dtype (probably float64)
# used for state and timelike objects respectively.
# Then:
# y0.dtype == self.y0.dtype
# f0.dtype == self.y0.dtype
# t0.dtype == self.dtype
# dt.dtype == self.dtype
# for coeff in interp_coeff: coeff.dtype == self.y0.dtype


########################################################
# Assertions #
########################################################
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0)

########################################################
# Make step, respecting prescribed grid points #
########################################################
on_grid = len(self.grid_points) and t0 < self.grid_points[self.next_grid_index] < t0 + dt
if on_grid:
dt = self.grid_points[self.next_grid_index] - t0
eps = min(0.5 * dt, self.eps)
dt = dt - eps
else:
eps = 0

y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=self.tableau)
# dtypes:
# y1.dtype == self.y0.dtype
# f1.dtype == self.y0.dtype
# y1_error.dtype == self.dtype
# k.dtype == self.y0.dtype

########################################################
# Error Ratio #
########################################################
error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm)
accept_step = error_ratio <= 1
# dtypes:
# error_ratio.dtype == self.dtype

########################################################
# Update RK State #
########################################################
t_next = t0 + dt + 2 * eps if accept_step else t0
y_next = y1 if accept_step else y0
if on_grid and accept_step:
# We've just passed a discontinuity in f; we should update f to match the side of the discontinuity we're
# now on.
if eps != 0:
f1 = self.func(t_next, y_next)
if self.next_grid_index != len(self.grid_points) - 1:
self.next_grid_index += 1
f_next = f1 if accept_step else f0
interp_coeff = self._interp_fit(y0, y1, k, dt) if accept_step else interp_coeff
dt_next = _optimal_step_size(dt, error_ratio, self.safety, self.ifactor, self.dfactor, self.order)
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff)
return rk_state

def _interp_fit(self, y0, y1, k, dt):
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
dt = dt.type_as(y0)
y_mid = y0 + k.matmul(dt * self.mid).view_as(y0)
f0 = k[..., 0]
f1 = k[..., -1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)

0 comments on commit fb0ac0f

Please sign in to comment.