In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import asyncio
import functools
import tempfile
import time
import os
from uuid import uuid4
import pathlib
import uuid
from typing import cast, Sequence
from concurrent.futures import ThreadPoolExecutor
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import numpy as np
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.shard_map import shard_map
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import broadcast_one_to_all
from jax._src.array import ArrayImpl
from jax._src import array
from jax.core import ShapedArray, ConcreteArray
from jax import sharding
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import multihost_utils as hutils
from jaxlib import xla_client as xc

import jaxfi as jaxm
from wrap_torch2jax import torch2jax_with_vjp

In [19]:
from jax.experimental.array_serialization import serialization_future

global_mesh = Mesh(np.array(jax.devices("cpu")).reshape((4, 2)), ("x", "y"))
global_input_shape = (8, 2)
mesh_axes = P("x", "y")
shard = NamedSharding(global_mesh, mesh_axes)
array1 = jax.device_put(jaxm.randn(global_input_shape), shard)

temp_ckpt_dir = tempfile.TemporaryDirectory().name
fut = serialization_future.nonblocking_save(array1, temp_ckpt_dir)
fut.pytree == jax.ShapeDtypeStruct(shape=(8, 2), dtype=np.float32)
while not fut.done():
  time.sleep(0.01)

In [20]:
jax.debug.visualize_array_sharding(array1)

In [25]:
fut = serialization_future.nonblocking_load(
    temp_ckpt_dir, shardings=[NamedSharding(global_mesh, mesh_axes)])
fut.pytree == jax.ShapeDtypeStruct(shape=(8, 2), dtype=np.float32)
while not fut.done():
  time.sleep(0.01)
m1 = fut.result()
assert m1.addressable_shards[0].data.shape == (2, 1)
assert m1.dtype == np.float32

['float32[8, 2] -> 556db1c3-5075-4847-bf58-6af82cf00de5']


In [26]:
from pathlib import Path
from etils.epath import Path as ePath
print(type(Path("hello").__fspath__()))
print(type(ePath("gs://rdyro-mixtral-storage").__fspath__()))
print(Path("hello").__fspath__())
print(ePath("gs://rdyro-mixtral-storage").__fspath__())

<class 'str'>
<class 'str'>
hello
gs://rdyro-mixtral-storage


In [4]:
from jax.experimental.array_serialization.new_api import load_pytree, save, load
from jax.experimental.array_serialization.serialization_future import nonblocking_load, nonblocking_save
import jaxfi as jaxm

In [21]:
with tempfile.TemporaryDirectory() as tmpdir:
  fut = nonblocking_save(jaxm.randn((100, 100)), tmpdir)
  assert fut.pytree == jax.ShapeDtypeStruct((100, 100), jnp.float32)
  print(fut.pytree)

ShapeDtypeStruct(shape=(100, 100), dtype=float32)


In [9]:
r = jax.random.normal(jax.random.key(int(time.time())), (100, 100))
fut = nonblocking_save({"hello": 1, "b": [r, r]}, "test_checkpoint")
print(fut.pytree)

{'b': [ShapeDtypeStruct(shape=(100, 100), dtype=float32), ShapeDtypeStruct(shape=(100, 100), dtype=float32)], 'hello': 1}


In [10]:
fut = nonblocking_load("test_checkpoint")
print(fut.pytree)

['float32[100, 100] -> 0e5629f4-abef-496c-bff2-1ce20f693062', 'float32[100, 100] -> 2cb4f0b0-f196-4efb-91b5-a2b6263eb1a1', 'int -> dcbd494e-ce1e-4f0f-9b14-1af663fbbe91']
{'b': [ShapeDtypeStruct(shape=(100, 100), dtype=float32), ShapeDtypeStruct(shape=(100, 100), dtype=float32)], 'hello': 'int'}


In [11]:
fut.pytree

{'b': [ShapeDtypeStruct(shape=(100, 100), dtype=float32),
  ShapeDtypeStruct(shape=(100, 100), dtype=float32)],
 'hello': 'int'}

In [12]:
fut.result()

