In [1]:
# Some modules
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

Just-in-time compiling and executing with jit()

In [2]:
# Normal run
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = random.PRNGKey(0)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()



3.18 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [3]:
# jit run
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

397 µs ± 25.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Automatic differentiation with grad()

In [4]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
first_derivative_fn = grad(sum_logistic)
print(first_derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [5]:
second_derivative_fn = grad(grad(sum_logistic))
print(second_derivative_fn(1.0))

-0.09085775


In [6]:
print(second_derivative_fn(x_small))

TypeError: Gradient only defined for scalar-output functions. Output had shape: (3,).

In [7]:
from jax import jacfwd, jacrev

second_derivative_fn = jacfwd(jacrev(sum_logistic))
print(second_derivative_fn(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]


In [8]:
print(jnp.diagonal(second_derivative_fn(x_small)))

[-0.         -0.09085776 -0.07996249]


Auto-vectorization with vmap()

In [9]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [10]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.11 ms ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
8.96 µs ± 145 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [12]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
14.4 µs ± 84.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
