In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

In [None]:
# 预热：计算二阶导数

def hessian(f):
  return jax.jacfwd(jax.grad(f))

def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))

# Higher-order optimization

## Stopping gradients

In [None]:
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

In [None]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((target - v_tm1) ** 2)

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta

In [None]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
# This is the most important line in the code.

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

print(delta_theta)

# 手动写代码验证上面的结果
grad_fn = jax.grad(value_fn)
(r_t+value_fn(theta, s_t)-value_fn(theta, s_tm1))*grad_fn(theta, s_tm1)

## Straight-through estimator using stop_gradient

In [None]:
def f(x):
  return jnp.round(x)  # non-differentiable

def straight_through_f(x):
  # Create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

# 前向传播就是函数求值
# 反向传播就是函数求导（用链式法则）

## Per-example gradients

In [None]:
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

## Hessian-vector products with jax.grad-of-jax.grad

In [None]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

def gprod(f, x, v):
    return grad(f)(x) @ v

# 或者不用 lambda 用下面的这种写法？
def hvp2(f, x, v):
    return grad(gprod, argnums=1)(x) @ v

In [None]:
# jax.grad 是什么形状的？

import jax.numpy as jnp
from jax import grad

# 定义目标函数
def f(x):
    return jnp.sum(x ** 2)  # 简单的二次函数

# 输入变量
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

# 计算梯度
grad_f = grad(f)(x)

print("输入变量 x 的形状:", x.shape)
print("梯度向量 grad_f 的形状:", grad_f.shape)

In [None]:
# 补课：点积和矩阵乘法

import jax.numpy as jnp

# 一维向量点积（标量）
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
print("一维向量点积（标量）")
print(jnp.dot(a, b))   # 32
print(a @ b)    # 32
print(jnp.matmul(a, b))  # [4 10 18]

print(jnp.inner(a, b))  # 32
print(jnp.vdot(a, b))   # 32

# 二维矩阵乘法
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
print(jnp.dot(A, B))   # works
print(A @ B)           # works
print(jnp.matmul(A, B))   # works

# 更高维度建议用 jnp.matmul 或 @
# 不推荐用 jnp.dot

# jnp.tensordot 这个比较复杂，有待进一步学习

## Jacobians and Hessians using jax.jacfwd and jax.jacrev

In [None]:
from jax import jacfwd, jacrev

# Define a sigmoid function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

These two functions compute the same values (up to machine numerics), but differ in their implementation: jax.jacfwd() uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices (more outputs than inputs), while jax.jacrev() uses reverse-mode, which is more efficient for “wide” Jacobian matrices (more inputs than outputs). For matrices that are near-square, jax.jacfwd() probably has an edge over jax.jacrev().

In [None]:
def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict_fwd = jacfwd(predict_dict)({'W': W, 'b': b}, inputs)
J_dict_rev = jacrev(predict_dict)({'W': W, 'b': b}, inputs)

for k, v in J_dict_fwd.items():
    print(f"Jacobian from {k} to logits is")
    print(v)
for k, v in J_dict_rev.items():
    print(f"Jacobian from {k} to logits is")
    print(v)

# How it’s made: Two foundational autodiff functions

## Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)

In [None]:
from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
print("使用 jvp 计算的结果")
print("y:", y)
print("u:", u)

print("下面是手动验证的代码")
print(f"y: {f(W)}\nu: {jax.jacobian(f)(W)@v}")

# 教程中的 “FLOP” 是 Floating Point Operation（浮点运算） 的缩写，是衡量计算量的一个基本单位。

## Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff)

In [None]:
from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
#u = random.normal(subkey, y.shape)
u = jnp.array([1., 0., 0., 0.])  # 只取第一个元素的梯度

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

u, v

## Vector-valued gradients with VJPs

In [None]:
def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

print(vgrad(lambda x: 3*x**2, x))

In [None]:
# 或者也可以手动实现

f = lambda x:3*x**2
vdf = jax.vmap(grad(f))
vdf = vdf(x.reshape(-1))
print(vdf, "但是这个形状变了！")

varf = lambda x: jnp.sum(3 * x ** 2)
print(varf(x), "这个函数变了！")
vardf = jax.grad(varf)
print(vardf(x))  # 输出 shape 为 (2,2)，结果是全 6

## Hessian-vector products using both forward- and reverse-mode

In [None]:
def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

# Reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

# Reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)

# Composing VJPs, JVPs, and jax.vmap

## Jacobian-Matrix and Matrix-Jacobian products

In [43]:
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

def vmap_mjp2(f, x, M):
    y, jvp_fun = vmap(jvp, in_axes=(None, None, 0))(f, (x,), (M,))
    outs = jvp_fun
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
V = random.normal(key, (num_covecs,) + W.shape)


loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product, reverse-mode')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product forward-mode')
vmap_vs2 = vmap_mjp2(f, W, M=V)
%timeit -n10 -r3 vmap_mjp2(f, W, M=V)

# assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'

  return jnp.vstack([vjp_fun(mi) for mi in M])


Non-vmapped Matrix-Jacobian product
63.4 ms ± 5.27 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product, reverse-mode
2.65 ms ± 187 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product forward-mode
1.49 ms ± 68.4 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)


## The implementation of jax.jacfwd and jax.jacrev

In [40]:
from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once.
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'

In [41]:
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

In [47]:
# 补课：闭包与函数嵌套调用

def func1(a):
    def func2(b):
        return a + b
    return func2

func1(1)(2)  # 3

3