In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Multiplying Matrices
We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers).

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


Let’s dive right in and multiply two big matrices.

In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

10.5 ms ± 477 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


We added that block_until_ready because JAX uses asynchronous execution by default (see Asynchronous dispatch).

JAX NumPy functions work on regular NumPy arrays.

In [5]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

46.6 ms ± 6.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


That’s slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put().

In [7]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

9.28 ms ± 318 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


The output of device_put() still acts like an NDArray, but it only copies values back to the CPU when they’re needed for printing, plotting, saving to disk, branching, etc. The behavior of device_put() is equivalent to the function jit(lambda x: x), but it’s faster.