In [176]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import numpy as np

In [50]:
from matplotlib import pyplot as  plt

In [106]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [52]:
#jax.config.update("jax_enable_x64", True)

In [113]:
jnp.array([1.0],dtype = float)

DeviceArray([1.], dtype=float32)

In [115]:
class Lorenz(eqx.Module):
    k1: float

    def __call__(self, t, y, args):
        f0 = 10.0*(y[1] - y[0])
        f1 = self.k1 * y[0] - y[1] - y[0] * y[2]
        f2 = y[0] * y[1] - (8/3)*y[2]
        return jnp.stack([f0, f1, f2])

In [320]:
@jax.jit
def main(k1):
    lorenz = Lorenz(k1)
    terms = diffrax.ODETerm(lorenz)
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    solver = diffrax.Tsit5()
    saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
    stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-3)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        #stepsize_controller=stepsize_controller,
    )
    return sol

In [321]:
main(28.0)


  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),


Solution(
  t0=f32[],
  t1=f32[],
  ts=f32[1],
  ys=f32[1,3],
  interpolation=None,
  stats={
    'compiled_num_steps':
    None,
    'max_steps':
    i32[],
    'num_accepted_steps':
    i32[],
    'num_rejected_steps':
    i32[],
    'num_steps':
    i32[]
  },
  result=i32[],
  solver_state=None,
  controller_state=None,
  made_jump=None
)

In [323]:

start = time.time()
sol = main(28.0)
end = time.time()

print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
    print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")

Results:
t=1.0, y=[-9.408454895019531, -9.096183776855469, 28.581666946411133]
Took 1001 steps in 0.051194190979003906 seconds.


In [324]:
numberOfParameters = 768000
parameterList = jnp.linspace(0.0,21.0,numberOfParameters)

In [325]:
parameterList

DeviceArray([0.0000000e+00, 2.7343785e-05, 5.4687571e-05, ...,
             2.0999945e+01, 2.0999973e+01, 2.1000000e+01], dtype=float32)

In [328]:
%timeit main(28.0)

44.8 ms ± 426 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [329]:
out = jax.vmap(main)(parameterList)

In [330]:
%timeit jax.vmap(main)(parameterList)

975 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [334]:
@jax.jit
def solve(prob):
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    sol = diffrax.diffeqsolve(
        prob,
        diffrax.Tsit5(),
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        #stepsize_controller=stepsize_controller,
    )
    return sol

In [335]:
prob = diffrax.ODETerm(Lorenz(28.0))

In [336]:
prob

ODETerm(vector_field=Lorenz(k1=28.0))

In [337]:
solve(prob)

  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),


Solution(
  t0=f32[],
  t1=f32[],
  ts=f32[1],
  ys=f32[1,3],
  interpolation=None,
  stats={
    'compiled_num_steps':
    None,
    'max_steps':
    i32[],
    'num_accepted_steps':
    i32[],
    'num_rejected_steps':
    i32[],
    'num_steps':
    i32[]
  },
  result=i32[],
  solver_state=None,
  controller_state=None,
  made_jump=None
)

In [338]:
%timeit solve(prob)

44.6 ms ± 80.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [342]:
@jax.jit
def main(k1):
    lorenz = Lorenz(k1)
    terms = diffrax.ODETerm(lorenz)
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    solver = diffrax.Tsit5()
    saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        stepsize_controller=stepsize_controller,
    )
    return sol

In [343]:
%timeit main(28.0)

  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),


11.2 ms ± 64.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [345]:
%timeit jax.vmap(main)(parameterList)

254 ms ± 320 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
