Automatic differentiation with `grad`

$$
\nabla_x f = \frac{df}{dx}
$$

In [None]:
import jax
import jax.numpy as jnp

print(jax.devices())
#!nvidia-smi

In [None]:
def fn(x):
    return jnp.sin(x**2)

grad_fn = jax.grad(fn)
grad_fn(1.0)

In [None]:
grad3_fn = jax.grad(jax.grad(jax.grad(fn)))
grad3_fn(1.0)

float32 -> float64

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

In [None]:
def fn(x):
    return jnp.sin(x**2)

grad_fn = jax.grad(fn)
grad_fn(1.0)

grad_fn(jnp.array(1.0, dtype=jnp.float64))

All functions should written in JAX!

In [None]:
import numpy as np
def fn(x):
    return np.sin(x**2)

jax.grad(fn)(1.0)

Calculate Jacobian

$$
J = \frac{\partial (z_1, z_2, z_3)}{\partial (x_1, x_2, x_3)}
=
\left(
\begin{matrix}
\frac{\partial z_1}{\partial x_1} & \frac{\partial z_1}{\partial x_2} & \frac{\partial z_1}{\partial x_3} \\
\frac{\partial z_2}{\partial x_1} & \frac{\partial z_2}{\partial x_2} & \frac{\partial z_2}{\partial x_3} \\
\frac{\partial z_3}{\partial x_1} & \frac{\partial z_3}{\partial x_2} & \frac{\partial z_3}{\partial x_3} \\
\end{matrix}
\right)
$$

In [None]:
def cartesian_to_spherical(x):
    r = jnp.sqrt(x[0]**2 + x[1]**2 + x[2]**2)
    theta = jnp.arccos(x[2] / r)
    phi = jnp.arctan2(x[1], x[0])
    return jnp.array([r, theta, phi])

In [None]:
x = jnp.array([1.0, 1.0, 1.0], dtype=jnp.float64)

In [None]:
jax.jacfwd(cartesian_to_spherical)(x)

In [None]:
jax.jacrev(cartesian_to_spherical)(x)

Compilation with `jit`

(Just-in-time compilation)

In [None]:
def slow_f(x):
    return x * x + x * 2.0

fast_f = jax.jit(slow_f)

x = jnp.ones((5000, 5000))
%timeit -n10 -r10 fast_f(x)
%timeit -n10 -r10 slow_f(x)

The Hamiltonian of the classical Coulomb gas:
$$H= \sum_{i<j} \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}_j|} + \sum_i  \boldsymbol{x}_i^2 . $$
The second term is a harmonic trapping potential. 

It makes our story easier (no need to consider periodic bondary condition or Ewald sum for long range interaction.)


In [None]:
n =10
i, j = jnp.triu_indices(n, k=1)
    
print(i)
print(j)

In [None]:
def energy_fun(x, n, dim):
    i, j = jnp.triu_indices(n, k=1)
    rij = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    V = jnp.sum(x**2) + jnp.sum(1/rij)
    return V

fast_energy_fun = jax.jit(energy_fun, static_argnums=(1, 2))

In [None]:
key = jax.random.PRNGKey(42)
print(key)
key, subkey = jax.random.split(key)
print(key, subkey)

n = 100
dim = 2
x = jax.random.normal(key, (n, dim))

In [None]:
%timeit -n10 -r10 energy_fun(x, n, dim)
%timeit -n10 -r10 fast_energy_fun(x, n, dim)

In [None]:
energy_fn = lambda x: energy_fun(x, n, dim)

grad_fn = jax.grad(energy_fn)
grad_fn_jacfwd = jax.jacfwd(energy_fn)
grad_fn_jacrev = jax.jacrev(energy_fn)

print(grad_fn)

In [None]:
g1 = grad_fn(x)
g2 = grad_fn_jacfwd(x)
g3 = grad_fn_jacrev(x)
print(jnp.isclose(g1, g2))
print(jnp.isclose(g1, g3))

In [None]:
%timeit -n10 -r100 grad_fn(x)
%timeit -n10 -r100 grad_fn_jacfwd(x)
%timeit -n10 -r100 grad_fn_jacrev(x)

Auto-vectorization with `vmap`

In [None]:
n, dim = 6, 2
x = jax.random.normal(key, (n, dim))
print("x:", x.shape, x)
E = energy_fun(x, n, dim)
print("E:", E)

In [None]:
batch = 1024
n = 6
dim = 2

energy_fun_vmap = jax.vmap(energy_fun, in_axes=(0, None, None), out_axes=(0))

x = jax.random.normal(key, (batch, n, dim))
print("x:", x.shape)
E = energy_fun_vmap(x, n, dim)
print("E:", E.shape, E)