# JAX differentiation

In [22]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Scalar function

In [23]:
def simplefun(x,w):
    return (x + w) * x**2

In [8]:
# gradient
x = jnp.array(3.0)
w = jnp.array(1.0)
x_grad = grad(simplefun, argnums=0)(x,w)
w_grad = grad(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

33.0
9.0


## Vector function

From the [documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html): These two functions (`jacfwd`, `jacrev`) compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices.

In [13]:
from jax import jacfwd, jacrev
x = jnp.array([3.0, 1.0, 2.0])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print("jacrev result:")
print(x_grad)
print(w_grad)

print("\njacfwd result:")
x_grad = jacfwd(simplefun, argnums=0)(x,w)
w_grad = jacfwd(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

jacrev result:
[[33.   0.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.  20.8]]
[[9. 0. 0.]
 [0. 1. 0.]
 [0. 0. 4.]]

jacfwd result:
[[33.   0.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.  20.8]]
[[9. 0. 0.]
 [0. 1. 0.]
 [0. 0. 4.]]


## How Far Can This Go?

The answer is: as far as you'd like! That is the real tensor-based Jacobian that is implemented. I am not going to do it, because I don't know the math.

In [19]:
x = jnp.array([[1.,4.,3.],
               [12.,42.,-1.],
               [-.5, -2, 3]])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
print(x_grad)
print(x_grad.shape)

[[[[ 5.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  4.0000000e+01  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  4.0199997e+01]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 4.5600000e+02  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  5.2080000e+03  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00 -1.4000001e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00 

In [22]:
x = jnp.array([[[1.,4.,3.],
                [12.,42.,-1.],
                [-.5, -2, 3]],
              
               [[5.,4.,3.],
                [19.,32.,-1.],
                [-.5, 1.3, 3]]])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
print(x_grad)
print(x_grad.shape)

[[[[[[ 5.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


   [[[ 0.0000000e+00  4.0000000e+01  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


   [[[ 0.0000000e+00  0.0000000e+00  4.0199997e+01]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]]



  [[[[ 0.0000000e+00  0.0000000e+00  0.00000

# JAX Jacobian for Operations

This is so simple and cool!

In [23]:
def multiply(a,b):
    return jnp.dot(a, b)

In [27]:
grad(multiply)

<function __main__.multiply(a, b)>

In [30]:
grad(multiply)(jnp.array([1.,2.]), jnp.array([3.,4]))

DeviceArray([3., 4.], dtype=float32)

In [31]:
import jax
jax.grad

<function jax._src.api.grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[Any] = ()) -> Callable>

# Some other tests

In [40]:
from jax import jacfwd, jacrev
x = jnp.array([3., 4., 5.])
w = jnp.array([1.0])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print("jacrev result:")
print(x_grad)
print(w_grad)

print("\njacfwd result:")
x_grad = jacfwd(simplefun, argnums=0)(x,w)
w_grad = jacfwd(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

jacrev result:
[[33.  0.  0.]
 [ 0. 56.  0.]
 [ 0.  0. 85.]]
[[ 9.]
 [16.]
 [25.]]

jacfwd result:
[[33.  0.  0.]
 [ 0. 56.  0.]
 [ 0.  0. 85.]]
[[ 9.]
 [16.]
 [25.]]


In [33]:
simplefun(x,w)

DeviceArray([ 36.,  80., 150.], dtype=float32)

In [34]:
x+w

DeviceArray([4., 5., 6.], dtype=float32)

In [35]:
x**2

DeviceArray([ 9., 16., 25.], dtype=float32)

In [38]:
from jax import jacfwd, jacrev
x = jnp.array([[3., 4., 5.],
               [5., 8., 10.],
               [3., 3., -7.]])
w = jnp.array([1.0])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print(simplefun(x,w))
print("jacrev result:")
print(x_grad)
print(x_grad.shape)
print(w_grad)

[[  36.   80.  150.]
 [ 150.  576. 1100.]
 [  36.   36. -294.]]
jacrev result:
[[[[ 33.   0.   0.]
   [  0.   0.   0.]
   [  0.   0.   0.]]

  [[  0.  56.   0.]
   [  0.   0.   0.]
   [  0.   0.   0.]]

  [[  0.   0.  85.]
   [  0.   0.   0.]
   [  0.   0.   0.]]]


 [[[  0.   0.   0.]
   [ 85.   0.   0.]
   [  0.   0.   0.]]

  [[  0.   0.   0.]
   [  0. 208.   0.]
   [  0.   0.   0.]]

  [[  0.   0.   0.]
   [  0.   0. 320.]
   [  0.   0.   0.]]]


 [[[  0.   0.   0.]
   [  0.   0.   0.]
   [ 33.   0.   0.]]

  [[  0.   0.   0.]
   [  0.   0.   0.]
   [  0.  33.   0.]]

  [[  0.   0.   0.]
   [  0.   0.   0.]
   [  0.   0. 133.]]]]
(3, 3, 3, 3)
[[[  9.]
  [ 16.]
  [ 25.]]

 [[ 25.]
  [ 64.]
  [100.]]

 [[  9.]
  [  9.]
  [ 49.]]]


# VJP

In [13]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import vjp
key = random.PRNGKey(0)

In [4]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

In [5]:
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

In [6]:
f = lambda W: predict(W, b, inputs)

In [14]:
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

In [16]:
y, vjp_fun = vjp(f, W)

In [18]:
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
print(key, subkey)
print(u)

[1082916127 1719789088] [3156601263 2504717133]
[-0.7607929  -0.5476383   0.08130554  0.2596837 ]


In [20]:
v = vjp_fun(u)
v

(DeviceArray([-0.05722704, -0.07111153, -0.09522615], dtype=float32),)