In [1]:
import jax
import jax.numpy as jnp

# gradients

In [52]:
def f(x):
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2)
    return ln_x/ln_2

from jax import make_jaxpr

x_0 = 2.0

make_jaxpr(f)(x_0)  # let's watch JAX trace the function! 

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

In [53]:
gradient = jax.grad(f)
make_jaxpr(gradient)(x_0)  # this is the JAX expression for the gradient
# note how the original expression is embedded in the gradient expression!
# this is JAX tracing the computation graph and then using the primitive rules
# for log and division to compute the gradient (d/dx log(x) = 1/x, d/dx x/y = 1/y)

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    _:f32[] = div b c
    d:f32[] = div 1.0 c
    e:f32[] = div d a
  in (e,) }

Here's the step-by-step representation of the given expression using LaTeX equations:

Step 1:
$ b = \log(a) $

Step 2:
$ c = \log(2.0) $

Step 3:
$ \_ = \frac{b}{c} $

Step 4:
$ d = \frac{1.0}{c} = \frac{1.0}{\log(2.0)} $, which is $\frac{\partial \_}{\partial b}$

Step 5:
$ e = \frac{d}{a} $, which is $\frac{\partial \_}{\partial b} \cdot \frac{\partial b}{\partial a}$ ,

since $\frac{\partial b}{\partial a} = \frac{\partial \log(a)}{\partial a} = \frac{1}{a}$

Step 6:
 return $e $


In [54]:
gradient(x_0)

Array(0.7213475, dtype=float32, weak_type=True)

In [55]:
1 / (x_0 * jnp.log(2))

Array(0.7213475, dtype=float32, weak_type=True)

In [56]:
# here's the gradient of the gradient
make_jaxpr(jax.grad(gradient))(x_0)

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    _:f32[] = div b c
    d:f32[] = div 1.0 c
    _:f32[] = div d a
    e:f32[] = integer_pow[y=-2] a
    f:f32[] = mul 1.0 e
    g:f32[] = mul f d
    h:f32[] = neg g
  in (h,) }

In [57]:
# which is the derivative of 1/x = -1/x^2
-1 / (x_0**2 * jnp.log(2)) == jax.grad(gradient)(x_0)

Array(True, dtype=bool, weak_type=True)

In [58]:
# we could go forever:
make_jaxpr(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(f)))))))))(x_0)

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    _:f32[] = div b c
    d:f32[] = div 1.0 c
    _:f32[] = div d a
    _:f32[] = integer_pow[y=-2] a
    _:f32[] = integer_pow[y=-2] a
    e:f32[] = integer_pow[y=-3] a
    _:f32[] = mul -2.0 e
    _:f32[] = integer_pow[y=-2] a
    f:f32[] = integer_pow[y=-3] a
    _:f32[] = mul -2.0 f
    g:f32[] = integer_pow[y=-3] a
    h:f32[] = integer_pow[y=-4] a
    _:f32[] = mul -3.0 h
    _:f32[] = mul -2.0 g
    _:f32[] = integer_pow[y=-2] a
    i:f32[] = integer_pow[y=-3] a
    _:f32[] = mul -2.0 i
    j:f32[] = integer_pow[y=-3] a
    k:f32[] = integer_pow[y=-4] a
    _:f32[] = mul -3.0 k
    _:f32[] = mul -2.0 j
    l:f32[] = integer_pow[y=-3] a
    m:f32[] = integer_pow[y=-4] a
    _:f32[] = mul -3.0 m
    n:f32[] = integer_pow[y=-4] a
    o:f32[] = integer_pow[y=-5] a
    _:f32[] = mul -4.0 o
    _:f32[] = mul -3.0 n
    _:f32[] = mul -2.0 l
    _:f32[] = integer_pow[y=-2] a
    p:f32[] = integer_pow[y=-3] a
    _:f32[] =

In [27]:
# here's another function: a layer of a neural network
def mlp(x, w, b):
    return jnp.dot(w, x) + b

x = 9.0
w = 3.0
b = 1.0

make_jaxpr(mlp)(x, w, b)

{ lambda ; a:f32[] b:f32[] c:f32[]. let
    d:f32[] = mul b a
    e:f32[] = add d c
  in (e,) }

In [28]:
# the gradient of f with respect to w
make_jaxpr(jax.grad(mlp, argnums=1))(x, w, b)

{ lambda ; a:f32[] b:f32[] c:f32[]. let
    d:f32[] = mul b a
    _:f32[] = add d c
    e:f32[] = mul 1.0 a
  in (e,) }

In [32]:
# analytical gradient is just x
x == jax.grad(mlp, argnums=1)(x, w, b)

Array(True, dtype=bool, weak_type=True)

