You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running into an issue using the ReversibleHeun solver, which may or may not just be an issue of choosing a proper step size. I've tried to make a MWE that still has the essence of my use case.
I have a quadratic potential function $\phi(x,y;t)$ that defines gradient dynamics, and that shifts in time so that the fixed point of the system moves around. I'm trying to simulate langevin dynamics, and diffrax has been really useful so far.
It looks though that the ReversibleHeun method becomes unstable, but in a bit of an odd way, and I can't quite figure out what the reason is. It notably persists without any noise in the system.
The example below defines the potential, defines the drift as its negative gradient, and uses a WeaklyDiagonalControlTerm for the isotropic, homogeneous noise. I show that the Heun method seems to work fine with a step size of $0.1$ in the zero-noise case while ReversibleHeun becomes unstable. As $dt$ decreases to $0.001$, ReversibleHeun appears to match.
I'm wondering if one should expect to require a small step size for the reversible heun method, or if there is something deeper going on. Any guidance would be appreciated.
I'm using diffrax version 0.5.0.
import numpy as np
import matplotlib.pyplot as plt
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrandom
from diffrax import VirtualBrownianTree, ODETerm, MultiTerm
from diffrax import WeaklyDiagonalControlTerm, ReversibleHeun, Heun
from diffrax import diffeqsolve, SaveAt
SEED = 123
rng = np.random.default_rng(seed=SEED)
key = jrandom.PRNGKey(seed=rng.integers(2**32))
def sigmoid(t, a, b, tcrit):
"""Sigmoid helper function"""
return 0.5 * (a + b + (b - a) * jnp.tanh(t-tcrit))
def potential(t, y, args):
"""A quadratic potential function where the fixed point changes.
p(x,y) = (x - u(t))^2 + (y - v(t))^2
with u(t) and v(t) sigmoidal functions defined by the arguments in `args`
"""
a1 = args['a1']
b1 = args['b1']
t1 = args['t1']
a2 = args['a2']
b2 = args['b2']
t2 = args['t2']
u = sigmoid(t, a1, b1, t1)
v = sigmoid(t, a2, b2, t2)
dy = y - jnp.array([u, v])
return jnp.sum(dy * dy)
### Define drift and diffusion terms
def f(t, y, args):
"""Drift is defined via the gradient of the potential"""
return -jax.jacfwd(potential, 1)(t, y, args)
def g(t, y, args):
"""Constant diffusion. Noise scale is a parameter `sigma` in `args`."""
return args['sigma'] * jnp.ones(y.shape, dtype=jnp.float64)
# ### Demonstrate Heun Solver works but ReversibleHeun becomes unstable
dt0 = 0.1 # Initial solver step size: ReversibleHeun unstable
# dt0 = 0.01 # Initial solver step size: Beginning of an instability
# dt0 = 0.001 # Initial solver step size: Matches Heun method
args = {
'a1': 0, # x fixed point starts at 0, moves to 1 at t=5
'b1': 1,
't1': 5,
'a2': 1, # y fixed point starts at 1, moves to 0 at t=5
'b2': 0,
't2': 5,
'sigma': 0.0 # SET NOISE TO 0
}
max_steps = 4096 * 8 # increase max number of steps to be safe
vbt_tol = 1e-6 # tolerance on VirtualBrownianTree
t0 = 0.
t1 = 10.
y0 = jnp.array([0, 0], dtype=jnp.float64) # (0, 0) initial condition
key, subkey = jrandom.split(key, 2)
brownian_motion = VirtualBrownianTree(
t0, t1, tol=vbt_tol,
shape=(len(y0),),
key=subkey
)
terms = MultiTerm(
ODETerm(f),
WeaklyDiagonalControlTerm(g, brownian_motion)
)
ts_save = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=ts_save)
sol_heun = diffeqsolve(
terms, Heun(),
t0, t1, dt0=dt0,
y0=y0,
saveat=saveat,
args=args,
max_steps=max_steps,
)
sol_rev_heun = diffeqsolve(
terms, ReversibleHeun(),
t0, t1, dt0=dt0,
y0=y0,
saveat=saveat,
args=args,
max_steps=max_steps,
)
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.plot(ts_save, sol_heun.ys, label=['x (heun)','y (heun)'])
ax1.plot(
ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
':', label='fixed point x'
)
ax1.plot(
ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
':', label='fixed point y'
)
ax1.set_xlabel('t')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.set_title("Heun Method")
ax2.plot(ts_save, sol_rev_heun.ys, label=['x (rev heun)','y (rev heun)'])
ax2.plot(
ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
':', label='fixed point x'
)
ax2.plot(
ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
':', label='fixed point y'
)
ax2.set_xlabel('t')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.set_title("Reversible Heun Method")
fig.suptitle(f"No noise, dt0={dt0}")
plt.tight_layout()
plt.show()
I think this is expected! ReversibleHeun is quite an unstable solver. It often requires smaller step sizes than other solvers. This is partly because it retains additional memory between evaluations (other than just the evolving state). I could believe that this memory, combined with the "moving target" nature of your problem makes it a particularly poor fit.
I'm running into an issue using the ReversibleHeun solver, which may or may not just be an issue of choosing a proper step size. I've tried to make a MWE that still has the essence of my use case.
I have a quadratic potential function$\phi(x,y;t)$ that defines gradient dynamics, and that shifts in time so that the fixed point of the system moves around. I'm trying to simulate langevin dynamics, and diffrax has been really useful so far.
It looks though that the ReversibleHeun method becomes unstable, but in a bit of an odd way, and I can't quite figure out what the reason is. It notably persists without any noise in the system.
The example below defines the potential, defines the drift as its negative gradient, and uses a WeaklyDiagonalControlTerm for the isotropic, homogeneous noise. I show that the Heun method seems to work fine with a step size of$0.1$ in the zero-noise case while ReversibleHeun becomes unstable. As $dt$ decreases to $0.001$ , ReversibleHeun appears to match.
I'm wondering if one should expect to require a small step size for the reversible heun method, or if there is something deeper going on. Any guidance would be appreciated.
I'm using diffrax version 0.5.0.
And here's my environment...
The text was updated successfully, but these errors were encountered: