In [1]:
from functools import partial
import timeit
import os
import sys

from typing import Tuple, List
from jaxtyping import Float, Array
import equinox as eqx

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image, display

import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)
config.update("jax_debug_infs", True)
config.update("jax_debug_nans", True)

# Problem Definition
This notebook explains how to implement backpropagation through implicit layers. As an example problem, we consider the following problem:
$$
\begin{align}
% \mathbf{F}(\mathbf{x}, \alpha, \beta) = \mathbf{0},
\mathbf{F}(\mathbf{x}, \alpha, \beta) = \mathbf{x},
\end{align}
$$
where $\mathbf{x} = (x_1, x_2)$ and
$$
\begin{align}
\mathbf{F}(\mathbf{x}, \alpha, \beta) = \left[\begin{matrix}
% x_1^2 + x_2^2 - \alpha \\
% \beta x_1^3 - x_2  
% x_1^2 + x_2^2 + x_1 - \alpha \\
% \beta x_1^3
\sqrt{\alpha-x_1x_2} \\
\sqrt{\frac{57-x_2}{\beta x_1}}
\end{matrix}\right]
\end{align}
$$. When $\alpha = 4, \beta = 1$, the approximate real solution is $(x_1, x_2) = (\pm 1.174, \pm 1.619)$, which is clear in the figure below.

In [2]:
class Para(eqx.Module):
    alpha: Float[Array, ""]
    beta: Float[Array, ""]


class Fc(eqx.Module):
    para: Para

    def __call__(self, x: Array) -> Array:
        F_0 = jnp.sqrt(self.para.alpha - x[0] * x[1])
        F_1 = jnp.sqrt((57 - x[1]) / (self.para.beta * F_0))
        # return F_0, F_1
        return jnp.array([F_0, F_1])

In [3]:
alpha_v = 10.0
beta_v = 3.0

para = Para(alpha_v, beta_v)

func = Fc(para)

func(jnp.array([1.5, 3.5]))

Array([2.17944947, 2.86050599], dtype=float64)

# Fixed-point method to solve the system of non-linear equations
$$
\begin{align}
% \mathbf{F}(\mathbf{x}, \alpha, \beta) = \mathbf{0},
\mathbf{x_{n+1}} = \mathbf{F}(\mathbf{x_n}, \alpha, \beta),
\end{align}
$$
where $n$ is the number of iterations.

In [4]:
def fixed_point(f, x_guess):
    def body_fun(x, i):
        x_new = f(x)
        # jax.debug.print("x_new: {a}", a=x_new)
        return x_new, x_new - x

    x_star, _ = jax.lax.scan(body_fun, x_guess, xs=None, length=1000)
    return x_star

In [5]:
# x_guess = jnp.array([1.0, 1.0])
x_guess = jnp.array([1.5, 3.5])
fixed_point(func, x_guess)

Array([2., 3.], dtype=float64)

# Compute the gradient through the fixed-point iteration solver

In [6]:
x_guess = jnp.array([1.5, 3.5])

In [7]:
def implicit_func_naive(func, x_guess):
    x_solution = fixed_point(func, x_guess)
    return x_solution


implicit_func_naive(func, x_guess)

Array([2., 3.], dtype=float64)

In [8]:
derivative_implicit_naive = jacfwd(implicit_func_naive, argnums=(0,), has_aux=False)

In [9]:
(jac,) = derivative_implicit_naive(func, x_guess)
jac.para.alpha, jac.para.beta

(Array([ 0.1804878 , -0.13170732], dtype=float64),
 Array([ 0.17560976, -0.61463415], dtype=float64))

# Differentiating through implicit funciton through implicit funciton theorem

Next, we will implement custom derivative rule for the Newton solver so that we can use jax.grad or jax.jacrev through the Newton solver.
$$
\begin{align}
F(\mathbf{x}, \mathbf{W}) = \mathbf{x},
\end{align}
$$
where $\mathbf{W} = (\alpha, \beta)$. Because both sides are the same functions, their derivatives are the same. Using the chain rule on the left hand side, 
$$
\begin{align}
\frac{\partial F}{\partial \mathbf{x}} \frac{\partial \mathbf{x}}{\partial \mathbf{W}} + \frac{\partial F}{\partial \mathbf{W}} = \frac{\partial \mathbf{x}}{\partial \mathbf{W}}
\end{align}
$$
Thus,
$$
\begin{align}
\frac{\partial \mathbf{x}}{\partial \mathbf{W}} = \left[\mathbf{I}-\frac{\partial F}{\partial \mathbf{x}}\right]^{-1}\frac{\partial F}{\partial \mathbf{W}}
\end{align}
$$
The Jacobian vector product (w is the vector to be multipled) is
$$
\begin{align}
\frac{\partial \mathbf{x}}{\partial \mathbf{W}}w = \left[\mathbf{I}-\frac{\partial F}{\partial \mathbf{x}}\right]^{-1}u,
\end{align}
$$
where $u = \frac{\partial F}{\partial \mathbf{W}}w$ can be computed by Jacobian vector product of $\mathbf{F}$. Then, the Jacobian vector product $\frac{\partial \mathbf{x}}{\partial \mathbf{W}}w$ is the solution to the linear system 
$$
\left[\mathbf{I}-\frac{\partial F}{\partial \mathbf{x}}\right] \frac{\partial \mathbf{x}}{\partial \mathbf{W}}w = u.
$$


### Implementing the implicit differentiation in JAX by Jacobian-vector product

In [10]:
@partial(jax.custom_jvp, nondiff_argnums=(0,))
def implicit_func(func, x_guess, para):
    func = eqx.tree_at(lambda t: t.para, func, para)
    x_solution = fixed_point(func, x_guess)
    return x_solution

In [13]:
@implicit_func.defjvp
def implicit_func_jvp(func, primals, tangents):
    def func_para(func, x, para):
        func = eqx.tree_at(lambda t: t.para, func, para)
        return func(x)

    x_guess, args = primals[0], primals[1:]
    tangents_x, tangents_args = tangents[0], tangents[1:]

    para = args[0]
    x_solution = implicit_func(func, x_guess, para)

    _, u = jvp(partial(func_para, func, x_solution), args, tangents_args, has_aux=False)
    # _, u = jvp(partial(func_para, func, x_solution), args, args, has_aux=False)
    Jacobian_JAX = jacfwd(func, argnums=0, has_aux=False)
    J = Jacobian_JAX(x_solution)
    # J, u = jnp.array(J), jnp.array(u)
    I = jnp.eye(J.shape[0])
    tangent_out = jnp.linalg.solve(I - J, u)
    return (
        x_solution,
        tangent_out,
    )  # you don't need to add None, see the discussion here (https://github.com/google/jax/discussions/16871)

In [14]:
derivative_implicit_fwd = jacrev(implicit_func, argnums=(1, 2), has_aux=False)
jac_x, jac_para = derivative_implicit_fwd(func, x_guess, para)
jac_x, jac_para.alpha, jac_para.beta

(Array([[0., 0.],
        [0., 0.]], dtype=float64),
 Array([ 0.1804878 , -0.13170732], dtype=float64, weak_type=True),
 Array([ 0.17560976, -0.61463415], dtype=float64, weak_type=True))

In [95]:
%timeit derivative_implicit_naive(func, x_guess)

45.8 ms ± 628 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [96]:
%timeit derivative_implicit_fwd(func, x_guess, para)

32 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


# With default initial guess
This implementation is due to the fact that the initial is generated inside canoak.

In [103]:
@partial(jax.custom_jvp, nondiff_argnums=(0,))
def implicit_func_nox(func, para):
    x_guess = jnp.array([1.5, 3.5])
    func = eqx.tree_at(lambda t: t.para, func, para)
    x_solution = fixed_point(func, x_guess)
    return x_solution


@implicit_func_nox.defjvp
def implicit_func_nox_jvp(func, primals, tangents):
    def func_para(func, x, para):
        func = eqx.tree_at(lambda t: t.para, func, para)
        return func(x)

    args = primals
    tangents_args = tangents

    para = args[0]
    x_solution = implicit_func_nox(func, para)

    _, u = jvp(partial(func_para, func, x_solution), args, tangents_args, has_aux=False)
    Jacobian_JAX = jacfwd(func, argnums=0, has_aux=False)
    J = Jacobian_JAX(x_solution)
    I = jnp.eye(J.shape[0])
    tangent_out = jnp.linalg.solve(I - J, u)
    return (
        x_solution,
        tangent_out,
    )  # you don't need to add None, see the discussion here (https://github.com/google/jax/discussions/16871)

In [104]:
derivative_implicit_nox_fwd = jacrev(implicit_func_nox, argnums=(1), has_aux=False)
jac_para = derivative_implicit_nox_fwd(func, para)
jac_para.alpha, jac_para.beta

(Array([ 0.1804878 , -0.13170732], dtype=float64, weak_type=True),
 Array([ 0.17560976, -0.61463415], dtype=float64, weak_type=True))

# (TODO) What if I only want to compute the gradient against one output?
- Here, I have to create a function that outputs only one value. This function is essentially another callable pytree. However, it is also a eqx.Partial or jax.tree_util.Partial instance.
So, the IFT gradient funciton needs to be revised accordingly.
- Also, this time, I embedded the iteration inside the main function in a way similar to canoak.

In [128]:
# # Let's redefine the classes
# class Para(eqx.Module):
#     alpha: Float[Array, ""]
#     beta: Float[Array, ""]

# class Fc(eqx.Module):
#     para: Para

