# Autodiff

1. Gradient
2. Jacobian
3. Hessian

In [None]:
import jax
import jax.numpy as jnp

Consider a function $f\colon \mathbb{R}^d \to \mathbb{R}$. Jax computes the gradient of $f$ by
`grad_f = jax.grad(f)`.

# Example

$$
f(x) = x^\top \, A \, x 
$$

compute $\nabla f$, and compare with the true gradient $2 \, A \,x$.

In [None]:
A = jnp.eye(2)

def f(x):
    return jnp.dot(x, A @ x)

grad_f = jax.grad(f)
grad_f(jnp.array([1., 2.]))

# Exercise

Consider an MLE objective function

$$
\ell(\theta) = x^\top A^{-1}(\theta) \, x,
$$

where $A(\theta) = \begin{bmatrix}2 & \mathrm{sigmoid}(\theta) \\ \mathrm{sigmoid}(\theta) & 3 \end{bmatrix}$, and $x\in\mathbb{R}^2$ is given. Compute $\nabla \ell$ at $\theta=2$, and compare the result to that of the finite difference.

```python
import jax.scipy

key = ?
x = jax.random.normal(?)

def sigmoid(theta):
    return ?

def nll(theta):
    A = ?

    chol = jax.scipy.linalg.cho_factor(A, lower=True)
    return jnp.dot(x, jax.scipy.linalg.cho_solve(chol, x))

grad_nll = ?

grad_nll(2.)

# compute gradient at 2. using finite difference
?
```

## Solution

In [None]:
import jax.scipy

key = jax.random.PRNGKey(999)
x = jax.random.normal(key, (2, ))

def sigmoid(theta):
    return 1 / (1 + jnp.exp(-theta))

def nll(theta):
    A = jnp.array([[2., sigmoid(theta)], 
                   [sigmoid(theta), 3.]])
    chol = jax.scipy.linalg.cho_factor(A, lower=True)
    return jnp.dot(x, jax.scipy.linalg.cho_solve(chol, x))

grad_nll = jax.grad(nll)

In [None]:
grad_nll(2.)

In [None]:
(nll(2. + 1e-3) - nll(2.)) / 1e-3

# Jacobian and Hessian

Jax computes Jacobian of any function $f\colon \mathbb{R}^{n} \to \mathbb{R}^m$ by `jax.jacfwd` or `jax.jacrev`. They give the same results but are implemented in different ways.

- `jax.jacfwd`. Forward-mode autodiff.
- `jax.jacrev`. Reverse-mode autodiff.

I am no expert in autodiff, but essentially, we use `jacfwd` when $n \ll m$ while use `jacrev` when $n \gg m$ for the best computation speed.

To obtain Hessian of a function, we could use either `jacfwd(jacrev(f))` or `jax.hessian(f)`.

# Example

Take the Jacobian of $\nabla f$ from the first example, where 

$$
f(x) = x^\top \, A \, x 
$$


In [None]:
jac_of_gradf = jax.jacfwd(grad_f)

# Evaluate at x=[0.1, 0.2]
jac_of_gradf(jnp.array([0.1, 0.2]))

# Exercise

Consider a simple perceptron

$$
\mathrm{NN}(x) = \mathrm{sigmoid} (w^\top \, x + b)
$$

Compute its gradient and Hessian w.r.t. the weight $w$.

Note: search `jax argnums` for specifying positional argument(s) to differentiate with respect to.

```python
def sigmoid(x):
    return ?

def nn(x, weights, b):
    """This function has three arguments. How to specify to which argument we differentiate?
    """
    return ?

key = jax.random.PRNGKey(999)
ws = jax.random.normal(key, (10, ))

key, _ = jax.random.split(key)
xs = jax.random.normal(key, (10, ))

grad_of_nn = ?
hessian_of_nn = ?

print(grad_of_nn(xs, ws, 1.))
print(hessian_of_nn(xs, ws, 1.))
```

## Solution

In [None]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def nn(x, weights, b):
    return sigmoid(jnp.dot(x, weights) + b)

key = jax.random.PRNGKey(999)
ws = jax.random.normal(key, (10, ))

key, _ = jax.random.split(key)
xs = jax.random.normal(key, (10, ))

grad_of_nn = jax.grad(nn, argnums=[1])
hessian_of_nn = jax.hessian(nn, argnums=[1])

print(grad_of_nn(xs, ws, 1.))
print(hessian_of_nn(xs, ws, 1.))

# Jacobian-vector product (JVP), vector-Jacobian product (VJP), and Hessian vector product (HVP)

In a plethora of applications, it is not the Jacobian/Hessian you want to solve, but the Matrix-vector product

$$
\mathrm{J} x
$$

for some Jacobian/Hessian $\mathrm{J}$ and a vector $x$. For example, a commonly seen operator in SDE/PDE:

$$
(A \phi)(x) = \nabla_x \phi \cdot a(x) + \frac{1}{2} \mathrm{tr}\big(\Gamma(x) \, \mathrm{H}_x \phi\big),
$$

Other examples: Gauss--Newton, quasi-Newton methods, and extended Kalman filter ...


This ese can be solved efficient using `jax.vjp` and `jax.jvp`. For details, see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions. 