In [1]:
%load_ext autoreload
%autoreload 2

import functools
import time

import jax
from jax import numpy as jnp
from jax._src import array
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.mesh_utils import create_device_mesh
import numpy as np

In [2]:
shape = (1600, 100)
devices = jax.devices("cpu")
mesh = Mesh(create_device_mesh((len(devices),), devices), "x")
shard = jax.NamedSharding(mesh, P(mesh.axis_names))

def long_running(shape_):
  time.sleep(2)
  return np.random.randn(*shape_).astype(np.float32)

@functools.partial(jax.jit, static_argnums=(0,))
def long_running_jit(shape_):
  #a = jnp.ones(round(1e1))
  #a = a ** 2
  b = jax.pure_callback(lambda shape: long_running(shape), 
                        jax.ShapeDtypeStruct(shape_, jnp.float32), shape_)
  #return a[0] + b
  return jnp.ones((), dtype=jnp.int8) + b

#print(f"flops = {long_running_jit.lower(shape).cost_analysis()['flops']}")

with jax.default_device(jax.devices("cpu")[0]):
  print(list(jnp.zeros(()).devices())[0].platform)
  if list(jnp.zeros(()).devices())[0].platform not in ("cuda", "gpu"):
    print(f"flops = {long_running_jit.lower(shape).cost_analysis()['flops']}")
  t = time.time()
  c = array.make_array_from_callback(
    shape, shard, lambda i: long_running_jit((shape[0] // len(devices),) + shape[1:]))
  t = time.time() - t
  print(f"Call time: {t:.4e} s")

  t = time.time()
  d = 2 * c
  t = time.time() - t
  print(f"Calc time: {t:.4e} s")

  t = time.time()
  c.block_until_ready()
  t = time.time() - t
  print(f"Block time: {t:.4e} s")

cpu
flops = 159999.0
Call time: 1.5227e-02 s
Calc time: 1.4534e-02 s
Block time: 1.9945e+00 s
