#### Jax

- quickstart https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
- common gotchas https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
    - works on pure functions only, avoid side-effect, global variables
    - on mutation (keep variables pure!), mutate via `index, index_add,index_update`
    - out-of-bounds indexing do not raise error, but outputs last value of array 
    - random number generation via keys, made explicit
    - `@jit` is compiled to `ShapedArray` of fixed size, so compiled function can be applied to array of arbitrary values. Use `static_argnums` then can specify trace on concrete values of some arguments


In [18]:
# quickstart https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

## multiply big matrices
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
# 95.9 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# works on np arrays, slower since needs to put data to GPU
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
# 130 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# use `device_put` to put onto GPU first
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
# 96.1 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

## Use @jit to compile multiple operations using XLA
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
# 2.62 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
# 598 µs ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Use `grad` to take derivatives
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25       0.19661197 0.10499357]
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
# [0.24998187 0.1964569  0.10502338]

# Use `jax.vjp` for reverse-mode vector-jacobian products
# Use `jax.jvp` for forward-mode jacobian-vector products
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

## auto-vectorization with `vmap`
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

# Naively batched
# 2.92 ms ± 98.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Manually batched
# 112 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# Auto-vectorized with vmap
# 105 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]
95.9 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
130 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [31]:
# Neurips 2019 Jax tutorial
# https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python


x = jnp.array([1,1])



(2,)

In [None]:
## MNIST
import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax
import datasets


def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    Dense(1024), Relu,
    Dense(1024), Relu,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]
    batches = data_stream()

    opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    _, init_params = init_random_params(rng, (-1, 28 * 28))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    print("\nStarting training...")
    for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
        opt_state = update(next(itercount), opt_state, next(batches))
    epoch_time = time.time() - start_time

    params = get_params(opt_state)
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))