In [80]:
# now let's get crazy: here's the attention mechanism from the Transformer
# (https://arxiv.org/abs/1706.03762)

def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x))

def scaled_dot_product_attention(q, k, v):
    qk = jnp.dot(q, k.T) / jnp.sqrt(q.shape[-1])
    weights = softmax(qk)
    return jnp.dot(weights, v)

def attention(weights, e):
    Wq = weights["Wq"]
    Wk = weights["Wk"]
    Wv = weights["Wv"]

    q = jnp.dot(Wq, e)
    k = jnp.dot(Wk, e)
    v = jnp.dot(Wv, e)
    return scaled_dot_product_attention(q, k, v)

Wq = jnp.array([[1.0, 0.1, 0.0],
                [0.0, 1.0, 0.1],
                [0.1, 0.0, 1.0]])

Wk = jnp.array([[1.0, 0.0, 0.1],
                [0.1, 1.0, 0.0],
                [0.0, 0.1, 1.0]])

Wv = jnp.array([[1.0, 0.1, 0.0],
                [0.0, 1.0, 0.1],
                [0.1, 0.0, 1.0]])


e = jnp.array([1.0, 2.0, 3.0])

make_jaxpr(attention)(dict(Wq=Wq, Wk=Wk, Wv=Wv), e)

{ lambda ; a:f32[3,3] b:f32[3,3] c:f32[3,3] d:f32[3]. let
    e:f32[3] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b d
    f:f32[3] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d
    g:f32[3] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c d
    h:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] e f
    i:f32[] = sqrt 3.0
    j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] i
    k:f32[] = div h j
    l:f32[] = exp k
    m:f32[] = exp k
    n:f32[] = reduce_sum[axes=()] m
    o:f32[] = div l n
    p:f32[3] = mul o g
  in (p,) }

In [82]:
attention(dict(Wq=Wq, Wk=Wk, Wv=Wv), e)

Array([1.2, 2.3, 3.1], dtype=float32)

In [81]:
# the gradient of the attention mechanism with respect to the weights
# note that we have to use the full jacobian, since the output of our function
# is a vector, not a scalar
# the result is a 3x3 matrix, where each row is the gradient of the output
# with respect to the corresponding row of the weight matrix
# note how simple the Wv gradient is, since all we do is a dot product
jax.jacobian(attention, argnums=0)(dict(Wq=Wq, Wk=Wk, Wv=Wv), e)

{'Wk': Array([[[-8.2590617e-08, -1.6518123e-07, -2.4777185e-07],
         [-1.5829868e-07, -3.1659735e-07, -4.7489601e-07],
         [-2.1335909e-07, -4.2671817e-07, -6.4007725e-07]],
 
        [[-1.6518123e-07, -3.3036247e-07, -4.9554370e-07],
         [-3.1659735e-07, -6.3319470e-07, -9.4979202e-07],
         [-4.2671817e-07, -8.5343635e-07, -1.2801545e-06]],
 
        [[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00]]], dtype=float32),
 'Wq': Array([[[-8.9473168e-08, -1.7894634e-07, -2.6841951e-07],
         [-1.4453357e-07, -2.8906715e-07, -4.3360072e-07],
         [-2.2024165e-07, -4.4048329e-07, -6.6072494e-07]],
 
        [[-1.7894634e-07, -3.5789267e-07, -5.3683902e-07],
         [-2.8906715e-07, -5.7813429e-07, -8.6720144e-07],
         [-4.4048329e-07, -8.8096658e-07, -1.3214499e-06]],
 
        [[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
         [ 0.0

# vmap (automatic vectorization)

have you ever written a very simple function that you wanted to apply to a large array of inputs? numpy makes this easy with vectorization, but jax takes it a step further with vmap.

vmap is a function transformation that allows you to take a function that operates on a single input, and transform it into a function that operates on many inputs in a vectorized manner.

for example, consider the following function that would be a litte tricky to vectorize with numpy:


In [59]:
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)
  return outputs

params = [
    (jnp.array([[1., 2.], [3., 4.]]), jnp.array([1., 1.])),
]

inputs = jnp.array([[1., 2.]])
predict(params, inputs)


we can use vmap to vectorize this function over multiple inputs:



In [64]:
batched_predict = jax.vmap(predict, in_axes=(None, 0))

many_inputs = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
batched_predict(params, many_inputs)

Array([[ 8., 11.],
       [16., 23.],
       [24., 35.]], dtype=float32)

with literally zero brain cells, we have implemented batching! :)

the `in_axes` argument specifies which of the function’s arguments to batch over. in this case, we want to batch over the second argument, which is the inputs. the first argument, params, will be left unchanged -- it will be as if we called predict in a for loop over each of the inputs.


here's another one of my favorite examples -- operating over pairs of inputs using nested vmaps:

```python

