# Vmap - write code for one sample point (datum), automatically batch it!
This is what enables per sample gradient - example below

In [4]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)


In [5]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [6]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.44 ms ± 5.37 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
54.7 µs ± 6.82 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
35 µs ± 533 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Deep learning per sample gradients

In [9]:
def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.tanh(outputs)  # inputs to the next layer
    return outputs 

In [10]:
def loss(params, inputs, targets):
    preds = predict(params, inputs)
    return jnp.sum((preds - targets)**2)

In [12]:
grad_loss = jit(grad(loss))  # compiled gradient evaluation function

In [13]:
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://github.com/google/jax