# Manual Vectorization

In [2]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

convolve(x, w)

DeviceArray([11., 20., 29.], dtype=float32)

In [3]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

In [4]:
def manually_batched_convolve(xs, ws):
    output = []
    for i in range(xs.shape[0]):
        output.append(convolve(xs[i], ws[i]))
    return jnp.stack(output)

manually_batched_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

In [5]:
def manually_vectorized_convolve(xs, ws):
    output = []
    for i in range(1, xs.shape[-1] - 1):
        output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
    return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

# Automatic Vectorization

In [6]:
auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

If the batch dimension is not the first, you may use the <code>in_axes</code> and <code>out_axes</code> arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.

In [7]:
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)

DeviceArray([[11., 11.],
             [20., 20.],
             [29., 29.]], dtype=float32)

<code>jax.vmap</code> also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights <code>w</code> with a batch of vectors <code>x</code>; in this case the <code>in_axes</code> argument can be set to <code>None</code>:

In [8]:
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

# Combining transformations

In [9]:
jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)