# Vectorizing and Parallelizing

In [None]:
import jax
import jax.numpy as jnp

## Vectorization using `vmap`

#### Simple function to calculate the cosine of the angle between two vectors

In [None]:
def vector_cos(x, y):
    return jnp.dot(x, y) / (jnp.linalg.norm(x) * jnp.linalg.norm(y))

In [None]:
x = jnp.array([1.0, 0.0, 0.0])
y = jnp.array([0.0, 1.0, 0.0])
p = jnp.array([1.0, 1.0, 0.0])

print(f'The cosine of the angle between x, y is: {vector_cos(x, y)}')
print(f'The cosine of the angle between x, p is: {vector_cos(x, p):1.2f}')

In [None]:
%timeit vector_cos(x, y)

#### What if we want to calculate the cosine of the array between multiple vector pairs?

In [None]:
seed = 42
random_key = jax.random.PRNGKey(seed)

key1, key2 = jax.random.split(random_key, 2)

X = jax.random.uniform(key1, shape=(20000, 3))
Y = jax.random.uniform(key2, shape=(20000, 3))

#### This is where `vmap` becomes useful, it vectorizes a function to work with "vectors" of it's arguments

In [None]:
# Vectorize the function
vector_cos_vmapped = jax.vmap(vector_cos) 

xy_cos = vector_cos_vmapped(X, Y)
print(f'The shape of the resulting array is: {xy_cos.shape}')

#### <mark>Hands-on</mark>: time the computation and time it also after jitting it

### `vmap` will assume that the first dimension of the arguments is the one to vectorize. The behavior can be explicitely controlled using `in_axes`

In [None]:
key3, key4 = jax.random.split(key1, 2)

Z = jax.random.uniform(key3, shape=(20000, 3))
P = jax.random.uniform(key4, shape=(3, 20000))

In [None]:
vector_cos_vmapped2 = jax.vmap(vector_cos, in_axes=(0, 1))

In [None]:
zp_cos = vector_cos_vmapped2(Z, P)
print(f'The shape of the resulting array is: {xy_cos.shape}')

## Parallelization using `pmap`

#### Up to know everything used a single GPU to carry out the computations. We can use `pmap` to parallelize them

In [None]:
key5, key6 = jax.random.split(key3, 2)

L = jax.random.uniform(key5, shape=(len(jax.devices('gpu')), 2000000))
M = jax.random.uniform(key6, shape=(len(jax.devices('gpu')), 2000000))

In [None]:
print(L.shape)

In [None]:
vector_cos_pmapped = jax.pmap(vector_cos)

In [None]:
vector_cos_pmapped(L, M)

#### The idea is similar to `vmap` but now the computation is spread across devices. We can even combine with `vmap`.

In [None]:
vector_cos_pvmapped = jax.pmap(vector_cos_vmapped)


XX = X.reshape(4, -1, 3)
YY = Y.reshape(4, -1, 3)

xy_cos_p = vector_cos_pvmapped(XX, YY)

#### Check the result

In [None]:
jnp.allclose(xy_cos_p.reshape(20000), xy_cos)

#### <mark>Hands-on</mark>: time the computation and run for increasing number of vectors. Check also the device on which the result resides