In [12]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import warnings

warnings.filterwarnings(action='ignore')

### Multiplying Matrices
We’ll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see Common Gotchas in JAX.



In [3]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M1
[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

66.4 ms ± 18.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

77.3 ms ± 7.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

65.5 ms ± 432 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


f you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU. See Is JAX faster than NumPy? for more comparison of performance characteristics of NumPy and JAX

JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:

jit(), for speeding up your code

grad(), for taking derivatives

vmap(), for automatic vectorization or batching.

Let’s go over these, one-by-one. We’ll also end up composing these in interesting ways.

### Using jit() to speed up functions
JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit decorator to compile multiple operations together using XLA. Let’s try that.

In [13]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

2.41 ms ± 87 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


We can speed it up with @jit, which will jit-compile the first time selu is called and will be cached thereafter.

In [18]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

577 µs ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Taking derivatives with grad()
In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the grad() function.

In [19]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499358]


Let’s verify with finite differences that our result is correct.

In [20]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.2501011  0.1965761  0.10502338]


Taking derivatives is as easy as calling grad(). grad() and jit() compose and can be mixed arbitrarily. In the above example we jitted sum_logistic and then took its derivative. We can go further:

In [21]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325583


For more advanced autodiff, you can use jax.vjp() for reverse-mode vector-Jacobian products and jax.jvp() for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here’s one way to compose them to make a function that efficiently computes full Hessian matrices:

In [22]:
from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

### Auto-vectorization with vmap()
JAX has one more transformation in its API that you might find useful: vmap(), the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit(), it can be just as fast as adding the batch dimensions by hand.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.



In [23]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

Given a function such as apply_matrix, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.

In [24]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.64 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


We know how to batch this operation manually. In this case, jnp.dot handles extra batch dimensions transparently.

In [25]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
354 µs ± 37.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


However, suppose we had a more complicated function without batching support. We can use vmap() to add batching support automatically.

In [26]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
The slowest run took 7.66 times longer than the fastest. This could mean that an intermediate result is being cached.
718 µs ± 752 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
