# 03 - Automatic vectorization in JAX

Let's look into how vectorization works in jax today.

### Manual vectorization

Here's a simple function that convolves over two one-dimensional vectors. 

In [2]:
import jax
import jax.numpy as jnp

In [3]:
x = jnp.arange(5)
x

DeviceArray([0, 1, 2, 3, 4], dtype=int32)

In [4]:
w = jnp.array([2., 3., 4.])
w

DeviceArray([2., 3., 4.], dtype=float32)

In [5]:
def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

In [6]:
convolve(x,w)

DeviceArray([11., 20., 29.], dtype=float32)

Now, say that we want to apply this to a batch of weights `ws` and a batch of vectors `xs`

In [7]:
xs = jnp.stack([x,x,x])
ws = jnp.stack([w,w,w])

In [8]:
xs

DeviceArray([[0, 1, 2, 3, 4],
             [0, 1, 2, 3, 4],
             [0, 1, 2, 3, 4]], dtype=int32)

In [9]:
ws

DeviceArray([[2., 3., 4.],
             [2., 3., 4.],
             [2., 3., 4.]], dtype=float32)

The naive option would be loop over the batch in Python.

In [20]:
def manual_batch_convolve(xs, ws):
    output = []
    for i in range(xs.shape[0]):
        output.append(convolve(xs[i], ws[i]))
    return jnp.stack(output)

In [21]:
manual_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

It's the correct result but not efficient at all.

## Automatic vectorization

In [22]:
auto_batch_convolve = jax.vmap(convolve)

In [23]:
auto_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

ðŸ’¥

This actually works similar to how the `jit` worked, by tracing the function and then adding batch axes automatically at the beginning of each input.

Jax assumes batch dimensions to be the first one, if that's not the case then one can include `in_axes` and `out_axes` arguments to specify location of batch dimensions.

In [24]:
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

In [25]:
auto_batch_convolve_v2(xst, wst)

DeviceArray([[11., 11., 11.],
             [20., 20., 20.],
             [29., 29., 29.]], dtype=float32)

`vmap` also supports the case where only one of the arguments is batched. For eg. convolving to a single set of weight w with a btch of vectors.

In [26]:
auto_batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

In [27]:
auto_batch_convolve_v3(xs, w)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

## Combining transformations

My favorite part of JAX is where one can compose these transformations, wrapping a vmapped function with jit or jitting a vmapped function etc.

In [28]:
jitted_batch_convolve = jax.jit(auto_batch_convolve)

In [30]:
jitted_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

This is all for today. Tomorrow, we'll look into (pseudo) random numbers.