# JAX Basics

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap

key = random.PRNGKey(0)



# Jax as GPU backed NumPy 

In [5]:
x_gpu = random.normal(key, (3000,3000))
print(type(x_gpu))

<class 'jaxlib.xla_extension.DeviceArray'>


In [6]:
%timeit -n 1 -r 1 jnp.dot(x_gpu, x_gpu.T).block_until_ready()

175 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [7]:
x_cpu = np.array(x_gpu)

%timeit -n 1 -r 1 np.dot(x_cpu,x_cpu.T)

127 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


# Automatic differentiation

In [8]:
def sum_of_squares(x):
    return jnp.sum(x**2)

sum_of_squares_dx = grad(sum_of_squares)

$$
f = \sum_{i \in X} i^2 = X^t X\\
\nabla f = 2 X
$$

In [9]:
x = jnp.asarray([1., 2., 3., 4])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

30.0
[2. 4. 6. 8.]


In [10]:
def f(x):
    for i in range(8):
        x = x + x * i + 3
    return x

In [11]:
grad(f)(290.)

DeviceArray(40320., dtype=float32)

# Using jit() to speed up functions

In [12]:
def mse(x, y):
    return jnp.mean((x - y) ** 2)

jit_mse = jit(mse)

In [13]:
x = random.normal(key, (100000,))

%timeit mse(x, x).block_until_ready()

415 µs ± 18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [14]:
%timeit jit_mse(x, x).block_until_ready()

60.3 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Vectorization

In [15]:
def norm(vect):
    return jnp.sqrt(jnp.sum(vect ** 2))

def naive_batched_norm(x): 
    return jnp.stack([norm(vect) for vect in x])

vmap_norm = vmap(norm)

In [18]:
matrix = random.normal(key, (100, 100))

In [19]:
%timeit naive_batched_norm(matrix).block_until_ready()

39.8 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
%timeit vmap_norm(matrix).block_until_ready()

930 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
jit_vmap_norm = jit(vmap_norm)
%timeit jit_vmap_norm(matrix).block_until_ready()

22.2 µs ± 2.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [24]:
 naive_batched_norm(matrix)

