In [4]:
import jax 
import jax.numpy as jnp 
from jax import jit 
from jax import random

In [5]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

[0.        1.05      2.1       3.1499999 4.2      ]


In [8]:
key = random.key(1701)
x = random.normal(key, (1_000_000,))
selu_jit = jit(selu)
_ = selu_jit(x)
%timeit selu(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()

410 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
70.5 µs ± 3.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
from jax import grad 

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [10]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
mat = jax.random.normal(key1, (150, 100))
batched_x = jax.random.normal(key2, (10, 100))


def apply_matrix(x):
    return jnp.dot(mat, x)

In [11]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.96 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
@jit
def vmap_batched_apply_matrix(batched_x):
  return jax.vmap(apply_matrix)(batched_x)

vmap_batched_apply_matrix(batched_x)
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

85.2 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
