# The Autodiff Cookbook

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

## Gradients

### Starting with <code>grad</code>

In [2]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650935


In [3]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621888
0.2526544


In [4]:
def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probatility of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.

inputs = jnp.array([[0.52, 1.12, 0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [5]:
# Differentiate 'loss' with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774647  -1.4901344 ]
W_grad [-0.16965583 -0.8774647  -1.4901344 ]
b_grad -0.29227248
W_grad [-0.16965583 -0.8774647  -1.4901344 ]
b_grad -0.29227248


### Differentiating with respect to nested lists, tuples, and dicts

In [6]:
def loss2(params_dict):
  preds = predict(params_dict['W'], params_dict['b'], inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': DeviceArray([-0.16965583, -0.8774647 , -1.4901344 ], dtype=float32), 'b': DeviceArray(-0.29227248, dtype=float32)}


### Evaluate a function and its gradient using <code>value_and_grad</code>

In [7]:
from jax import value_and_grad

loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 3.051939
loss value 3.051939


### Checking against numerical differences

In [8]:
# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))

b_grad_numerical -0.29563904
b_grad_autodiff -0.29227248
W_dirderiv_numerical -0.19788742
W_dirderiv_autodiff -0.19909093


In [9]:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives

### Hessian-vector products with <code>grad</code>-of-<code>grad</code>

In [10]:
def hvp(f, x, v):
  return grad(lambda x: jnp.vdot(grad(f)(s), v)(x))

### Jacobians and Hessians using <code>jacfwd</code> and <code>jacrev</code>

In [11]:
from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

jacfwd result, with shape (4, 3)
[[ 0.05981757  0.12883784  0.08857601]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140429 -0.00472523  0.00263778]]
jacrev result, with shape (4, 3)
[[ 0.05981756  0.12883782  0.088576  ]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140429 -0.00472523  0.00263778]]


In [12]:
def predict_dict(params, inputs):
  return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
  print("Jacobian from {} to logits is".format(k))
  print(v)

Jacobian from W to logits is
[[ 0.05981756  0.12883782  0.088576  ]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140429 -0.00472523  0.00263778]]
Jacobian from b to logits is
[0.11503378 0.04563539 0.23439017 0.00189768]


In [13]:
def hessian(f):
  return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.0492254   0.03384246]
  [ 0.04922541  0.10602394  0.07289145]
  [ 0.03384247  0.07289146  0.05011287]]

 [[-0.03195214  0.03921399 -0.00544639]
  [ 0.03921399 -0.04812626  0.0066842 ]
  [-0.00544639  0.0066842  -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103522  0.00348338 -0.00194454]
  [ 0.00348338 -0.01172109  0.0065431 ]
  [-0.00194453  0.0065431  -0.00365257]]]


## How it's made: two foundational autodiff functions

### Jacobian-Vector Products (JVPs, aka forward-mode autodiff)

#### JVPs in JAX code

In [14]:
from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector 'v' along 'f' evaluated at 'W'
y, u = jvp(f, (W,), (v,))

### Vector-Jacobian products (VJPs, aka reverse-mode autodiff)

#### VJPs in JAX code

In [15]:
from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector 'u' along 'f' evaluated at 'W'
v = vjp_fun(u)

#### Vector-valued gradients with VJPs

In [16]:
from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))

[[6. 6.]
 [6. 6.]]


### Hessian-vector products using both forward- and reverse-mode

In [17]:
def hvp(f, x, v):
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

In [18]:
from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

In [19]:
def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))

True


In [20]:
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

That’s not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:

In [21]:
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)

Forward over reverse
10 loops, best of 3: 9.34 ms per loop
Reverse over forward
10 loops, best of 3: 10.9 ms per loop
Reverse over reverse
10 loops, best of 3: 13.6 ms per loop
Naive full Hessian materialization
10 loops, best of 3: 22.3 ms per loop


## Composing VJPs, JVPs, and <code>vmap</code>

In [22]:
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors 'm_i' along 'f', evaluated at 'W', for all 'i'.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
  y, vjp_fun = vjp(f, x)
  return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.

def vmap_mjp(f, x, M):
  y, vjp_fun = vjp(f, x)
  outs, = vmap(vjp_fun)(M)
  return outs

key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('None-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'

None-vmapped Matrix-Jacobian product
10 loops, best of 3: 225 ms per loop

Vmapped Matrix-Jacobian product
10 loops, best of 3: 8.33 ms per loop
