In [3]:
!pip install jax



In [4]:
import jax.numpy as jnp

In [5]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [6]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

1.06 ms ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

301 µs ± 7.72 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
