# Vectorisation

One of the most helpful feature for my past projects (e.g., PDEs, GPs), and one reason that I completely abandoned Matlab.

Note: THE vectorisation is not the same as parallelisation.

# Example

Suppose that we have two arrays `a` and `b` of shapes `(10, 5, 3, 6)` and `(10, 3, 7, 6)`, respectively. 

How to do `a @ b` in the way that the multiplication applies to `(5, 3) x (3, 7)` while taking other dimensions as the broadcasting dimension? Eventually, we desire an array of shape `(10, 5, 7, 6)`. 

How to do so in Numpy/Matlab?

(einsum)

# Example

In [None]:
def func(x, y):
    """Arguments x and y have the same shape (2, ).
    Return a (2, 2) matrix.
    """
    z = x * y
    return np.array([[x[0] ** 2, x[0] * x[1]], 
                     [np.sin(x[1]), x[0] + x[1]]]) + np.outer(z, z)

Now if inputs `x` and `y` are of shape `(100, 2)`, how to batch over the first dimension and return a tensor `(100, 2, 2)`?

Even more complicated, if the inputs `x` and `y` are of shapes `(100, 2)` and `(300, 2)`, respectively, how to visit over the 100 and 200 and return a tensor `(100, 300, 2, 2)`?

Good luck with Matlab.

How to do these in numpy? As an example,

In [None]:
import numpy as np

np_vectorized_func = np.vectorize(func, signature='(n),(n)->(n,n)')

np_vectorized_func(np.ones((5, 2)), np.ones((5, 2))).shape

Cool! numpy has a concise way to do the vectorisation. 

However, please note that `np.vectorize` is merely a syntax sugar of a python loop. It is **not a vectorisation on the computation level**. 

The jax implementation

In [None]:
import jax
import jax.numpy as jnp

def func(x, y):
    """Arguments x and y are of the same shape (2, ).
    Return a (2, 2) matrix.
    """
    z = x * y
    return jnp.array([[x[0] ** 2, x[0] * x[1]], 
                     [jnp.sin(x[1]), x[0] + x[1]]]) + jnp.outer(z, z)

jax_vectorized_func = jax.vmap(func, in_axes=[0, 0])

jax_vectorized_func(jnp.ones((5, 2)), jnp.ones((5, 2))).shape

Compare speed

In [None]:
a = np.ones((10000, 2))
%timeit np_vectorized_func(a, a)

In [None]:
a = jnp.asarray(a)
_ = jax_vectorized_func(a, a)
%timeit jax_vectorized_func(a, a).block_until_ready()

`jax.vmap(func,
          in_axes) -> Callable[the same signature as func]`

- `func`: the function you want to vectorize
- `in_axes`: a tuple/list that indicates the vectorization axes for the arguments of `func`.

Recall that our `func(x, y)` takes two `(2, )` arrays as inputs and returns a `(2, 2)` matrix. 

- To vectorize for `x: (n, 2)` and `y: (n, 2)` for some `n`, use `in_axes=[0, 0]`. Get `(n, ...)`

- To vectorize for `x: (2, n)` and `y: (n, 2)` for some `n`, use `in_axes=[1, 0]`. Get `(n, ...)`

- To vectorize for `x: (n, 2)` and `y: (2, )` for some `n`, use `in_axes=[0, None]`. Get `(n, ...)`

To vectorize for `x: (m, 2)` and `y: (n, 2)` for some `m, n`. How to do? Use two `vmap` nested!

```python
jax.vmap(jax.vmap(func, in_axes=[0, None]), in_axes=[None, 0])
```

 Get `(m, n, ...)`

# Exercise

Monte Carlo approximation $\mathbb{E}[g(X)] \approx \frac{1}{N}\sum^N_{i=1} g(X^i),$ where $X, X^1, X^2, \ldots \sim \mathrm{N}(0, I_2)$ and $ g(X) = \begin{bmatrix} \exp(X_1) \sin(X_2) \\ X_1 \, X_2 + X_1\end{bmatrix} $

```python
N = 1000
key = ?
samples = jax.random.normal(?)

def g(x):
    return ?
    
vectorised_g = jax.vmap(?)
propogated_samples = vectorised_g(samples)
mean_g = jnp.mean(propogated_samples, axis=0)
```

## Solution

In [None]:
N = 1000

key = jax.random.PRNGKey(999)
samples = jax.random.normal(key, shape=(N, 2))

def g(x):
    return jnp.array([jnp.exp(x[0]) * jnp.sin(x[1]), 
                      x[0] * x[1] + x[0]])
    
