# JAX as Numpy

JAX can be used as a drop-in replacement for numpy. There are other things that JAX can do, but jax.numpy can be used just as you would use numpy.

In [21]:
import jax.numpy as jnp
from jax import random as jnr

In [34]:
rng = jnr.key(42)
x = jnr.normal(rng, (int(1e7)))
y = jnr.normal(rng, (int(1e7)))

print("Type of x:", type(x))

print("Sum of elements in x using dot():", jnp.dot(x, y))

print("Sum of elements in x using einsum():", jnp.einsum('i,i', x, y))

print("Sum of elements in x using @:", x @ y)

assert jnp.dot(x, y) == jnp.einsum('i,i', x, y) == x @ y
print("Operations result in equal values.")


Type of x: <class 'jaxlib.xla_extension.ArrayImpl'>
Sum of elements in x using dot(): 9993415.0
Sum of elements in x using einsum(): 9993415.0
Sum of elements in x using @: 9993415.0
Operations result in equal values.


Let's look at the counterpart in numpy and torch, and their out of box performance.

In [44]:
import numpy as np
import torch

nx = np.array(x) # numpy can directly convert from jax
tx = torch.tensor(nx) # torch cannot directly convert from jax

ny = np.array(y)
ty = torch.tensor(ny)

print("Type of nx:", type(nx))
print("Type of yx:", type(tx))

Type of nx: <class 'numpy.ndarray'>
Type of yx: <class 'torch.Tensor'>


In [45]:
%timeit jnp.dot(x, y).block_until_ready()
%timeit jnp.einsum('i,i', x, y).block_until_ready()
%timeit (x @ y).block_until_ready()

%timeit nx @ ny
%timeit tx @ ty
tx = tx.to('cuda')
ty = ty.to('cuda')
%timeit tx @ ty


409 μs ± 22.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
703 μs ± 34.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
399 μs ± 14.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.34 ms ± 33.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.13 ms ± 54.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
136 μs ± 36.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Jax is faster than numpy, in fact by a factor of about 25. This is due to GPU speed ups that JAX implements. We can see that JAX is also faster than torch if torch is on the CPU. When torch is on the GPU, the performance is similar.

But we can speed up JAX even more with just in time compilation, in fact in this example, JAX is sped up by a factor of 4 more when using jit. This provides almost a 100x speedup compared to using numpy on CPU, and a 4x speedup compared to torch on GPU.

In [53]:
from jax import jit

@jit
def dot_product_examples(x, y):
    return jnp.dot(x, y)


%timeit dot_product_examples(x, y).block_until_ready()


387 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


JAX has this particular feature of just in time compilation, which can speed up the performance of the code.