#     def __call__(self) -> Array:
#         x_guess = jnp.array([1.5, 3.5])
#         def func(x):
#             F_0 = jnp.sqrt(self.para.alpha-x[0]*x[1])
#             F_1 = jnp.sqrt((57-x[1])/(self.para.beta*F_0))
#             return jnp.array([F_0, F_1])
#         return fixed_point(func, x_guess)

#     def output1(self) -> Array:
#         return self()[0]

# alpha_v = 10.
# beta_v = 3.0

# para = Para(alpha_v, beta_v)

# func = Fc(para)

# # func(jnp.array([1.5, 3.5])), func.output1(jnp.array([1.5, 3.5]))
# func(), func.output1()

(Array([2., 3.], dtype=float64), Array(2., dtype=float64))

In [131]:
# @partial(jax.custom_jvp, nondiff_argnums=(0,))
# def implicit_func_nox_partial(func, para):
#     func = eqx.tree_at(lambda t: t.args[0].para, func, para)
#     return func()

# @implicit_func_nox_partial.defjvp
# def implicit_func_nox_partial_jvp(func, primals, tangents):
#     def func_para(func, para):
#         func = eqx.tree_at(lambda t: t.args[0].para, func, para)
#         return func()
#     args = primals
#     tangents_args = tangents

#     para = args[0]
#     x_solution = implicit_func_nox_partial(func, para)
#     print(x_solution)

#     _, u = jvp(partial(func_para, func), args, tangents_args, has_aux=False)
#     Jacobian_JAX = jacfwd(func, argnums=0, has_aux=False)
#     J = Jacobian_JAX(x_solution)
#     I = jnp.eye(J.shape[0])
#     tangent_out = jnp.linalg.solve(I-J, u)
#     return x_solution, tangent_out # you don't need to add None, see the discussion here (https://github.com/google/jax/discussions/16871)

In [133]:
# derivative_implicit_nox_partial_fwd = jacrev(implicit_func_nox_partial, argnums=(1), has_aux=False)
# jac_para = derivative_implicit_nox_partial_fwd(func.output1, para)
# jac_para.alpha,jac_para.beta

# (TODO) What if the outputs are two pytrees?

In [138]:
class Para(eqx.Module):
    alpha: Float[Array, ""]
    beta: Float[Array, ""]


class Var(eqx.Module):
    value: Float[Array, ""]


class Fc(eqx.Module):
    para: Para

    def __call__(self, x: Tuple[Var, Var]) -> Tuple[Var, Var]:
        F_0 = jnp.sqrt(self.para.alpha - x[0].value * x[1].value)
        F_1 = jnp.sqrt((57 - x[1].value) / (self.para.beta * F_0))
        # return F_0, F_1
        xnew = [Var(value=F_0), Var(value=F_1)]
        return xnew

In [142]:
alpha_v = 10.0
beta_v = 3.0
x_guess = [Var(1.5), Var(3.5)]
para = Para(alpha_v, beta_v)
func = Fc(para)
func(x_guess)

[Var(value=f64[]), Var(value=f64[])]

In [145]:
def fixed_point_pytree(f, x_guess):
    def body_fun(x, i):
        x_new = f(x)
        return x_new, None

    x_star, _ = jax.lax.scan(body_fun, x_guess, xs=None, length=1000)
    return x_star


fixed_point_pytree(func, x_guess)

[Var(value=f64[]), Var(value=f64[])]

In [146]:
Jacobian_JAX_pytree = jacfwd(func, argnums=0, has_aux=False)
Jacobian_JAX_pytree(x_guess)

[Var(value=[Var(value=f64[]), Var(value=f64[])]),
 Var(value=[Var(value=f64[]), Var(value=f64[])])]

In [None]:
@partial(jax.custom_jvp, nondiff_argnums=(0,))
def implicit_func(func, x_guess, para):
    func = eqx.tree_at(lambda t: t.para, func, para)
    x_solution = fixed_point(func, x_guess)
    return x_solution

In [None]:
@implicit_func.defjvp
def implicit_func_jvp(func, primals, tangents):
    def func_para(func, x, para):
        func = eqx.tree_at(lambda t: t.para, func, para)
        return func(x)

    x_guess, args = primals[0], primals[1:]
    tangents_x, tangents_args = tangents[0], tangents[1:]

    para = args[0]
    x_solution = implicit_func(func, x_guess, para)

    _, u = jvp(partial(func_para, func, x_solution), args, tangents_args, has_aux=False)
    Jacobian_JAX = jacfwd(func, argnums=0, has_aux=False)
    J = Jacobian_JAX(x_solution)
    # J, u = jnp.array(J), jnp.array(u)
    I = jnp.eye(J.shape[0])
    tangent_out = jnp.linalg.solve(I - J, u)
    return (
        x_solution,
        tangent_out,
    )  # you don't need to add None, see the discussion here (https://github.com/google/jax/discussions/16871)