vectorised_g = jax.vmap(g, in_axes=[0])

propogated_samples = vectorised_g(samples)

mean_g = jnp.mean(propogated_samples, axis=0)

# Exercise

Consider a Matern 3/2 covariance function $C\colon \mathbb{R}\times \mathbb{R} \to \mathbb{R}$ defined by

$$
C(t, t') = \sigma^2 \, \bigg(1 + \frac{\sqrt{3} \, \lvert t-t'\rvert}{\ell}\bigg) \, \exp\bigg(-\frac{\sqrt{3} \, \lvert t-t'\rvert}{\ell}\bigg)
$$

Say, now we have $T$ data points $t_1, t_2, \ldots, t_T$, compute the covariance matrix evaluated at the Cartesian $(t_1, t_2, \ldots, t_T) \times (t_1, t_2, \ldots, t_T)$, that is, 

$$
C_{1:T} = \begin{bmatrix} C(t_1, t_1) & C(t_1, t_2) & \cdots & C(t_1, t_T) \\
                                   C(t_2, t_1) & \ddots & & \vdots\\
                                   \vdots & & & \vdots\\
                                   C(t_T, t_1) & \cdots & \cdots & C(t_T, t_T)\end{bmatrix}.
$$

```python
def cov_func(t1: float, t2: float, ell: float, sigma: float) -> float:
    return ?

vectorised_cov_func = jax.vmap(jax.vmap(?), ?)

ts = jnp.linspace(0.01, 1, 100)

ell, sigma = 0.1, 1.
cov_matrix = vectorised_cov_func(ts, ts, ell, sigma)

import matplotlib.pyplot as plt
plt.contourf(cov_matrix, levels=20)
```

## Solution

In [None]:
import math

def cov_func(t1: float, t2: float, ell: float, sigma: float) -> float:
    z = math.sqrt(3) * jnp.abs(t1 - t2) / ell
    return sigma ** 2 * (1 + z) * jnp.exp(-z)

vectorised_cov_func = jax.vmap(jax.vmap(cov_func, 
                                        in_axes=[0, None, None, None]), 
                               in_axes=[None, 0, None, None])

# or equivalently
# from functools import partial

# @partial(jax.vmap, in_axes=[None, 0, None, None])
# @partial(jax.vmap, in_axes=[0, None, None, None])
# def cov_func(...):
#     ...

ts = jnp.linspace(0.01, 1, 100)

ell, sigma = 0.1, 1.
cov_matrix = vectorised_cov_func(ts, ts, ell, sigma)

import matplotlib.pyplot as plt
plt.contourf(cov_matrix, levels=20)

Speed comparison

In [None]:
# Numpy implementation 1. Naive implementation with two loops. Do not use this in practice.

def np_cov_func(t1, t2):
    cov_matrix = np.zeros((ts.size, ts.size))
    for i, t1 in enumerate(ts):
        for j, t2 in enumerate(ts):
            z = math.sqrt(3) * np.abs(t1 - t2) / ell
            cov_matrix[i, j] = sigma ** 2 * (1 + z) * np.exp(-z)
    return cov_matrix

ts_np = np.asarray(ts)
%timeit np_cov_func(ts_np, ts_np)

In [None]:
# Numpy implementation 2. Using broadcasting
# This is applicable only for limited applications, for example this exercise.

def np_cov_func(t1, t2):
    z = math.sqrt(3) * np.abs(t1[:, None] - t2[None, :]) / ell
    return sigma ** 2 * (1 + z) * np.exp(-z)

ts_np = np.asarray(ts)
%timeit np_cov_func(ts_np, ts_np)

In [None]:
# Numpy implementation 3 using scipy cdist

import scipy.spatial

def np_cov_func(t1, t2):
    r = scipy.spatial.distance.cdist(t1, t2, 'euclidean')
    z = math.sqrt(3) * r / ell
    return sigma ** 2 * (1 + z) * np.exp(-z)

ts_np = np.asarray(ts).reshape(-1, 1)
%timeit np_cov_func(ts_np, ts_np)

In [None]:
# In principle we should not jit vmap which is jitted already, but for some reasons
# the jitted vmap is faster than that of non-jitted in my computer
f = jax.jit(vectorised_cov_func)

f(ts, ts, ell, sigma)

%timeit f(ts, ts, ell, sigma).block_until_ready()

In [None]:
vectorised_cov_func(ts, ts, ell, sigma)

%timeit vectorised_cov_func(ts, ts, ell, sigma).block_until_ready()

There is also `jax.pmap`. This for parallelisation across different devices, for example, multiple GPUs/TPUs. See details https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html.