# JAX workshop

28 November, 2022 in "A computational introduction to stochastic differential equations"

Zheng Zhao


# Installation

You should have installed `jax`. If not, please now do so by using the command

`pip install jaxlib jax`

or use `Google Colab`.

# Why/when should you use JAX?

It is good to use JAX if your project requires any of:

0. python is your primary language.
1. just-in-time (JIT) compilation for fast computation.
2. automatic differentiation.
3. vectorisation.
4. Multiple GPU/TPU acceleration (we don't do these today).
5. ...

# Easy to use

JAX = numpy + scipy + essential functionalities (e.g., auto-diff, JIT, and parallelisation). 

If you can use `numpy`, then you should have no problem to use `jax`.


|     | Numpy       | JAX |
| --- | ----------- | ----------- |
|  | `import numpy as np`      | `import jax.numpy as jnp`       |
|  | `np.array()`   | `jnp.array()`        |
|  | `np.ones()`   | `jnp.ones()`        |
|  | `np.zeros()`   | `jnp.zeros()`        |
|  | `np.eye()`   | `jnp.eye()`        |
|  | `np.sin()`   | `jnp.sin()`        |
|  | `np.linalg....`   | `jnp.linalg....`        |
|  | ...   | ...        |

# Galleries

## Auto-differentiation

In [None]:
import jax
import jax.numpy as jnp

def f(x):
    return x * jnp.sin(x) + jnp.cos(x)

grad_of_f = jax.grad(f) # jax.grad returns a function
grad_of_f(1.)

You say you can compute the gradient by hand? How about something deep...

In [None]:
def something_deep(x):
    for i in range(100):
        x = f(x)
    return x

grad_of_deep = jax.grad(something_deep)
grad_of_deep(1.)

## JIT speed-up

In [None]:
import numpy as np

def softmax_np(x):
    """Softmax in Numpy
    """
    z = np.exp(x)
    return z / np.sum(z, axis=0)

@jax.jit
def softmax_jax(x):
    """Softmax in JAX
    """
    z = jnp.exp(x)
    return z / jnp.sum(z, axis=0)

In [None]:
x = np.ones((10, 1000))
%timeit softmax_np(x)

In [None]:
x = jnp.ones((10, 1000))
softmax_jax(x)
%timeit softmax_jax(x).block_until_ready()

# Agenda

1. JIT compilation
2. Generate random numbers
3. Vectorisation

-- break --

4. Autodiff
5. Control flows (i.e., loops/if-else)
6. SDE examples
7. Homework

Useful materials:

1. Jax sharp bits https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html.
2. Workshop by Adrien Corenflos: https://github.com/AdrienCorenflos/jax-workshop.