In [None]:
import jax
import jax.numpy as jnp
from IPython.display import display, Latex

## If else condition with `lax`

$$
f(\mathbf{x}) = \sum_{x \in \mathbf{x}} \begin{cases}
    x^2,& \text{if } x \gt 5\\
    x^3,             & \text{otherwise}
\end{cases}
$$

In [None]:
x = [jnp.array(10.0), jnp.array(2.0)]

@jax.jit
@jax.value_and_grad
def f(x):
  bool_val = jax.tree_map(lambda val: val > 5.0, x)
  ans = jax.tree_map(lambda val, bool: jax.lax.cond(bool, lambda: val**2, lambda: val**3), x, bool_val)
  return jax.tree_util.tree_reduce(lambda a, b: a + b, ans)

value, grad = f(x)

display(Latex(f"$f(\mathbf{{x}}) = {value}$"))
for idx in range(len(x)):
  display(Latex(f"$\\frac{{df}}{{dx_{idx}}} = {grad[idx]}$"))



<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

## Pair-wise distance with `vmap`

In [None]:
# create vour pairwise function
def distance(a, b):
    return jnp.linalg.norm(a - b)


# map based combinator to operate on all pairs
def all_pairs(f):
    f = jax.vmap(f, in_axes=(None, 0))
    f = jax.vmap(f, in_axes=(0, None))
    return f


# transform to operate over sets
distances = all_pairs(distance)

# Example
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([3.0, 4.0, 5.0])
distances(x, y)

DeviceArray([[2., 3., 4.],
             [1., 2., 3.],
             [0., 1., 2.]], dtype=float32)

## Compute Hessian with `jax`

Let us consider Linear regression loss function

\begin{align}
\mathcal{L}(\boldsymbol{\theta}) &= (\boldsymbol{y} - X\boldsymbol{\theta})^T(\boldsymbol{y} - X\boldsymbol{\theta})\\
\frac{d\mathcal{L}}{d\boldsymbol{\theta}} &= -2X^T\boldsymbol{y} + 2X^TX\boldsymbol{\theta}\\
H_{\mathcal{L}}(\boldsymbol{\theta}) &= 2X^TX
\end{align}

In [None]:
def loss_function_per_point(theta, x, y):
  y_pred = x.T@theta
  return jnp.square(y_pred - y)

def loss_function(theta, x, y):
  loss_per_point = jax.vmap(loss_function_per_point, in_axes=(None, 0, 0))(theta, x, y)
  return jnp.sum(loss_per_point)

def gt_loss(theta, x, y):
  return jnp.sum(jnp.square(x@theta - y))

def gt_grad(theta, x, y):
  return 2 * (x.T@x@theta - x.T@y)

def gt_hess(theta, x, y):
  return 2 * x.T@x 

### Simulate dataset 

In [None]:
key = jax.random.PRNGKey(0)
key, subkey1, subkey2 = jax.random.split(key, num=3)
N = 100
D = 11
x = jax.random.uniform(key, shape=(N, D))
y = jax.random.uniform(subkey1, shape=(N,))
theta = jax.random.uniform(subkey2, shape=(D,))

### Verify loss and gradient values

In [None]:
loss_and_grad_function = jax.value_and_grad(loss_function)

loss_val, grad = loss_and_grad_function(theta, x, y)

assert jnp.allclose(loss_val, gt_loss(theta, x, y))
assert jnp.allclose(grad, gt_grad(theta, x, y))

### Verify hessian matrix

#### Way-1 

In [None]:
hess = jax.hessian(loss_function)(theta, x, y)

assert jnp.allclose(hess, gt_hess(theta, x, y))

#### Way-2

In [None]:
hess = jax.jacfwd(jax.jacrev(loss_function))(theta, x, y)

assert jnp.allclose(hess, gt_hess(theta, x, y))