In [4]:
import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

print("numpy")
x_np = np.ones((500, 500), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime
# print("transfer")
# %time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
x_jax = jnp.asarray(x_np)
f_jit = jax.jit(f)
print("compile")

%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
print("compiled run")
%time f_jit(x_jax).block_until_ready()  # measure JAX runtime

numpy
The slowest run took 4.64 times longer than the fastest. This could mean that an intermediate result is being cached.
3.91 ms ± 2.02 ms per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
compile
CPU times: user 288 ms, sys: 645 ms, total: 933 ms
Wall time: 149 ms
compiled run
CPU times: user 10.5 ms, sys: 4.42 ms, total: 14.9 ms
Wall time: 2.93 ms


DeviceArray([[-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             [-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             [-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             ...,
             [-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             [-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
             [-2.3748726e-05, -2.3748726e-05, -2.3748726e-05, ...,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00]],            dtype=float32)