DeviceArray([ 9.424122 ,  9.48725  ,  9.005592 ,  9.376043 ,  8.738633 ,
              9.67595  , 10.202312 ,  9.295728 , 10.043758 , 11.205047 ,
             10.043984 ,  9.783223 ,  9.775057 ,  9.872046 , 11.055422 ,
             11.051067 ,  9.390705 , 10.003488 , 10.729829 ,  9.728503 ,
             10.564091 ,  9.68822  ,  9.859245 , 10.632973 , 10.256151 ,
             11.143114 , 11.010348 ,  9.898047 , 10.956767 , 10.345362 ,
              9.701541 , 10.382492 ,  9.457232 ,  9.91471  , 11.022385 ,
             10.168085 ,  8.986782 , 10.513088 , 10.427667 ,  9.694277 ,
             10.646457 , 11.129786 , 10.154379 ,  9.784558 ,  9.633135 ,
             10.000055 ,  9.502357 , 10.485467 , 10.889698 , 10.201964 ,
              9.463949 , 10.868552 ,  9.912413 , 10.541245 , 10.432263 ,
             10.040994 ,  8.693418 ,  9.699597 ,  9.186164 ,  9.984744 ,
             11.57897  , 11.4064045,  9.4746065, 10.667576 ,  9.949682 ,
             11.235698 ,  9.689346 , 10.480486 , 10

In [25]:
vmap_norm(matrix)

DeviceArray([ 9.424122 ,  9.48725  ,  9.005592 ,  9.376043 ,  8.738633 ,
              9.67595  , 10.202313 ,  9.295728 , 10.043758 , 11.205046 ,
             10.043985 ,  9.783223 ,  9.775057 ,  9.872046 , 11.055422 ,
             11.051067 ,  9.390705 , 10.003488 , 10.729829 ,  9.728503 ,
             10.564091 ,  9.68822  ,  9.859245 , 10.632973 , 10.25615  ,
             11.143114 , 11.010349 ,  9.898047 , 10.956766 , 10.345362 ,
              9.701541 , 10.382493 ,  9.457232 ,  9.91471  , 11.022385 ,
             10.168085 ,  8.986781 , 10.513088 , 10.427667 ,  9.694277 ,
             10.646457 , 11.129786 , 10.154379 ,  9.784558 ,  9.633134 ,
             10.000055 ,  9.5023575, 10.485467 , 10.889699 , 10.201964 ,
              9.463949 , 10.868552 ,  9.912413 , 10.541246 , 10.432263 ,
             10.040994 ,  8.693418 ,  9.699597 ,  9.186165 ,  9.984744 ,
             11.57897  , 11.406404 ,  9.4746065, 10.667576 ,  9.949682 ,
             11.235698 ,  9.689347 , 10.480486 , 10

In [26]:
jit_vmap_norm(matrix)

DeviceArray([ 9.424122 ,  9.48725  ,  9.005592 ,  9.376043 ,  8.738633 ,
              9.67595  , 10.202313 ,  9.295728 , 10.043758 , 11.205046 ,
             10.043985 ,  9.783223 ,  9.775057 ,  9.872046 , 11.055422 ,
             11.051067 ,  9.390705 , 10.003488 , 10.729829 ,  9.728503 ,
             10.564091 ,  9.68822  ,  9.859245 , 10.632973 , 10.25615  ,
             11.143114 , 11.010349 ,  9.898047 , 10.956766 , 10.345362 ,
              9.701541 , 10.382493 ,  9.457232 ,  9.91471  , 11.022385 ,
             10.168085 ,  8.986781 , 10.513088 , 10.427667 ,  9.694277 ,
             10.646457 , 11.129786 , 10.154379 ,  9.784558 ,  9.633134 ,
             10.000055 ,  9.5023575, 10.485467 , 10.889699 , 10.201964 ,
              9.463949 , 10.868552 ,  9.912413 , 10.541246 , 10.432263 ,
             10.040994 ,  8.693418 ,  9.699597 ,  9.186165 ,  9.984744 ,
             11.57897  , 11.406404 ,  9.4746065, 10.667576 ,  9.949682 ,
             11.235698 ,  9.689347 , 10.480486 , 10

# How does it work ?

In [22]:
jax.make_jaxpr(norm)(x)

{ lambda  ; a.
  let b = integer_pow[ y=2 ] a
      c = reduce_sum[ axes=(0,) ] b
      d = sqrt c
  in (d,) }

In [11]:
jax.make_jaxpr(naive_batched_norm)(matrix)

{ lambda  ; a.
  let b = slice[ limit_indices=(1, 100)
                 start_indices=(0, 0)
                 strides=(1, 1) ] a
      c = squeeze[ dimensions=(0,) ] b
      d = slice[ limit_indices=(2, 100)
                 start_indices=(1, 0)
                 strides=(1, 1) ] a
      e = squeeze[ dimensions=(0,) ] d
      f = slice[ limit_indices=(3, 100)
                 start_indices=(2, 0)
                 strides=(1, 1) ] a
      g = squeeze[ dimensions=(0,) ] f
      h = slice[ limit_indices=(4, 100)
                 start_indices=(3, 0)
                 strides=(1, 1) ] a
      i = squeeze[ dimensions=(0,) ] h
      j = slice[ limit_indices=(5, 100)
                 start_indices=(4, 0)
                 strides=(1, 1) ] a
      k = squeeze[ dimensions=(0,) ] j
      l = slice[ limit_indices=(6, 100)
                 start_indices=(5, 0)
                 strides=(1, 1) ] a
      m = squeeze[ dimensions=(0,) ] l
      n = slice[ limit_indices=(7, 100)
                 start_indi

In [23]:
jax.make_jaxpr(vmap_norm)(matrix)

{ lambda  ; a.
  let b = integer_pow[ y=2 ] a
      c = reduce_sum[ axes=(1,) ] b
      d = sqrt c
  in (d,) }

In [25]:
global_list = []

@jit
def norm(vect):
    global_list.append(vect)
    return jnp.sqrt(jnp.sum(vect ** 2))

norm(x)
norm(x)
norm(x)

print(len(global_list))

1


In [26]:
jax.make_jaxpr(norm)(x)

{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = integer_pow[ y=2 ] a
                                     c = reduce_sum[ axes=(0,) ] b
                                     d = sqrt c
                                 in (d,) }
                    device=None
                    donated_invars=(False,)
                    name=norm ] a
  in (b,) }