In [1]:
from __future__ import annotations
import time

#from jaxfi import jaxm
import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array

In [5]:
def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
    assert via in ("dlpack", "cpu")
    if isinstance(x, Array):
        if via == "dlpack":
            return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
        else:
            return torch.as_tensor(np.array(x), device=device)
    else:
        if via == "dlpack":
            return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        else:
            return jax.device_put(jax.numpy.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])


In [7]:
#transfer(jaxm.rand((10), device="cuda"), device="cuda")
transfer(torch.randn(10, device="cuda"), via="cpu", device="cuda")

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [82]:
def time_transfer(*args, **kw):
    trials = 3
    ret = transfer(*args, **kw)
    t = time.time()
    for _ in range(trials):
        ret = transfer(*args, **kw)
    t = time.time() - t
    print(f"{type(args[0])} {kw} = {t / trials:.4e} s")
    return ret

In [86]:
shape = (10 ** 4, 10 ** 4)
for via in ["dlpack", "cpu"]:
    for device in ["cpu", "cuda"]:
        for x in [jaxm.ones(shape, device=device), torch.randn(shape, device=device)]:
            x2 = time_transfer(x, via=via, device=device)
            if isinstance(x, Array):
                expected_device = torch.device("cuda:0" if device == "cuda" else device)
                assert isinstance(x2, Tensor) and x2.device == expected_device
                trials, t = 3, time.time()
                for _ in range(trials):
                    torch.sum(x2)
                t = time.time() - t
                print(f"Operation on took {t / trials:.4e} s")
            else:
                assert isinstance(x2, Array) and x2.device() == jaxm.resolve_device(device)
                trials, t = 3, time.time()
                for _ in range(trials):
                    jaxm.sum(x2).block_until_ready()
                t = time.time() - t
                print(f"Operation on took {t / trials:.4e} s")

<class 'jaxlib.xla_extension.ArrayImpl'> {'via': 'dlpack', 'device': 'cpu'} = 9.4573e-06 s
Operation on took 2.0425e-02 s
<class 'torch.Tensor'> {'via': 'dlpack', 'device': 'cpu'} = 5.0863e-05 s
Operation on took 1.1913e-01 s
<class 'jaxlib.xla_extension.ArrayImpl'> {'via': 'dlpack', 'device': 'cuda'} = 1.1444e-05 s
Operation on took 3.2028e-05 s
<class 'torch.Tensor'> {'via': 'dlpack', 'device': 'cuda'} = 4.5458e-05 s
Operation on took 8.4257e-04 s
<class 'jaxlib.xla_extension.ArrayImpl'> {'via': 'cpu', 'device': 'cpu'} = 8.4309e-02 s
Operation on took 1.9971e-02 s
<class 'torch.Tensor'> {'via': 'cpu', 'device': 'cpu'} = 4.1650e-01 s
Operation on took 8.3929e-02 s
<class 'jaxlib.xla_extension.ArrayImpl'> {'via': 'cpu', 'device': 'cuda'} = 6.9831e-02 s
Operation on took 6.6837e-05 s
<class 'torch.Tensor'> {'via': 'cpu', 'device': 'cuda'} = 1.5130e-01 s
Operation on took 1.9184e-02 s


In [36]:
jax.ShapedArray(jax.numpy.zeros(10), dtype=jax.numpy.float32)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].

In [32]:
from __future__ import annotations

import numpy as np
import jax.dlpack
import torch
from torch import Tensor
import torch.utils.dlpack
from jax import Array

def transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):
    assert via in ("dlpack", "cpu")
    if isinstance(x, Array):
        if via == "dlpack":
            return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
        else:
            return torch.as_tensor(np.array(x), device=device)
    else:
        if via == "dlpack":
            return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
        else:
            return jax.device_put(jax.array(x.detach().cpu().numpy()), device=jax.devices(device)[0])

def f(x: Array) -> Array:
    device = "cuda"
    z = x
    def g(y):
        nonlocal z
        return transfer(torch.sum(transfer(y, via="dlpack", device=device)), via="dlpack", device=device)

    return jax.debug.callback(g, x)

r = jax.device_put(jax.numpy.ones(1000).astype(jax.numpy.float32), jax.devices("cuda")[0])

try:
    print(f(r))
    print("Untraced JAX function did work.")
except:
    print("Untraced JAX function did NOT work.")

try:
    print(jax.jit(f)(r))
    print("Traced JAX function did work.")
except:
    print("Traced JAX function did NOT work.")

[]
Untraced JAX function did work.
Traced JAX function did NOT work.


