In [1]:
from jax import numpy as jnp

## A Simple Function

In [2]:
x = jnp.arange(4.0)
x

Array([0., 1., 2., 3.], dtype=float32)

In [3]:
y = lambda x: 2 * jnp.dot(x,x)

In [4]:
y(x)

Array(28., dtype=float32)

In [5]:
from jax import grad

In [6]:
x_grad = grad(y)(x)

In [7]:
x_grad

Array([ 0.,  4.,  8., 12.], dtype=float32)

In [8]:
x_grad == 4 * x

Array([ True,  True,  True,  True], dtype=bool)

In [9]:
y = lambda x: x.sum()
grad(y)(x)

Array([1., 1., 1., 1.], dtype=float32)

## Backward for Non-scalar Variables

In [10]:
y = lambda x: x*x

In [11]:
grad(lambda x: y(x).sum())(x)

Array([0., 2., 4., 6.], dtype=float32)

## Detaching Computation

In [12]:
import jax

In [13]:
y = lambda x: x*x

In [15]:
u = jax.lax.stop_gradient(y(x))
z = lambda x: u * x

grad(lambda x: z(x).sum())(x) == y(x)

Array([ True,  True,  True,  True], dtype=bool)

## Gradients and Python Control Flow

In [16]:
def f(a):
    b = a*2
    while jnp.linalg.norm(b) < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = 100*b
    return c        

In [17]:
from jax import random

In [18]:
a = random.normal(random.PRNGKey(1), ())

In [19]:
d = f(a)
d_grad = grad(f)(a)

d_grad == d / a

Array(True, dtype=bool)