-
-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Description
Hi,
I've wasted some time because I made a small mistake, and the error was quite obscure.
I passed EulerHeun instead of EulerHeun() as the solver argument.
Minimal example:
import jax.random as jr
from diffrax import diffeqsolve, ControlTerm, EulerHeun, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree
t0, t1 = 0, 3
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.1 * t
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jr.PRNGKey(0))
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = EulerHeun
saveat = SaveAt(dense=True)
sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=1.0, saveat=saveat)
print(sol.evaluate(1.1)) # DeviceArray(0.89436394)
This returns
--> 137 assert type(term_contr_kwargs) is tuple
For the Euler solver, the error is different:
TypeError: functools.partial() argument after ** must be a mapping, not property
Maybe this can be improved? I can take a look later if I have some time.
Metadata
Metadata
Assignees
Labels
No labels