In [1]:
import jax.numpy as jnp

In [2]:
def relu(x):
    return jnp.maximum(0,x)
def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=0)
def cross_entropy(x,y):
    return -jnp.sum(y * jnp.log(x))

In [3]:
x = jnp.arange(6.0)
print(relu(x))
print(softmax(x))
print(cross_entropy(softmax(x),x))

[0. 1. 2. 3. 4. 5.]
[0.00426978 0.01160646 0.03154963 0.08576079 0.233122   0.6336913 ]
26.8429


In [4]:
from jax import random

key = random.PRNGKey(0)
key, subkey = random.split(key)
print(subkey)

[2718843009 1272950319]


In [5]:
x = random.normal(key, (1_000_000,))
%timeit relu(x)

169 μs ± 1.11 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
from jax import jit

jitted_relu = jit(relu)
a = jitted_relu(x)
%timeit relu(a)

167 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = random.key(1000)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

1.34 ms ± 12.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

375 μs ± 3.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
def mm(x, y):
  return jnp.dot(x, y)

@jit
def mm_jit(x, y):
  return jnp.dot(x, y)

In [10]:
a = random.normal(key, (1000, 1000))
b = random.normal(key, (1000, 1000))

%timeit mm(a, b)

3.61 ms ± 35.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
%timeit mm_jit(a, b)

3.63 ms ± 22.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
