In [1]:
import jax
import jax.numpy as jnp

key = jax.random.key(0)

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

inputs = jnp.array(
    [
        [0.52, 1.12, 0.77],
        [0.88, -1.08, 0.15],
        [0.52, 0.06, -1.30],
        [0.74, -2.49, 1.39],
    ]
)

targets = jnp.array([True, True, False, True])

def loss(W, b):
    preds = predict(W, b, inputs)
    # print(preds.shape)
    label_probs = preds * targets + (1 - preds) * (1 -  targets)
    return -jnp.sum(jnp.log(label_probs))


key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())


In [2]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = jax.grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)


In [3]:
# Since argnums=0 is the default, this does the same thing:
W_grad = jax.grad(loss)(W, b)
print(f'{W_grad=}')

W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)


In [4]:
# But you can choose different values too, and drop the keyword:
b_grad = jax.grad(loss, 1)(W, b)
print(f'{b_grad=}')

b_grad=Array(-0.29227245, dtype=float32)


In [5]:
# Including tuple values
W_grad, b_grad = jax.grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')

W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)


In [6]:
def loss2(params_dict):
    preds = predict(params_dict["W"], params_dict["b"], inputs)
    label_probs = preds * targets + (1 - preds) * (1 -  targets)
    return -jnp.sum(jnp.log(label_probs))


print(jax.grad(loss2)({"W": W, "b": b}))

{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}


In [7]:
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
print("loss grad:", Wb_grad)

loss value 3.0519385
loss value 3.0519385
loss grad: (Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), Array(-0.29227245, dtype=float32))


In [8]:
# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', jax.grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(jax.grad(loss)(W, b), unitvec))

b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117


In [9]:
from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives