In [1]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys

In [2]:
numberOfParameters = int(sys.argv[1])

ValueError: invalid literal for int() with base 10: '-f'

In [3]:
from jax.lib import xla_bridge
print("Working on :", xla_bridge.get_backend().platform)

Working on : gpu


In [4]:
jax.config.update("jax_enable_x64", True)


In [5]:
y0 = jnp.array([1.0, 0.0, 0.0])

In [6]:
class Robertson(eqx.Module):
    k1: float
    k2: float
    k3: float

    def __call__(self, t, y, args):
        f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]
        f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]
        f2 = self.k2 * y[1] ** 2
        return jnp.stack([f0, f1, f2])

In [38]:
@jax.jit
def main(k3):
    robertson = Robertson(0.04, 3e7, k3)
    terms = diffrax.ODETerm(robertson)
    t0 = 0.0
    t1 = 1e5
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.0002
    solver = diffrax.Kvaerno3()
    saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
    stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
#         max_steps=None
    )
    return sol

In [39]:
# main(1e4)

# start = time.time()
# sol = main(1e4)
# end = time.time()

# print("Results:")
# for ti, yi in zip(sol.ts, sol.ys):
#     print(f"t={ti.item()}, y={yi.tolist()}")
# print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")

In [47]:
numberOfParameters = 768000
parameterList = jnp.linspace(10.0,1e4,numberOfParameters)

In [48]:
# a = jax.numpy.array([[ 1.01290589e-03,  2.75272126e-05, -2.69166597e-04,
#         -5.58780779e-06],
#        [ 2.75272126e-05,  1.34740128e-03, -4.34192721e-06,
#         -3.00849575e-04],
#        [-2.69166597e-04, -4.34192721e-06,  1.28766222e-04,
#          7.41944929e-07],
#        [-5.58780779e-06, -3.00849575e-04,  7.41944929e-07,
#          7.99537441e-05]])

# jax.numpy.linalg.cholesky(a)   # NOk

In [49]:
parameterList

Array([   10.        ,    10.01300783,    10.02601566, ...,
        9999.97398434,  9999.98699217, 10000.        ], dtype=float64)

In [50]:
import timeit

In [51]:
sol = jax.vmap(main)(parameterList)

In [52]:
sol.stats

{'max_steps': Array([4096, 4096, 4096, ..., 4096, 4096, 4096], dtype=int64, weak_type=True),
 'num_accepted_steps': Array([56, 56, 56, ..., 63, 63, 63], dtype=int64, weak_type=True),
 'num_rejected_steps': Array([0, 0, 0, ..., 0, 0, 0], dtype=int64, weak_type=True),
 'num_steps': Array([56, 56, 56, ..., 63, 63, 63], dtype=int64, weak_type=True)}

In [53]:
%timeit jax.vmap(main)(parameterList)

1.58 s ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [65]:
data = jax.vmap(main)(parameterList)

In [25]:
data.ys

Array([[[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [2.10665622e-08, 8.42450306e-11, 9.99999979e-01]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.11241824e-07, 1.93206882e-10, 9.99999889e-01]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [2.72473913e-07, 3.02256787e-10, 9.99999727e-01]],

       ...,

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.77834458e-02, 7.25947078e-08, 9.82216482e-01]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.78246728e-02, 7.26711608e-08, 9.82175254e-01]],

       [[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.78659329e-02, 7.27475660e-08, 9.82133994e-01]]],      dtype=float64, weak_type=True)

In [35]:
res = timeit.repeat(lambda: jax.vmap(main)(parameterList),repeat = 100,number = 1)

In [44]:
best_time = min(res)*1000

In [46]:
file = open("./data/Stiff_ODE/Jax_times_adaptive.txt","a+")
file.write('{0} {1}\n'.format(numberOfParameters, best_time))
file.close()