In [None]:
import jax
import jax.numpy as jnp
from jax import grad

In [None]:
# 使用 jax.grad 求导数

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0),
      grad(grad(jnp.tanh))(2.0),
      grad(grad(grad(jnp.tanh)))(2.0))

In [None]:
# 使用 jax.grad 求导数

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

print(dfdx(1.),
      d2fdx(1.),
      d3fdx(1.),
      d4fdx(1.))

In [None]:
key = jax.random.key(0)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

In [None]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')

In [None]:
# 也可以对字典进行求导
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

In [None]:
# 函数value_and_grad可以同时返回函数值和梯度 非常方便

loss_value, Wb_grad = jax.value_and_grad(loss, (0,1))(W, b) # 这儿的 (0,1) 是 tuple，表示对 W 和 b 都求导
print('loss value', loss_value)
print('loss value', loss(W, b))
print('Wb_grad', Wb_grad)

_, W_grad = jax.value_and_grad(loss, 0)(W, b)
print('W_grad', W_grad)

_, b_grad = jax.value_and_grad(loss, 1)(W, b)
print('b_grad', b_grad)

In [None]:
# 手动验证梯度计算结果

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))

In [None]:
# 使用 JAX 内置的 check_grads 函数来验证梯度计算结果

from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives