The following codes are adopted from an IFT tutorial over here: http://implicit-layers-tutorial.org/implicit_functions/.

In [27]:
from functools import partial
import jax
import jax.numpy as jnp
from jax import random

# Different solvers

In [1]:
def fwd_solver(f, z_init):
    z_prev, z = z_init, f(z_init)
    while jnp.linalg.norm(z_prev - z) > 1e-5:
        z_prev, z = z, f(z)
    return z


def newton_solver(f, z_init):
    f_root = lambda z: f(z) - z
    g = lambda z: z - jnp.linalg.solve(jax.jacobian(f_root)(z), f_root(z))
    return fwd_solver(g, z_init)


def anderson_solver(f, z_init, m=5, lam=1e-4, max_iter=50, tol=1e-5, beta=1.0):
    x0 = z_init
    x1 = f(x0)
    x2 = f(x1)
    X = jnp.concatenate([jnp.stack([x0, x1]), jnp.zeros((m - 2, *jnp.shape(x0)))])
    F = jnp.concatenate([jnp.stack([x1, x2]), jnp.zeros((m - 2, *jnp.shape(x0)))])

    res = []
    for k in range(2, max_iter):
        n = min(k, m)
        G = F[:n] - X[:n]
        GTG = jnp.tensordot(G, G, [list(range(1, G.ndim))] * 2)
        H = jnp.block(
            [[jnp.zeros((1, 1)), jnp.ones((1, n))], [jnp.ones((n, 1)), GTG]]
        ) + lam * jnp.eye(n + 1)
        alpha = jnp.linalg.solve(H, jnp.zeros(n + 1).at[0].set(1))[1:]

        xk = beta * jnp.dot(alpha, F[:n]) + (1 - beta) * jnp.dot(alpha, X[:n])
        X = X.at[k % m].set(xk)
        F = F.at[k % m].set(f(xk))

        res = jnp.linalg.norm(F[k % m] - X[k % m]) / (1e-5 + jnp.linalg.norm(F[k % m]))
        if res < tol:
            break
    return xk

# Fixed point function

In [3]:
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
    return z_star


f = lambda W, x, z: jnp.tanh(jnp.dot(W, z) + x)

In [5]:
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))

In [6]:
z_star = fixed_point_layer(fwd_solver, f, W, x)
print(z_star)

[ 0.00649598 -0.7015958  -0.984715   -0.04196562 -0.61522174 -0.4818382
  0.5783123   0.9556705  -0.08373147  0.8447805 ]


In [7]:
z_star = fixed_point_layer(newton_solver, f, W, x)
print(z_star)

[ 0.00649405 -0.701595   -0.98471504 -0.04196506 -0.6152211  -0.4818385
  0.57831246  0.9556705  -0.08372928  0.8447799 ]


# Naive automatic differentiation through iterative solvers

In [8]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x).sum())(W)
print(g[0])

[ 0.0075667  -0.8125902  -1.1404794  -0.04861286 -0.71255237 -0.5580556
  0.66978824  1.1068414  -0.09702272  0.97842246]


In [9]:
g = jax.grad(lambda W: fixed_point_layer(newton_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00752129 -0.81257427 -1.1404787  -0.04860315 -0.7125375  -0.5580563
  0.66979074  1.1068397  -0.09697369  0.97840846]


# VJP and JVP

In [14]:
def f(x):
    return jnp.sin(x) * x**2


x = 2.0
y = f(x)
print(y)

3.6371896


In [15]:
w = 1.0
y, f_vjp = jax.vjp(f, x)
(lmbda,) = f_vjp(w)
print(y)
print(lmbda)

3.6371896
1.9726022


In [20]:
h = jnp.sin
g = lambda x: x**3

f = lambda x: g(h(x))
z, delta_z = jax.jvp(f, (1.0,), (1.0,))
print(z)
print(delta_z)

0.59582317
1.1477209


In [21]:
def f_jvp(x, delta_x):
    y, delta_y = jax.jvp(h, (x,), (delta_x,))
    z, delta_z = jax.jvp(g, (y,), (delta_y,))
    return z, delta_z


z, delta_z = f_jvp(1.0, 1.0)
print(z)
print(delta_z)

0.59582317
1.1477209


In [26]:
def f_vjp(x, w):
    y, h_vjp = jax.vjp(h, x)
    z, g_vjp = jax.vjp(g, y)
    (delta_y,) = h_vjp(w)
    (delta_z,) = g_vjp(delta_y)
    return z, delta_z


z, delta_z = f_vjp(1.0, 1.0)
print(z)
print(delta_z)

0.59582317
1.1477209


# Implicit function theorem

In [31]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=jnp.zeros_like(x))
    return z_star


def fixed_point_layer_fwd(solver, f, params, x):
    z_star = fixed_point_layer(solver, f, params, x)
    return z_star, (params, x, z_star)


def fixed_point_layer_bwd(solver, f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(
        solver(lambda u: vjp_z(u)[0] + z_star_bar, z_init=jnp.zeros_like(z_star))
    )


fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [32]:
# Recall we had these definitions for f, W, and x
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
f = lambda W, x, z: jnp.tanh(jnp.dot(W, z) + x)
x = random.normal(random.PRNGKey(1), (ndim,))

In [34]:
fixed_point_layer(fwd_solver, f, W, x)

Array([ 0.00649598, -0.7015958 , -0.984715  , -0.04196562, -0.61522174,
       -0.4818382 ,  0.5783123 ,  0.9556705 , -0.08373147,  0.8447805 ],      dtype=float32)

In [35]:
g = jax.grad(lambda W: fixed_point_layer(fwd_solver, f, W, x).sum())(W)
print(g[0])

[ 0.0075235  -0.812573   -1.1404755  -0.04860367 -0.7125365  -0.55805457
  0.6697887   1.1068368  -0.09697597  0.9784065 ]