{'b': [Array([[-0.08076067, -0.6227396 ,  0.22723912, ..., -0.86740726,
          -0.80425376,  0.8863032 ],
         [ 0.5787659 ,  1.218969  ,  0.3317969 , ..., -0.42311195,
          -0.43619475, -0.7986426 ],
         [ 0.64526623, -0.5652721 , -0.9524942 , ...,  0.10780644,
          -2.1995645 , -0.3796637 ],
         ...,
         [ 1.5891613 , -1.6145512 ,  0.62276554, ..., -0.4588255 ,
          -1.3757391 , -1.32873   ],
         [-0.10977724,  1.0510896 , -1.5869526 , ...,  0.66006345,
           1.0710827 ,  0.41363144],
         [-1.3962336 ,  0.42232367, -3.105433  , ...,  0.28272235,
          -1.5353185 ,  0.23680422]], dtype=float32),
  Array([[-0.08076067, -0.6227396 ,  0.22723912, ..., -0.86740726,
          -0.80425376,  0.8863032 ],
         [ 0.5787659 ,  1.218969  ,  0.3317969 , ..., -0.42311195,
          -0.43619475, -0.7986426 ],
         [ 0.64526623, -0.5652721 , -0.9524942 , ...,  0.10780644,
          -2.1995645 , -0.3796637 ],
         ...,
         [ 1.5

In [4]:
async def hello():
  print("Hello World")
  return

def sync_hello():
  print("Hello World")
  return

In [21]:
fut = asyncio.ensure_future(hello())

Hello World


In [15]:
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
    futures.append(executor.submit(hello))
    futures[0].done()

In [17]:
futures[0].result()

<coroutine object hello at 0x7f65e413b400>

In [12]:
print(asyncio.iscoroutine(hello()))
print(asyncio.iscoroutine(sync_hello()))

True
Hello World
False


  print(asyncio.iscoroutine(hello()))


In [3]:
def torch_fn(x):
  time.sleep(10)
  return x

inp = jax.ShapeDtypeStruct((100, 100), jnp.float32)
fn = torch2jax_with_vjp(torch_fn, inp, output_shapes=inp)

In [5]:
shape = (1600, 100)
devices = jax.devices("cpu")
mesh = Mesh(create_device_mesh((len(devices),), devices), "x")
shard = jax.NamedSharding(mesh, P(mesh.axis_names))

def long_running(shape_):
  time.sleep(2)
  return np.random.randn(*shape_).astype(np.float32)

@functools.partial(jax.jit, static_argnums=(0,))
def long_running_jit(shape_):
  #a = jnp.ones(round(1e1))
  #a = a ** 2
  b = jax.pure_callback(lambda shape: long_running(shape), 
                        jax.ShapeDtypeStruct(shape_, jnp.float32), shape_)
  #return a[0] + b
  return jnp.ones((), dtype=jnp.int8) + b


t = time.time()
c = array.make_array_from_callback(
  shape, shard, lambda i: long_running_jit((shape[0] // len(devices),) + shape[1:]))
t = time.time() - t
print(f"Call time: {t:.4e} s")

t = time.time()
d = 2 * c
t = time.time() - t
print(f"Calc time: {t:.4e} s")

t = time.time()
c.block_until_ready()
t = time.time() - t
print(f"Block time: {t:.4e} s")

Call time: 1.6790e-02 s
Calc time: 1.2831e-02 s
Block time: 1.5994e+01 s


In [103]:
type(c)

jaxlib.xla_extension.ArrayImpl

In [104]:
t = time.time()
c.block_until_ready()
t = time.time() - t

: 

: 

In [83]:
arg = (200, 10)
print(long_running_jit.lower(arg).cost_analysis()["flops"])
t = time.time()
a = long_running_jit(arg)
t = time.time() - t
print(f"Call time: {t:.4e}")
print("Done already")
a.block_until_ready()

1999.0
Call time: 9.3753e-03
Done already


Array([[ 0.1762979 ,  1.0097841 ,  0.9234634 , ...,  2.679797  ,
         1.3891101 ,  0.24603891],
       [-0.11763883,  2.5261579 ,  1.4714091 , ...,  0.09255052,
        -0.09681976,  0.3977055 ],
       [ 0.9161737 ,  1.7152143 ,  0.70685816, ...,  0.3879246 ,
         2.891816  ,  1.4030812 ],
       ...,
       [ 0.9910855 ,  1.4816061 ,  0.16083038, ...,  1.5364587 ,
         1.5826927 ,  0.04808569],
       [-1.573318  , -0.19487882,  0.17291683, ...,  2.2511487 ,
        -0.3745637 ,  3.6890392 ],
       [-1.6377559 ,  1.0804694 , -0.1147145 , ...,  0.05407035,
         0.81654763, -1.6515357 ]], dtype=float32)

In [84]:
print(long_running_jit.lower((200, 10)).as_text())
print(long_running_jit.lower((200, 10)).cost_analysis())

module @jit_long_running_jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<200x10xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<200> : tensor<i32>
    %c_0 = stablehlo.constant dense<10> : tensor<i32>
    %c_1 = stablehlo.constant dense<93959129167808> : tensor<i64>
    %0 = stablehlo.custom_call @xla_python_cpu_callback(%c_1, %c, %c_0) {api_version = 2 : i32, backend_config = "93959129167808", mhlo.sharding = "{maximal device=0}", operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<i64>, tensor<i32>, tensor<i32>) -> tuple<tensor<200x10xf32>>
    %1 = stablehlo.get_tuple_element %0[0] : (tuple<tensor<200x10xf32>>) -> tensor<200x10xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>

In [37]:
def high_flops(a, b):
  time.sleep(10.0) 
  c = a ** 4 - b
  return a

shape = (1000, 1000)
a, b = jnp.ones(shape), jnp.zeros(shape)
fn = jax.jit(high_flops)
print(fn.lower(a, b).cost_analysis())

#fn(jnp.zeros(shape))

{}


In [213]:
jaxpr = jax.make_jaxpr(high_flops)(3.0)
print(jaxpr.in_avals)
print(jaxpr.eqns)
print(jaxpr)

[ShapedArray(float32[], weak_type=True)]
[a:f32[] = integer_pow[y=2] b]
{ lambda ; a:f32[]. let b:f32[] = integer_pow[y=2] a in (b,) }


In [177]:
@jax.jit
def f():
  return fn(jnp.zeros((100, 100)))

In [178]:
f()

Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [42]:
def long_running(x):
    el = x.reshape(-1)[0]
    return jax.pure_callback(lambda x: (time.sleep(10), x)[1], 
                              jax.ShapeDtypeStruct((), dtype=el.dtype), el)

@jax.jit
def high_flops(a, b):
  #if not isinstance(a, core.Tracer):
  #time.sleep(10.0) 
  c = a ** 4 - b + long_running(a)
  #return long_running(a)
  return c

shape = (100, 100)
a, b = jnp.ones(shape), jnp.zeros(shape)
#fn = jax.jit(high_flops)
fn = high_flops
print(fn.lower(a, b).cost_analysis())

t = time.time()
c = fn(a, b)
t = time.time() - t
print(f"{t = :.4e}")
t = time.time()
c.block_until_ready()
t = time.time() - t
print(f"{t = :.4e}")

{'flops': 39999.0, 'bytes accessed': 600019.0, 'optimal_seconds': -1.0, 'utilization0{}': 9.000100135803223, 'utilization1{}': 5.0, 'bytes accessed0{}': 200011.0, 'bytes accessed1{}': 159999.0, 'bytes accessedout{}': 240007.0}
t = 1.4348e-02
t = 1.0001e+01


In [220]:
from jaxlib.mlir import mhlo

ImportError: cannot import name 'mhlo' from 'jaxlib.mlir' (unknown location)

# Making a array with async dispatch

In [129]:
shape, dtype = (100, 100), jnp.float32
data = np.random.randn(*shape).astype(dtype)
shapedarray = ShapedArray(shape, dtype)
shard = sharding.SingleDeviceSharding(jax.devices("cpu")[0])

In [137]:
def hcb(s):
  #await asyncio.sleep(10)
  time.sleep(10)
  return np.random.randn(*(100, 100)).astype(np.float32)

@jax.jit
def make_array():
  return jax.pure_callback(hcb, jax.ShapeDtypeStruct((100, 100), jnp.float32), jnp.ones(()))

In [138]:
comp = make_array.lower().compile()

In [141]:
a = comp()

In [223]:
from jaxlib.mlir.dialects import mhlo

In [126]:
a = make_array()

In [75]:
@jax.jit
def fn(x):
  return x ** 5.5

In [221]:
#from jax.interpreters import mhlo as hlo
#mhlo.SendOp

ImportError: cannot import name 'mhlo' from 'jax.interpreters' (/home/rdyro/Projects/jax/jax/interpreters/__init__.py)

In [146]:
type(jnp.zeros(100))

jaxlib.xla_extension.ArrayImpl

In [148]:
from jaxlib.xla_extension import ArrayImpl

In [154]:
from jaxlib import xla_extension as xe

In [None]:
jax.Array

In [144]:
jax.Array(shape, jnp.float32)


TypeError: Array() takes no arguments

In [80]:
from jax.interpreters import xla

xla.apply_primitive()

In [124]:
def dcb(index):
  time.sleep(10)
  return np.random.randn(*(100, 100)).astype(np.float32)

array.make_array_from_callback(shape, shard, dcb)

Array([[ 0.5237842 , -0.13295816,  0.8163282 , ...,  1.0272064 ,
        -1.0020953 , -2.2172556 ],
       [-0.21399932, -0.82629573,  0.28917634, ...,  1.138213  ,
        -1.4303926 ,  0.22079563],
       [ 0.7029308 ,  0.76928127, -0.67187524, ..., -0.7529223 ,
         0.07898451,  0.49682337],
       ...,
       [ 1.1057003 , -0.51633984, -0.6576222 , ..., -0.99347425,
         0.69257563,  1.2636473 ],
       [-0.4952039 , -1.3495808 ,  0.14881878, ..., -2.272792  ,
        -1.4284267 , -0.7063522 ],
       [ 0.5917428 ,  0.58725923,  1.2979989 , ..., -0.978442  ,
         0.14061783,  2.0366719 ]], dtype=float32)

In [121]:
ArrayImpl(shapedarray, shard, cast(Sequence[ArrayImpl], [jnp.asarray(data)]), committed=True, _skip_checks=True)

Array([[-0.8177348 ,  1.6091626 , -0.77163905, ...,  0.4970985 ,
         1.2692783 ,  1.2193668 ],
       [ 0.71688765,  0.7734844 ,  0.6703563 , ...,  1.3276428 ,
         0.24779364,  0.6105239 ],
       [-2.158977  ,  1.0940106 ,  0.18226077, ...,  1.6948227 ,
        -0.19424263, -0.732435  ],
       ...,
       [-0.06510899, -0.9032963 , -1.3804721 , ...,  1.3381408 ,
        -0.06489277, -0.02748343],
       [ 0.28411415,  0.38746634, -0.03548809, ...,  0.06469295,
        -1.0033218 ,  0.3748052 ],
       [-0.4197347 ,  0.5445934 , -0.03203459, ...,  1.2798468 ,
        -0.62778944,  2.0478637 ]], dtype=float32)

# Communicating between hosts

In [56]:
vals = broadcast_one_to_all(np.frombuffer(uuid4().bytes, dtype=np.uint32))

In [62]:
str(uuid.UUID(bytes=vals.tobytes()))

'3f29ff2d-0240-4753-8e15-0ebac623fa6b'

In [59]:
u.bytes = vals.tobytes()

TypeError: UUID objects are immutable

In [None]:
u

In [25]:
d = create_device_mesh((8,), jax.devices("cpu"))

In [None]:
@functools.partial(shard_map, mesh=create_device_mesh(jax.devices("cpu")

In [18]:
@functools.partial(jax.pmap, axis_name="x")
def get_unique_val(salt: None):
  if salt is None:
    salt = jnp.ones((jax.process_count(),))
  key = jax.pure_callback(lambda: np.array(time.time_ns() % 2 ** 31, dtype=np.int32), 
                          jax.ShapeDtypeStruct((), jnp.int32))
  key = key + salt.astype(jnp.int32)
  key_all = lax.broadcast(key, (jax.process_count(),))
  return key_all

In [19]:
get_unique_val()

ValueError: pmap wrapped function must be passed at least one argument containing an array, got empty *args=() and **kwargs={}