# Getting started with JAX

Ref. [JAX quick start](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)  
Ref. [Getting started with JAX (MLPs, CNNs & RNNs)](https://roberttlange.github.io/posts/2020/03/blog-post-10/)

In [None]:
%matplotlib inline

In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.scipy.optimize import minimize

plt.style.use('ggplot')
key = random.PRNGKey(1)

## Multiplying matrices

In [None]:
N = 10

In [None]:
x = random.normal(key, (N,))
print(x)

### Multiplying big matrices

In [None]:
N = 3000
x = random.normal(key, (N, N))
x_numpy = np.random.normal(size=(N, N))





In [None]:
%timeit -n 100 -r 5 np.dot(x_numpy, x_numpy.T)
%timeit -n 100 -r 5 jnp.dot(x, x.T).block_until_ready()
%timeit -n 100 -r 5 jnp.dot(x_numpy, x_numpy.T).block_until_ready()


In [None]:
N = 150
D_features = 100
N_batch = 10

x = random.normal(key, (N, D_features))
batched_x = random.normal(key, (N_batch, D_features))

In [None]:
@jit
def apply_matrix(x, v):
    return jnp.dot(x, v)  

In [None]:
@jit
def naively_batched_apply_matrix(x, v):
    return jnp.stack([apply_matrix(x, _v) for _v in v])

print('Naively batched')
%timeit naively_batched_apply_matrix(x, batched_x).block_until_ready()

In [None]:
@jit
def batched_apply_matrix(x, v):
  return jnp.dot(v, x.T)

print('Manually batched')
%timeit batched_apply_matrix(x, batched_x).block_until_ready()

In [None]:
@jit
def v_apply_matrix(v_batched):
    return jnp.dot(x, v_batched)


In [None]:
@jit
def vmap_batched_apply_matrix(v_batched): 
  return vmap(v_apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
N_dim = 1000

In [None]:
x = random.uniform(key, (N_dim, N_dim))

In [None]:
def ReLU(x):
    """ Rectified Linear Unit = ReLU activation function. """
    return jnp.maximum(0, x)


jit_ReLU = jit(ReLU)

In [None]:
%time out = ReLU(x).block_until_ready()
%time jit_ReLU(x).block_until_ready()
%time out2 = jit_ReLU(x).block_until_ready()

In [None]:
def finite_grad(x):
    return jnp.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))

print(f"JAX grad: {jit(grad(jit(ReLU)))(2.)}")
print(f"Finite grad: {finite_grad(2.)}")

## Minimize loss function

In [None]:
def target_func(x, w):
    return w[0] * (x - w[1]) ** 2 + w[2]

In [None]:
a_true = 2.
b_true = 1.
c_true = 3.

In [None]:
x = jnp.linspace(-4., 4., 100)
y = target_func(x, [a_true, b_true, c_true]) + 2.5*random.normal(key, shape=(100,))

In [None]:
fig = plt.figure(figsize=(16, 9))
plt.scatter(x, y, color='blue')
plt.show()

In [None]:
def rmse_loss(w, x, y):
    y_hat = target_func(x, w)
    return jnp.sqrt(jnp.sum((y - y_hat)**2)/len(x))

In [None]:
w_init = jnp.array([1., 1., 1.])

In [None]:
results = minimize(rmse_loss, w_init, args=(x, y), method='BFGS',
                    tol=1e-7*x.shape[0], options={'maxiter': 20000})

In [None]:
print(f"if succeeded: {results.success}")
print(f"parameters: {results.x}")

In [None]:
fig = plt.figure(figsize=(16, 9))
plt.plot(x, target_func(x, results.x))
plt.scatter(x, y, color='blue')
plt.show()

## vmap

`vmap` lets you simply write your computations for a single sample case and afterwards wrap it to make it batch compatible.  
It is as easy as that. Let’s say you have a 100 dimensional feature vector and want to process it by a linear layer with 512 hidden units & your ReLU activation.  
And let’s say you want to compute the layer activations for a batch with size 32.

In [None]:
N_dim = 10
N_hidden_dim = 512
N_batch_dim = 32


# Generate a batch of vectors to process
X = random.normal(key, (N_batch_dim, N_dim))

# Generate Gaussian weights and biases
params = [random.normal(key, (N_hidden_dim, N_dim)),
          random.normal(key, (N_hidden_dim,))]


def relu_layer(params, x):
    """ Simple ReLU layer for single sample """
    return ReLU(jnp.dot(params[0], x) + params[1])


def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(jnp.dot(X, params[0].T) + params[1])


def vmap_relu_layer(params, x):
    """ vmap version of ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

In [None]:
out = jnp.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])

In [None]:
out.shape

In [None]:
out_batch_version = batch_version_relu_layer(params, X)

In [None]:
assert out.all() == out_batch_version.all()

In [None]:
out_vmap = vmap_relu_layer(params, X)

In [None]:
out_vmap()

In [None]:
assert out.all() == out_batch_version.all() == out_vmap.all()

In [None]:
[1, 2, 3, 4][:-1]