# Automatic differentiation with JAX

In [3]:
import jax
import jax.numpy as jnp
from jax import random
from jax import grad

In [2]:
def f(x):
    if x > 0:
        return 2*x**3
    else:
        return 3 * x

In [6]:
key = random.PRNGKey(0)

# Sample standard normal random values with given shape and float dtype.
x = random.normal(key, ())
x

DeviceArray(-0.20584226, dtype=float32)

In [11]:
# First order derivative

# 2*3*x**2 if x > 0
# 3.0 if x < o

In [7]:
print(grad(f)(x))
print(grad(f)(-x))

3.0
0.2542262


In [None]:
# second order derivative

# 2*3*2*x if x > 0
# 0 if x < o

# third order derivative 
# 2*3*2 = 12.0 if x > 0
# 0 if x < 0

In [10]:
print(grad(grad(f))(-x))
print(grad(grad(grad(f)))(-x))

2.470107
12.0


## Other JAX autodiff highlights:

<ul>
    <li>Forward- and reverse-mode, totally composable</li>
    <li>Fast Jacobians and Hessians </li>
    <li>Complex number support (holomorphic and non-holomorphic) </li>
    <li>Jacobian pre-accumulation for elementwise operations (like gelu) </li>
</ul>

In [15]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650935


In [16]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621888
0.2526544


In [17]:
# Logistic Regression model 
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])



# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))


# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

## Use the `grad` function with its `argnums` argument to differentiate a function with respect to positional arguments.

In [18]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774647  -1.4901344 ]
W_grad [-0.16965583 -0.8774647  -1.4901344 ]
b_grad -0.29227248
W_grad [-0.16965583 -0.8774647  -1.4901344 ]
b_grad -0.29227248


In [19]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': DeviceArray([-0.16965583, -0.8774647 , -1.4901344 ], dtype=float32), 'b': DeviceArray(-0.29227248, dtype=float32)}


# Hessian-vector products with grad-of-grad


In [22]:
from IPython.display import Image
Image(url="images/hessianvector.PNG", width=800, height=400)


In [23]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

In [24]:
from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

jacfwd result, with shape (4, 3)
[[ 0.05981982  0.12885131  0.08857825]
 [ 0.04015379 -0.04928978  0.00684642]
 [ 0.12188746  0.01406284 -0.30466145]
 [ 0.00140473 -0.00472568  0.00263711]]
jacrev result, with shape (4, 3)
[[ 0.05981756  0.12883782  0.088576  ]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140429 -0.00472523  0.00263778]]


In [25]:
def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285551  0.04923055  0.03384332]
  [ 0.04922726  0.10603505  0.07289331]
  [ 0.03384374  0.07289909  0.05011415]]

 [[-0.03194788  0.03921682 -0.00544727]
  [ 0.03920877 -0.04812973  0.00668529]
  [-0.00544566  0.00668469 -0.00092851]]

 [[-0.01583768 -0.00182728  0.03958676]
  [-0.00182742 -0.00021084  0.0045677 ]
  [ 0.03959419  0.0045682  -0.09896691]]

 [[-0.00103555  0.0034837  -0.00194404]
  [ 0.00348448 -0.01172219  0.00654144]
  [-0.00194515  0.00654371 -0.00365165]]]


# JVPs in JAX code - forward mode autodiff




In [26]:
from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

In [32]:
y, u

(DeviceArray([0.13262251, 0.952067  , 0.6249393 , 0.99809873], dtype=float32),
 DeviceArray([-0.6668207 ,  0.36615133,  0.35238004, -0.10669762], dtype=float32))

# VJPs in JAX code - reverse mode autodiff

In [30]:
from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

In [31]:
v

(DeviceArray([ 0.01761599, -0.09849812, -0.16421208], dtype=float32),)


Refer here for more: 
https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
https://github.com/google/jax/blob/main/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb