-
-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Description
Description
When solving from t1 to t0 with -dt0, AbstractSRK produces an incorrect solution.
Here is an example:
def drift(t, y, args):
a, b = args
return -a * y
def diffusion(t, y, args):
a, b = args
return b
t0 = 0
t1 = 1.0
dt0 = 0.001
y0 = jnp.array([1.0])
args = (1.0, 1.0)
W = dfx.VirtualBrownianTree(
t0, t1, tol=1e-6, shape=(), key=jr.PRNGKey(2), levy_area=dfx.SpaceTimeLevyArea
)
terms = dfx.MultiTerm(dfx.ODETerm(drift), dfx.ControlTerm(diffusion, W))
print("Forwards-in-time")
solver = dfx.Heun()
sol_heun = dfx.diffeqsolve(terms, solver, t0, t1, dt0, y0, args)
print(f"Heun: {sol_heun.ys[-1][0]}")
solver = dfx.SlowRK()
sol_shark = dfx.diffeqsolve(terms, solver, t0, t1, dt0, y0, args)
print(f"SRK: {sol_shark.ys[-1][0]}")
print("Backwards-in-time")
solver = dfx.Heun()
sol_heun = dfx.diffeqsolve(terms, solver, t1, t0, -dt0, y0, args)
print(f"Heun: {sol_heun.ys[-1][0]}")
solver = dfx.SlowRK()
sol_shark = dfx.diffeqsolve(terms, solver, t1, t0, -dt0, y0, args)
print(f"SRK: {sol_shark.ys[-1][0]}")Output:
Forwards-in-time
Heun: 1.2419108152389526
SRK: 1.2419666051864624
Backwards-in-time
Heun: 0.34244588017463684
SRK: -0.14933258295059204Fix
This is due to this line in AbstractSRK:
diffrax/diffrax/_solver/srk.py
Lines 355 to 356 in 14baa1e
| # time increment | |
| h = t1 - t0 |
Here,
h has the wrong sign. Instead, we should use the dt from VirtualBrownianTree:
# Brownian increment (and space-time Lévy area)
bm_inc = diffusion.contr(t0, t1, use_levy=True)
# time increment
h = bm_inc.dtWith this change, the output becomes:
Forwards-in-time
Heun: 1.2419108152389526
SRK: 1.2419666051864624
Backwards-in-time
Heun: 0.34244588017463684
SRK: 0.34262511134147644Metadata
Metadata
Assignees
Labels
No labels