# Config TPU

In [1]:
from jax_tpu_util import config_tpu

In [2]:
config_tpu('moon')

# Setup

In [3]:
import jax.numpy as jnp

from jax import grad, jit, vmap
from jax import random
from jax import device_put

import numpy as np

# JAX vs Numpy

In [40]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

5.62 ms ± 78.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [41]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

268 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

5.72 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Using `jit()` to speed up functions

In [20]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [23]:
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

11.8 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
selu_jit = jit(selu)

In [27]:
%timeit selu_jit(x).block_until_ready()

1.81 ms ± 70.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Taking derivatives with `grad()`

In [28]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.1966118  0.10499343]


### Verify with finite differences

In [29]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24974345 0.1965761  0.10490417]


In [30]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.03532532


# Auto-vectorization with `vmap()`

In [31]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [35]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
38.9 ms ± 686 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [36]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
1.86 ms ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [37]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
1.86 ms ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
