Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT #37

Closed
pnkraemer opened this issue Aug 17, 2021 · 1 comment
Closed

JIT #37

pnkraemer opened this issue Aug 17, 2021 · 1 comment

Comments

@pnkraemer
Copy link
Owner

pnkraemer commented Aug 17, 2021

Maybe we can write some of the code in a way that it is out-of-the-box jittable.

https://jax.readthedocs.io/en/latest/jax.html#jax.jit

@pnkraemer pnkraemer assigned pnkraemer and unassigned pnkraemer Aug 17, 2021
@pnkraemer
Copy link
Owner Author

Inputs and outputs of jittable functions are jnp.arrays, scalars, or tuples/lists thereof.

To JIT the solver, there are a few speedbumps:

  • the assert ... statements in EK1
  • The if step_change < min_change stuff in adaptive steps.

The following works after commenting out the EK1 asserts:

@jax.jit
def solve(y0):
    ivp = tornado.ivp.vanderpol(y0=y0)

    solution, _ = tornado.ivpsolve.solve(
        ivp,
        method="ek1_diag",
        solver_order=5,
        adaptive=False,
        dt=0.1,
        benchmark_mode=True,
    )
    return solution.y.mean

y0 = jnp.array([2.0, 0.0])
solve(y0)
solve(y0)

Note: it is pointless to test this on such high level, because compilation takes ages.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant