Skip to content

Solving backwards-in-time fails for AbstractSRK's #598

@sammccallum

Description

@sammccallum

Description

When solving from t1 to t0 with -dt0, AbstractSRK produces an incorrect solution.

Here is an example:

def drift(t, y, args):
    a, b = args
    return -a * y


def diffusion(t, y, args):
    a, b = args
    return b


t0 = 0
t1 = 1.0
dt0 = 0.001
y0 = jnp.array([1.0])
args = (1.0, 1.0)

W = dfx.VirtualBrownianTree(
    t0, t1, tol=1e-6, shape=(), key=jr.PRNGKey(2), levy_area=dfx.SpaceTimeLevyArea
)
terms = dfx.MultiTerm(dfx.ODETerm(drift), dfx.ControlTerm(diffusion, W))

print("Forwards-in-time")
solver = dfx.Heun()
sol_heun = dfx.diffeqsolve(terms, solver, t0, t1, dt0, y0, args)
print(f"Heun: {sol_heun.ys[-1][0]}")

solver = dfx.SlowRK()
sol_shark = dfx.diffeqsolve(terms, solver, t0, t1, dt0, y0, args)
print(f"SRK: {sol_shark.ys[-1][0]}")


print("Backwards-in-time")
solver = dfx.Heun()
sol_heun = dfx.diffeqsolve(terms, solver, t1, t0, -dt0, y0, args)
print(f"Heun: {sol_heun.ys[-1][0]}")

solver = dfx.SlowRK()
sol_shark = dfx.diffeqsolve(terms, solver, t1, t0, -dt0, y0, args)
print(f"SRK: {sol_shark.ys[-1][0]}")

Output:

Forwards-in-time
Heun: 1.2419108152389526
SRK: 1.2419666051864624
Backwards-in-time
Heun: 0.34244588017463684
SRK: -0.14933258295059204

Fix

This is due to this line in AbstractSRK:

# time increment
h = t1 - t0

Here, h has the wrong sign. Instead, we should use the dt from VirtualBrownianTree:

# Brownian increment (and space-time Lévy area)
bm_inc = diffusion.contr(t0, t1, use_levy=True)

# time increment
h = bm_inc.dt

With this change, the output becomes:

Forwards-in-time
Heun: 1.2419108152389526
SRK: 1.2419666051864624
Backwards-in-time
Heun: 0.34244588017463684
SRK: 0.34262511134147644

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions