# Autodiff

For more please visit [autodiff_cook]

In [1]:
from jax import grad
import jax.numpy as jnp

## grad

In [6]:
def operation(x):
    return 2*x

grad_operation = grad(operation)

In [8]:
print(grad_operation(100.))



2.0


Let's test its math

$f(x)=x^2+5x; f'(x) = 2x+5$

In [29]:
def operation(x):
    return jnp.power(x,2) + 5*x

grad_operation = grad(operation)
print(grad_operation(100.))
print(grad_operation(6.))

205.0
17.0


**The out put of the function can only be a scala, that we can backward on, not a vector or matrix**

In [12]:
try:
    print(grad_operation(jnp.array([10.,20.,30.])))
except Exception as e:
    print(e)

Gradient only defined for scalar-output functions. Output had shape: (3,).


The input can be in richer shape though

In [36]:
def relu(x):
    return jnp.maximum(x, jnp.zeros(x.shape))

def operation(w):
    x = jnp.array([100.,50.,30.,20.,20])
    for i in range(3):
        x = x*w
        x = relu(x)
    return (x*w).sum()

grad_operation = grad(operation)
print(grad_operation(jnp.array([-1,1,0,-.5,.6])))

[  0.       200.         0.         0.        17.280003]


## Weights as list of things

In [37]:
import jax

In [41]:
key = jax.random.PRNGKey(42)

In [49]:
def relu(x):
    return jnp.maximum(x, jnp.zeros(x.shape))

layers = [jax.random.normal(key, (28,16)),
          jax.random.normal(key, (16,16)),
          jax.random.normal(key, (16,10))]

def operation(w):
    x = jax.random.normal(key,(8,28))
    for i in range(3):
        x = x@layers[i]
        x = relu(x)
    return (x.sum(-1)-jnp.ones(x.shape[0])).mean()

grad_operation = grad(operation)
g1, g2, g3 = grad_operation(layers)

## Weights as dictionary

In [57]:
layers = dict((f"layer_{i}", w) for i, w in enumerate([jax.random.normal(key, (28,16)),
          jax.random.normal(key, (16,16)),
          jax.random.normal(key, (16,10))]))

def operation(w):
    x = jax.random.normal(key,(8,28))
    for i in range(3):
        x = x@layers[f"layer_{i}"]
        x = relu(x)
    return (x.sum(-1)-jnp.ones(x.shape[0])).mean()

grad_operation = grad(operation)
grad_dict = grad_operation(layers)

In [58]:
for k, w in grad_dict.items():
    print(f"{k}:\t{w.shape}")

layer_0:	(28, 16)
layer_1:	(16, 16)
layer_2:	(16, 10)


## Jacobian

You can view the difference between [Gradient, Jacobian, Hessian, Laplacian](https://najeebkhan.github.io/blog/VecCal.html)

The difference between derivative, gradient and Jacobian is between the input/output dimension

In [62]:
W = jax.random.normal(key,(5,10))

n=5; m=10

In [78]:
def matmul(x):
    return x@W

J = jax.jacfwd(matmul)(jax.random.normal(key,(5,)))

J shape m,n

In [81]:
J.shape, J

((10, 5),
 DeviceArray([[ 2.3575046 , -0.13301466, -0.44822735,  1.226089  ,
               -0.9021458 ],
              [-0.0676787 , -1.8649583 , -0.54558253, -0.58878934,
                1.5137577 ],
              [-0.36718026,  0.53279227, -0.47208768,  0.20114596,
               -0.06903279],
              [-0.03573515,  1.4438008 , -0.6238551 ,  0.39448676,
               -0.760346  ],
              [-0.30748853, -0.95318866,  0.07758974, -0.4147216 ,
                0.9220106 ],
              [ 1.1332345 ,  1.2754252 , -1.1710014 , -0.15716387,
               -1.2646029 ],
              [-0.5726349 , -0.35616288,  0.5531245 , -0.76754636,
                3.132822  ],
              [-0.44970438, -1.4492625 , -0.3600667 ,  0.5882494 ,
               -0.53693825],
              [-0.35575888,  0.5369553 ,  0.07277003,  0.43019757,
                1.9336216 ],
              [ 0.25047532, -0.3869919 , -0.8419892 , -0.99975586,
               -0.28862053]], dtype=float32))

In [84]:
J = jax.jacrev(matmul)(jax.random.normal(key,(5,)))
J.shape,J

((10, 5),
 DeviceArray([[ 2.3575046 , -0.13301466, -0.44822735,  1.226089  ,
               -0.9021458 ],
              [-0.0676787 , -1.8649583 , -0.54558253, -0.58878934,
                1.5137577 ],
              [-0.36718026,  0.53279227, -0.47208768,  0.20114596,
               -0.06903279],
              [-0.03573515,  1.4438008 , -0.6238551 ,  0.39448676,
               -0.760346  ],
              [-0.30748853, -0.95318866,  0.07758974, -0.4147216 ,
                0.9220106 ],
              [ 1.1332345 ,  1.2754252 , -1.1710014 , -0.15716387,
               -1.2646029 ],
              [-0.5726349 , -0.35616288,  0.5531245 , -0.76754636,
                3.132822  ],
              [-0.44970438, -1.4492625 , -0.3600667 ,  0.5882494 ,
               -0.53693825],
              [-0.35575888,  0.5369553 ,  0.07277003,  0.43019757,
                1.9336216 ],
              [ 0.25047532, -0.3869919 , -0.8419892 , -0.99975586,
               -0.28862053]], dtype=float32))

### Hessian matrix

Hessian Shape m,n,n

In [87]:
H = jax.jacrev(jax.jacfwd(matmul))(jax.random.normal(key,(5,)))
H.shape

(10, 5, 5)