In [1]:
import jax
from jax import jit
jax.config.update('jax_num_cpu_devices', 12)

import jax.numpy as jnp

y = 0

@jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(2, 3):
  print("Result:", impure_func(y))

Inside: 2
Result: 4


In [2]:
class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul
  @jit
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y
  def _tree_flatten(self):
    children = (self.x,)  # arrays / dynamic values
    aux_data = {'mul': self.mul}  # static values
    return (children, aux_data)
  @classmethod
  def _tree_unflatten(cls, aux_data, children):
    return cls(*children, **aux_data)

from jax import tree_util
tree_util.register_pytree_node(CustomClass,
                               CustomClass._tree_flatten,
                               CustomClass._tree_unflatten)

In [3]:
from jax import numpy as jnp
print(jnp.ones(3).devices())

{CpuDevice(id=0)}


In [4]:
import jax

from jax import device_put

arr = device_put(1, jax.devices()[5])
print(arr.devices())

{CpuDevice(id=5)}


In [5]:
import numpy as np
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

# measure JAX device transfer time
%time x_jax = jax.device_put(x_np).block_until_ready()

f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

10.7 ms ± 839 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
CPU times: total: 0 ns
Wall time: 0 ns
CPU times: total: 0 ns
Wall time: 56.1 ms
6.54 ms ± 263 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
from jax import numpy as jnp
(jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])>0)

Array([False, False, False,  True,  True], dtype=bool)

In [7]:
import jax
import numpy as np
import jax.numpy as jnp

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3))) # 或者写成 x = jnp.array(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)