#### Auto differentiation
1. Auto differentiation is required in the computation of *gradients*, *jacobians* and *hessians*. Although there are many different ways to perform differentiation, namely, analytical functions and finite differentiation, this tutorial will only cover auto differentiation. Comprehensive comparison among them can be found easily on the web and will not be covered here.
2. Lets start with a simple example of taking the *gradient* of a scalar valued function and see what is happening in the background

In [37]:
import time
import numpy as np

import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax import vmap, pmap, jit
from jax import grad, value_and_grad
from jax.test_util import check_grads

def product(x, y):
    z = x * y
    return z


x = 3.0
y = 4.0

z = product(x, y)

print(f"Input Variable x: {x}")
print(f"Input Variable y: {y}")
print(f"Product z: {z}\n")

# dz / dx
dx = grad(product, argnums=0)(x, y) # diff wrt first arg, default of argnums is 0
print(f"Gradient of z wrt x: {dx}")

# dz / dy
dy = grad(product, argnums=1)(x, y) # diff wrt second arg
print(f"Gradient of z wrt y: {dy}")

# p = dz / d(x,y) can also be done
p = grad(product, argnums=(0,1))(x,y)
print(f"Gradient of z wrt x: {p[0]}, gradient of z wrt x: {p[1]}")

print("Differentiating wrt x")
print(make_jaxpr(grad(product, argnums=0))(x,y))

print("Differentiating wrt y")
print(make_jaxpr(grad(product, argnums=1))(x,y))

print("Differentiating wrt x, y")
print(make_jaxpr(grad(product, argnums=(0,1)))(x,y))

Input Variable x: 3.0
Input Variable y: 4.0
Product z: 12.0

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0
Gradient of z wrt x: 4.0, gradient of z wrt x: 3.0
Differentiating wrt x
{ lambda ; a:f32[] b:f32[]. let _:f32[] = mul a b; c:f32[] = mul 1.0 b in (c,) }
Differentiating wrt y
{ lambda ; a:f32[] b:f32[]. let _:f32[] = mul a b; c:f32[] = mul a 1.0 in (c,) }
Differentiating wrt x, y
{ lambda ; a:f32[] b:f32[]. let
    _:f32[] = mul a b
    c:f32[] = mul a 1.0
    d:f32[] = mul 1.0 b
  in (d, c) }


3. Notice that the argument that we are differentiating wrt is a constant 1.0
4. Using **`vmap`**, we can batch compute gradients of a scalar valued function at multiple points

In [38]:
def activate(x):
    """Applies tanh activation."""
    return jnp.tanh(x)

key = random.PRNGKey(1234)
x = random.normal(key=key, shape=(5,))
activations = activate(x)

grads_batch = vmap(grad(activate))(x)
print("Gradients for the batch: ", grads_batch)

print("Jaxpr:\n", make_jaxpr(vmap(grad(activate)))(x))

Gradients for the batch:  [0.482287   0.45585027 0.99329686 0.09532695 0.8153717 ]
Jaxpr:
 { lambda ; a:f32[5]. let
    b:f32[5] = tanh a
    c:f32[5] = sub 1.0 b
    d:f32[5] = mul 1.0 c
    e:f32[5] = mul d b
    f:f32[5] = add_any d e
  in (f,) }


5. Again, we can incorporate additional transformation, such as **`jit`**

In [39]:
jitted_grads_batch = jit(vmap(grad(activate)))

for _ in range(3):
    start_time = time.time()
    print("Gradients for the batch: ", jitted_grads_batch(x))
    print(f"Time taken: {time.time() - start_time:.4f} second")
    print("="*50)

Gradients for the batch:  [0.482287   0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.0219 second
Gradients for the batch:  [0.482287   0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.0007 second
Gradients for the batch:  [0.482287   0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.0003 second


6. Jax also provides a finite difference utility function **`check_grads`** for developers to verify the computation of the gradients

In [40]:
try:
    check_grads(jitted_grads_batch, (x,), order=1)
    print("Gradient match with gradient calculated using finite differences")
except Exception as ex:
    print(type(ex).__name__, ex)

Gradient match with gradient calculated using finite differences
