This is an adaptation from: https://wangkuiyi.github.io/vmap.html

Consider a function $f$ that processes a vector $x$. We could batch a list of vectors as rows or as columns. In the first case, we should vectorize along the first axis, and in the latter, along the second one:

In [1]:
import jax 
import jax.numpy as jnp

def f(x):
    return x.sum()
key = jax.random.PRNGKey(0)  # Random seed
v = jax.random.normal(key, shape=(3, 2))  # Random normal distribution
print(f"Batched vectors: v =\n{v}\n")

fv_rows = jax.vmap(f, in_axes=0)
fv_cols = jax.vmap(f, in_axes=1)

print("Vectorization along rows results in: ", fv_rows(v))
print("Vectorization along columns results in: ", fv_cols(v))

Batched vectors: v =
[[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]

Vectorization along rows results in:  [ 3.6479068  -0.5122118  -0.79599833]
Vectorization along columns results in:  [1.3651385 0.9745582]


JAX's tensors are immutable, therefore when we can a vectorized function we get a *new* tensor. The `out_axes` parameter dictates how to batch the outputs to construct the new tensor. If we want to pack along the rows, we'll set `out_axes=0`, and if we want to do it along the columns we'll use `out_axes=1`:

In [2]:
def g(x):
    return x

gv_rows = jax.vmap(g, out_axes = 0)
gv_cols = jax.vmap(g, out_axes = 1)
print(f"Batched vectors: v =\n{v}\n")
print("Packing along rows:\n", gv_rows(v))
print("Packing along columns:\n", gv_cols(v))

Batched vectors: v =
[[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]

Packing along rows:
 [[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]
Packing along columns:
 [[ 1.6226422  -0.43359444  0.1760909 ]
 [ 2.0252647  -0.07861735 -0.97208923]]


here we *implicitly* vectorized along the rows (`in_axes=0`). We can also combine both:

In [3]:
def h(x):
    return x + x

hv_rows_rows = jax.vmap(h, in_axes=0, out_axes=0)
hv_rows_cols = jax.vmap(h, in_axes=0, out_axes=1)
hv_cols_rows = jax.vmap(h, in_axes=1, out_axes=0)
hv_cols_cols = jax.vmap(h, in_axes=1, out_axes=1)

print(f"Batched vectors: v =\n{v}\n")
print("Vectorizing along rows and packing along rows:\n", hv_rows_rows(v))
print("Vectorizing along rows and packing along cols:\n", hv_rows_cols(v))
print("Vectorizing along cols and packing along rows:\n", hv_cols_rows(v))
print("Vectorizing along cols and packing along cols:\n", hv_cols_cols(v))

Batched vectors: v =
[[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]

Vectorizing along rows and packing along rows:
 [[ 3.2452843  4.0505295]
 [-0.8671889 -0.1572347]
 [ 0.3521818 -1.9441785]]
Vectorizing along rows and packing along cols:
 [[ 3.2452843 -0.8671889  0.3521818]
 [ 4.0505295 -0.1572347 -1.9441785]]
Vectorizing along cols and packing along rows:
 [[ 3.2452843 -0.8671889  0.3521818]
 [ 4.0505295 -0.1572347 -1.9441785]]
Vectorizing along cols and packing along cols:
 [[ 3.2452843  4.0505295]
 [-0.8671889 -0.1572347]
 [ 0.3521818 -1.9441785]]
