#### Vectorization and Parallelization
1. Vectorization and parallelization are done with two transformations, namely **`vmap`** and **`pmap`**
2. To illustrate these functions we start with a typical dot product example using **`numpy`**

In [52]:
import numpy as np

# What if we want to do dot products for a batch of vectors?
array1 = np.stack([np.array([1, 2, 3, 4]) for i in range(5)])
array2 = np.stack([np.array([5, 6, 7, 8]) for i in range(5)])

# We can use `einsum`
print(np.einsum('ij,ij-> i', array1, array2))

[70 70 70 70 70]


##### What is vmap
1. In the above, we used functions that operate on batches of data such as **`einsum`** to get the dot products of the rows of both matrices
2. The same can be done with another transformation, **`vmap`**. **`vmap`** takes a function as an input along the dimensions for hte inputs and the outputs where the functions is to be mapped over to create a vectorized function. The syntax looks like this **`vmap(function, in_axes, out_axes, ...)`**

In [53]:
import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax.config import config
from jax import grad, vmap, pmap, jit

# A batch of vectors
array1 = np.stack([np.array([1, 2, 3, 4]) for i in range(5)])
array2 = jnp.stack([jnp.array([5, 6, 7, 8]) for i in range(5)])

# Singular operation to be performed by vmap
def dot_product(array1, array2):
    """Performs dot product on two jax arrays."""
    return jnp.dot(array1, array2)

# Vmapped function
func = vmap(dot_product, in_axes=(0,0), out_axes=(0))

# Further transformation with jit
jitted_func = jit(func)

%timeit res1 = func(array1, array2)
%time res2 = jitted_func(array1, array2)
%timeit res3 = jitted_func(array1, array2)

print(res1)
print(res3)


351 µs ± 28.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
CPU times: user 23.7 ms, sys: 0 ns, total: 23.7 ms
Wall time: 23.2 ms
4.98 µs ± 478 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
[70 70 70 70 70]
[70 70 70 70 70]


3. Now, lets look at whats happening in the backend. Notice that now there is a **`dimension_numbers`** to specify the axes that the **`dot_general`** operation is executing.

In [54]:
make_jaxpr(vmap(dot_product, in_axes=(0,0), out_axes=(0)))(array1, array2)

{ lambda ; a:i32[5,4] b:i32[5,4]. let
    c:i32[5] = dot_general[
      dimension_numbers=(((1,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] a b
  in (c,) }