2023-06-21 14:18:19.277519: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2432] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: RuntimeError: data must be a Tensor

At:
  /tmp/ipykernel_942670/4080544319.py(19): transfer
  /tmp/ipykernel_942670/4080544319.py(28): g
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/_src/debugging.py(227): _flat_callback
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/_src/debugging.py(85): debug_callback_impl
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/_src/debugging.py(146): _callback
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(1810): _wrapped_callback
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1908): __call__
  /home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/

In [43]:
fn = jax.jit(f)

In [48]:
out = fn.lower(jax.numpy.zeros(5))

In [52]:
out2 = out.compiler_ir()

In [62]:
from jax.lib import xla_client

In [65]:
help(xla_client.ops.CustomCall)

Help on built-in function CustomCall in module jaxlib.xla_extension.ops:

CustomCall(...) method of builtins.PyCapsule instance
    CustomCall(builder: jaxlib.xla_extension.XlaBuilder, call_target_name: bytes, operands: Span[jaxlib.xla_extension.XlaOp], shape: jaxlib.xla_extension.Shape, opaque: bytes = b'', has_side_effect: bool = False, schedule: jaxlib.xla_extension.ops.CustomCallSchedule = <CustomCallSchedule.SCHEDULE_NONE: 0>, api_version: jaxlib.xla_extension.ops.CustomCallApiVersion = <CustomCallApiVersion.API_VERSION_ORIGINAL: 1>) -> jaxlib.xla_extension.XlaOp



In [66]:
fn.to_pycapsule()

AttributeError: 'PjitFunction' object has no attribute 'to_pycapsule'

In [61]:
print(out.as_text())

module @jit_f {
  func.func public @main(%arg0: tensor<5xf64> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) {
    %0 = stablehlo.constant dense<94202629491968> : tensor<i64>
    %1 = stablehlo.custom_call @xla_python_gpu_callback(%0, %arg0) {api_version = 2 : i32, backend_config = "94202629491968", has_side_effect = true, mhlo.sharding = "{maximal device=0}", operand_layouts = [dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>], result_layouts = []} : (tensor<i64>, tensor<5xf64>) -> tuple<>
    return
  }
}



In [147]:
print(jaxm.jit(f)(r))

<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


TypeError: Argument to to_dlpack must be a jax.Array, got <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

# custom interpreters

In [9]:
from jax.interpreters import mlir, partial_eval as pe

In [12]:
help(pe.abstract_eval_fun)

Help on function abstract_eval_fun in module jax._src.interpreters.partial_eval:

abstract_eval_fun(fun, *avals, debug_info=None, **params)



In [13]:
help(jax.interpreters.partial_eval)

Help on module jax.interpreters.partial_eval in jax.interpreters:

NAME
    jax.interpreters.partial_eval

DESCRIPTION
    # Copyright 2018 The JAX Authors.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     https://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

DATA
    AbstractedAxesSpec = typing.Union[typing.Dict[int, typing.Hashable], t...
    AbstractedAxisName = typing.Hashable
        A generic version of collections.abc.Hashable.
    
    Const = typing.Any
        Special type indicating an uncon

In [17]:
from jax.experimental import callback
help(callback)

ImportError: cannot import name 'callback' from 'jax.experimental' (/home/rdyro/.pyenv/versions/devel/lib/python3.9/site-packages/jax/experimental/__init__.py)

In [16]:
jax.experimental

AttributeError: module 'jax.experimental' has no attribute 'callback'

In [18]:
from jax._src import core

In [20]:
help(core.raise_to_shaped)

Help on function raise_to_shaped in module jax._src.core:

raise_to_shaped(aval: 'AbstractValue', weak_type=None)



In [21]:
io_callback_p = core.Primitive("io_callback")

In [26]:
io_callback_p.get_bind_params()

TypeError: get_bind_params() missing 1 required positional argument: 'params'

In [28]:
help(mlir.emit_python_callback)

Help on function emit_python_callback in module jax._src.interpreters.mlir:

emit_python_callback(ctx: 'LoweringRuleContext', callback, token: 'Optional[Any]', operands: 'List[ir.Value]', operand_avals: 'List[core.ShapedArray]', result_avals: 'List[core.ShapedArray]', has_side_effect: 'bool', *, sharding: 'Optional[xc.OpSharding]' = None, operand_layouts: 'Optional[Sequence[Optional[Sequence[int]]]]' = None, result_layouts: 'Optional[Sequence[Optional[Sequence[int]]]]' = None) -> 'Tuple[List[ir.Value], Any, Any]'
    Emits MLIR that calls back to a provided Python function.

