In [1]:
import jax
import jax.numpy as jnp

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')

x = jnp.square(2)
print(repr(x.device_buffer.device()))  # CpuDevice(id=0)

CpuDevice(id=0)


In [2]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import numpy as np

In [3]:
from matplotlib import pyplot as  plt

In [4]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


In [5]:
#jax.config.update("jax_enable_x64", True)

In [6]:
jnp.array([1.0],dtype = float)

DeviceArray([1.], dtype=float32)

In [7]:
class Lorenz(eqx.Module):
    k1: float

    def __call__(self, t, y, args):
        f0 = 10.0*(y[1] - y[0])
        f1 = self.k1 * y[0] - y[1] - y[0] * y[2]
        f2 = y[0] * y[1] - (8/3)*y[2]
        return jnp.stack([f0, f1, f2])

In [8]:
@jax.jit
def main(k1):
    lorenz = Lorenz(k1)
    terms = diffrax.ODETerm(lorenz)
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    solver = diffrax.Tsit5()
    saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
    stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-3)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        #stepsize_controller=stepsize_controller,
    )
    return sol

In [9]:
main(28.0)


  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),


Solution(
  t0=f32[],
  t1=f32[],
  ts=f32[1],
  ys=f32[1,3],
  interpolation=None,
  stats={
    'compiled_num_steps':
    None,
    'max_steps':
    i32[],
    'num_accepted_steps':
    i32[],
    'num_rejected_steps':
    i32[],
    'num_steps':
    i32[]
  },
  result=i32[],
  solver_state=None,
  controller_state=None,
  made_jump=None
)

In [10]:

start = time.time()
sol = main(28.0)
end = time.time()

print("Results:")
for ti, yi in zip(sol.ts, sol.ys):
    print(f"t={ti.item()}, y={yi.tolist()}")
print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.")

Results:
t=1.0, y=[-9.408472061157227, -9.096207618713379, 28.58167266845703]
Took 1001 steps in 0.0003418922424316406 seconds.


In [11]:
numberOfParameters = 768000
parameterList = jnp.linspace(0.0,21.0,numberOfParameters)

In [12]:
parameterList

DeviceArray([0.0000000e+00, 2.7343785e-05, 5.4687571e-05, ...,
             2.0999945e+01, 2.0999973e+01, 2.1000000e+01], dtype=float32)

In [13]:
%timeit main(28.0)

152 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [329]:
out = jax.vmap(main)(parameterList)

In [330]:
%timeit jax.vmap(main)(parameterList)

975 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [334]:
@jax.jit
def solve(prob):
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    sol = diffrax.diffeqsolve(
        prob,
        diffrax.Tsit5(),
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        #stepsize_controller=stepsize_controller,
    )
    return sol

In [335]:
prob = diffrax.ODETerm(Lorenz(28.0))

In [336]:
prob

ODETerm(vector_field=Lorenz(k1=28.0))

In [337]:
solve(prob)

  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),


Solution(
  t0=f32[],
  t1=f32[],
  ts=f32[1],
  ys=f32[1,3],
  interpolation=None,
  stats={
    'compiled_num_steps':
    None,
    'max_steps':
    i32[],
    'num_accepted_steps':
    i32[],
    'num_rejected_steps':
    i32[],
    'num_steps':
    i32[]
  },
  result=i32[],
  solver_state=None,
  controller_state=None,
  made_jump=None
)

In [23]:
%timeit sol = solve(prob)

NameError: name 'solve' is not defined

In [39]:
@jax.jit
def main(k1):
    lorenz = Lorenz(k1)
    terms = diffrax.ODETerm(lorenz)
    t0 = 0.0
    t1 = 1.0
    y0 = jnp.array([1.0, 0.0, 0.0])
    dt0 = 0.001
    solver = diffrax.Tsit5()
    saveat = diffrax.SaveAt(ts = jnp.array([t0,t1]))
    stepsize_controller = diffrax.PIDController(rtol=1e-9, atol=1e-8)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
#         saveat=saveat,
        stepsize_controller=stepsize_controller,
    )
    return sol

In [40]:
%timeit sol = main(28.0)

  [ts, save_index] + jax.tree_leaves(ys),
  lambda s: [s.ts, s.save_index] + jax.tree_leaves(s.ys),
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x7fda983e8fa0> threw exception The maximum number of solver steps was reached. Try increasing `max_steps`..


XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.

At:
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/diffrax/misc/errors.py(33): raises
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/experimental/host_callback.py(725): __call__
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/experimental/host_callback.py(1294): _outside_call_run_callback
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/experimental/host_callback.py(1163): wrapped_callback
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/interpreters/mlir.py(1567): _wrapped_callback
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/interpreters/mlir.py(1592): _wrapped_callback
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py(878): _execute_compiled
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py(237): _xla_call_impl
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/core.py(701): process_call
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/core.py(1955): call_bind
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/core.py(1939): bind
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/_src/api.py(606): cache_miss
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback
  <magic-timeit>(1): inner
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/magics/execution.py(156): timeit
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/magics/execution.py(1162): timeit
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2309): run_line_magic
  /state/partition1/slurm_tmp/20500750.0.0/ipykernel_57217/3263374720.py(1): <module>
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3378): run_code
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3318): run_ast_nodes
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3139): run_cell_async
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2940): _run_cell
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2885): run_cell
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/zmqshell.py(528): run_cell
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py(383): do_execute
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py(730): execute_request
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py(406): dispatch_shell
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py(499): process_one
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/asyncio/events.py(80): _run
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/asyncio/base_events.py(1896): _run_once
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/asyncio/base_events.py(600): run_forever
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/tornado/platform/asyncio.py(199): start
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel/kernelapp.py(712): start
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/traitlets/config/application.py(982): launch_instance
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/runpy.py(86): _run_code
  /home/gridsan/utkarsh/.conda/envs/venv/lib/python3.10/runpy.py(196): _run_module_as_main


In [None]:
sol.stats['num_steps']