# JAX differentiation

In [22]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Scalar function

In [23]:
def simplefun(x,w):
    return (x + w) * x**2

In [8]:
# gradient
x = jnp.array(3.0)
w = jnp.array(1.0)
x_grad = grad(simplefun, argnums=0)(x,w)
w_grad = grad(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

33.0
9.0


## Vector function

From the [documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html): These two functions (`jacfwd`, `jacrev`) compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices.

In [13]:
from jax import jacfwd, jacrev
x = jnp.array([3.0, 1.0, 2.0])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print("jacrev result:")
print(x_grad)
print(w_grad)

print("\njacfwd result:")
x_grad = jacfwd(simplefun, argnums=0)(x,w)
w_grad = jacfwd(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

jacrev result:
[[33.   0.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.  20.8]]
[[9. 0. 0.]
 [0. 1. 0.]
 [0. 0. 4.]]

jacfwd result:
[[33.   0.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.  20.8]]
[[9. 0. 0.]
 [0. 1. 0.]
 [0. 0. 4.]]


## How Far Can This Go?

The answer is: as far as you'd like! That is the real tensor-based Jacobian that is implemented. I am not going to do it, because I don't know the math.

In [19]:
x = jnp.array([[1.,4.,3.],
               [12.,42.,-1.],
               [-.5, -2, 3]])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
print(x_grad)
print(x_grad.shape)

[[[[ 5.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  4.0000000e+01  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  4.0199997e+01]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 4.5600000e+02  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  5.2080000e+03  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00 -1.4000001e+00]
   [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00 

In [22]:
x = jnp.array([[[1.,4.,3.],
                [12.,42.,-1.],
                [-.5, -2, 3]],
              
               [[5.,4.,3.],
                [19.,32.,-1.],
                [-.5, 1.3, 3]]])
w = jnp.array([1.0, -1.0, 2.2])
x_grad = jacrev(simplefun, argnums=0)(x,w)
print(x_grad)
print(x_grad.shape)

[[[[[[ 5.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


   [[[ 0.0000000e+00  4.0000000e+01  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]


   [[[ 0.0000000e+00  0.0000000e+00  4.0199997e+01]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]

    [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
     [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]]]]



  [[[[ 0.0000000e+00  0.0000000e+00  0.00000

# JAX Jacobian for Operations

This is so simple and cool!

In [23]:
def multiply(a,b):
    return jnp.dot(a, b)

In [27]:
grad(multiply)

<function __main__.multiply(a, b)>

In [30]:
grad(multiply)(jnp.array([1.,2.]), jnp.array([3.,4]))

DeviceArray([3., 4.], dtype=float32)

In [31]:
import jax
jax.grad

<function jax._src.api.grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[Any] = ()) -> Callable>

# Some other tests

In [40]:
from jax import jacfwd, jacrev
x = jnp.array([3., 4., 5.])
w = jnp.array([1.0])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print("jacrev result:")
print(x_grad)
print(w_grad)

print("\njacfwd result:")
x_grad = jacfwd(simplefun, argnums=0)(x,w)
w_grad = jacfwd(simplefun, argnums=1)(x,w)
print(x_grad)
print(w_grad)

jacrev result:
[[33.  0.  0.]
 [ 0. 56.  0.]
 [ 0.  0. 85.]]
[[ 9.]
 [16.]
 [25.]]

jacfwd result:
[[33.  0.  0.]
 [ 0. 56.  0.]
 [ 0.  0. 85.]]
[[ 9.]
 [16.]
 [25.]]


In [33]:
simplefun(x,w)

DeviceArray([ 36.,  80., 150.], dtype=float32)

In [34]:
x+w

DeviceArray([4., 5., 6.], dtype=float32)

In [35]:
x**2

DeviceArray([ 9., 16., 25.], dtype=float32)

In [38]:
from jax import jacfwd, jacrev
x = jnp.array([[3., 4., 5.],
               [5., 8., 10.],
               [3., 3., -7.]])
w = jnp.array([1.0])
x_grad = jacrev(simplefun, argnums=0)(x,w)
w_grad = jacrev(simplefun, argnums=1)(x,w)
print(simplefun(x,w))
print("jacrev result:")
print(x_grad)
print(x_grad.shape)
print(w_grad)

[[  36.   80.  150.]
 [ 150.  576. 1100.]
 [  36.   36. -294.]]
jacrev result:
[[[[ 33.   0.   0.]
   [  0.   0.   0.]
   [  0.   0.   0.]]

  [[  0.  56.   0.]
   [  0.   0.   0.]
   [  0.   0.   0.]]

  [[  0.   0.  85.]
   [  0.   0.   0.]
   [  0.   0.   0.]]]


 [[[  0.   0.   0.]
   [ 85.   0.   0.]
   [  0.   0.   0.]]

  [[  0.   0.   0.]
   [  0. 208.   0.]
   [  0.   0.   0.]]

  [[  0.   0.   0.]
   [  0.   0. 320.]
   [  0.   0.   0.]]]


 [[[  0.   0.   0.]
   [  0.   0.   0.]
   [ 33.   0.   0.]]

  [[  0.   0.   0.]
   [  0.   0.   0.]
   [  0.  33.   0.]]

  [[  0.   0.   0.]
   [  0.   0.   0.]
   [  0.   0. 133.]]]]
(3, 3, 3, 3)
[[[  9.]
  [ 16.]
  [ 25.]]

 [[ 25.]
  [ 64.]
  [100.]]

 [[  9.]
  [  9.]
  [ 49.]]]


# VJP

In [13]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import vjp, jacrev
key = random.PRNGKey(0)

In [2]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

In [3]:
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

In [4]:
f = lambda W: predict(W, b, inputs)

In [5]:
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

In [6]:
y, vjp_fun = vjp(f, W)

In [16]:
key, subkey = random.split(key)
u = random.normal(W_key, (4,))
print(key, subkey)
print(u)
print(y)
print(f(W))

[2384771982 3928867769] [1278412471 2182328957]
[-0.36838785  0.35917208  0.01144757 -0.12499736]
[0.13262251 0.952067   0.6249393  0.99809873]
[0.13262251 0.952067   0.6249393  0.99809873]


vjp is equivalent to vector multiplying a jacobian

In [17]:
v = vjp_fun(u)
print(v)
print(jnp.dot(u, jacrev(f, argnums=0)(W)))

(DeviceArray([-0.00639229, -0.06441288, -0.03398955], dtype=float32),)
[-0.00639229 -0.06441288 -0.03398955]


vjp for tensors

In [24]:
def f(x1,x2):
    return jnp.dot(x1, x2)

x1 = random.normal(key, (3,3))
x2 = random.normal(key, (3,3))
_, vjp_fun = vjp(f, x1, x2)
# u*df/dx1
u = random.normal(key, (3,3))
print("====== vjp =============")
print(vjp_fun(u))
print("====== u * jacobian ====")
print(jnp.tensordot(u, jacrev(f, argnums=0)(x1,x2)))

(DeviceArray([[ 3.2109752 , -1.2424183 , -2.8282166 ],
             [-1.2424183 ,  1.0281112 ,  0.97847855],
             [-2.8282166 ,  0.97847855,  2.5156035 ]], dtype=float32), DeviceArray([[ 2.282812 , -1.9776087,  2.136005 ],
             [-1.9776087,  2.468872 , -1.7930149],
             [ 2.136005 , -1.7930149,  2.0030057]], dtype=float32))
[[ 3.2109752  -1.2424183  -2.8282166 ]
 [-1.2424183   1.0281112   0.97847855]
 [-2.8282166   0.97847855  2.5156035 ]]


In [31]:
x1 = random.normal(key, (3,3,3))
x2 = random.normal(key, (3,3,3))
_, vjp_fun = vjp(f, x1, x2)
# u*df/dx1
u = random.normal(key, (3,3,3,3))
print("====== vjp =============")
print(vjp_fun(u))
print("====== u * jacobian ====")
print(jacrev(f, argnums=0)(x1,x2))

(DeviceArray([[[ 1.0267432 , -1.8556583 , -0.80159235],
              [ 1.1297358 ,  1.2436607 , -1.1522541 ],
              [ 2.1084635 , -3.6130662 ,  1.6428506 ]],

             [[ 3.4298165 , -0.04724604,  4.1091905 ],
              [-3.4140415 , -3.7138467 ,  1.3979788 ],
              [-0.58238673, -2.0433717 , -3.1903377 ]],

             [[ 4.791793  ,  3.3656254 ,  1.732102  ],
              [-1.9840615 , -2.3594637 ,  0.41650546],
              [ 0.86179733, -3.2646542 ,  1.2744228 ]]], dtype=float32), DeviceArray([[[-0.470088  , -0.9264761 ,  2.8457048 ],
              [-0.17020647,  1.8542234 ,  0.3037143 ],
              [-0.25616166, -0.42286313, -2.552411  ]],

             [[-0.9706892 ,  5.6405954 ,  0.43347734],
              [ 2.0002906 ,  1.9044861 ,  4.6858964 ],
              [-0.47892877, -6.030055  , -3.258853  ]],

             [[-1.08567   ,  4.992403  , -1.984975  ],
              [-1.1949979 ,  5.896005  , -3.2992806 ],
              [ 2.5832283 , -6.9359612

From Matthieu Blondel's example. I am matching his definition of vjp with the built-in vjp function

In [1]:
def dot(x, W):
    return jnp.dot(W, x)

In [19]:
def dot_make_vjp(x, W):
    def vjp(u):
        return W.T.dot(u), jnp.outer(u, x)
    return vjp

In [22]:
dot.make_vjp = dot_make_vjp

In [32]:
x = jnp.array([0.52, 1.12,  0.77])

In [9]:
dot.make_vjp(x, W)

<function __main__.dot_make_vjp.<locals>.vjp(u)>

In [33]:
y, vjp_fun = vjp(dot, x, W)

In [34]:
u = random.normal(subkey, y.shape)

In [35]:
vjp_fun(u)

(DeviceArray([ 0.10106743,  0.62433666, -0.00314065], dtype=float32),
 DeviceArray([-0.14266232, -0.3072727 , -0.21124996], dtype=float32))

In [36]:
print(y)

-2.7315187


In [29]:
custom_vjp_fun = dot.make_vjp(x, W)

In [30]:
custom_vjp_fun(u)

(DeviceArray(5.314601, dtype=float32),
 DeviceArray([[-0.19156167, -0.41259438, -0.28365862],
              [-1.1833582 , -2.5487716 , -1.7522804 ],
              [ 0.00595273,  0.01282127,  0.00881463]], dtype=float32))