<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Jax_tutorial_(4)_Advanced_Automatic_Differentiation_in_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

In [3]:
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

In [5]:
dfdx(1.)

Array(4., dtype=float32, weak_type=True)

In [6]:
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))

10.0
6.0
0.0


### Multivariate Case
 - Hessian of a real-valued function of several variables can be identified with the jacobian of its gradient

In [7]:
def hessian(f):
  return jax.jacfwd(jax.grad(f))

In [8]:
import jax.numpy as jnp

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))

Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

Often, we aren't interested in computing the full hessian itself. We use trick such as Hessian-vector product.

### Stopping Gradients

In [9]:
value_fn = lambda theta, state: jnp.dot(theta, state) # assume the value estimate is linear fcn of a state
theta = jnp.array([0.1, -0.1, 0.])

Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$

In [10]:
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

In [12]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)

td_update = jax.grad(td_loss)
delta_update = td_update(theta, s_tm1, r_t, s_t)

delta_update

Array([ 1.2,  2.4, -1.2], dtype=float32)

### Straight-through estimator using `stop_gradient`

straight-through estimator is a trick for defining a gradient of a function that is otherwise non-differentiable.

In [14]:
def f(x):
  return jnp.round(x) # non-differentiable

def straight_through_f(x):
  # create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x)", straight_through_f(3.2))

print("grad(f)(x): ", jax.grad(f)(3.2))
print("grad(straight_through_f)(x)", jax.grad(straight_through_f)(3.2))

f(x):  3.0
straight_through_f(x) 3.0
grad(f)(x):  0.0
grad(straight_through_f)(x) 1.0


Jax's way to compute the gradient per-sample in an easy but efficient way

In [15]:
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

In [16]:
dtdloss_dtheta = jax.grad(td_loss) # fcn that computes the gradient of loss w.r.t theta (on single input)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)

Array([ 1.2,  2.4, -1.2], dtype=float32)

In [17]:
almost_perex_grads = jax.vmap(dtdloss_dtheta) # vectorize input / output

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)

Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

In [19]:
# But we want to use single theta

inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) # vectorize input / output

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

In [20]:
perex_grads = jax.jit(inefficient_perex_grads)

In [21]:
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

Array([[ 1.2,  2.4, -1.2],
       [ 1.2,  2.4, -1.2]], dtype=float32)

In [23]:
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()

12.3 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.2 µs ± 2.09 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
