In this first section you will learn the very fundamentals of JAX.

## Getting started with JAX numpy
Fundamentally, JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API.

Over the course of this series of guides, we will unpack exactly what that means. For now, you can think of JAX as differentiable NumPy that runs on accelerators.

The code below shows how to import JAX and create a vector.

In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(10)
print(x)

[0 1 2 3 4 5 6 7 8 9]


So far, everything is just like NumPy. A big appeal of JAX is that you don’t need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute np for jnp. However, there are some important differences which we touch on at the end of this section.

You can notice the first difference if you check the type of x. It is a variable of type DeviceArray, which is the way JAX represents arrays.

In [2]:
x

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

One useful feature of JAX is that the same code can be run on different backends – CPU, GPU and TPU.

We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use %timeit to check the performance.

(Technical detail: when a JAX function is called, the corresponding operation is dispatched to an accelerator to be computed asynchronously when possible. The returned array is therefore not necessarily ‘filled in’ as soon as the function returns. Thus, if we don’t require the result immediately, the computation won’t block Python execution. Therefore, unless we block_until_ready, we will only time the dispatch, not the actual computation. See Asynchronous dispatch in the JAX docs.)

In [3]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

408 µs ± 26.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**Tip**: Try running the code above twice, once without an accelerator, and once with a GPU runtime (while in Colab, click Runtime → Change Runtime Type and choose GPU). Notice how much faster it runs on a GPU.

## JAX first transformation: grad

A fundamental feature of JAX is that it allows you to transform functions.

One of the most commonly used transformations is `jax.grad`, which takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function.

To use it, let’s first define a function that takes an array and returns the sum of squares.

In [4]:
def sum_of_squares(x):
  return jnp.sum(x**2)

Applying `jax.grad` to `sum_of_squares` will return a different function, namely the gradient of `sum_of_squares` with respect to its first parameter `x`.

Then, you can use that function on an array to return the derivatives with respect to each element of the array.

In [5]:
sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

30.0
[2. 4. 6. 8.]
