# Implicit Differentiation
This notebook will briefly look at implicit differentiation. Implicit differentiation has numerous applications; we will look at a simple fixed point iteration.

For a more in-depth guide on its application to machine learning, see the NeurIPS 2020 tutorial on [deep implicit layers](http://implicit-layers-tutorial.org/). Full credit goes to the deep implicit layers tutorial from which we pull examples and applicable code from to provide a primer on the implicit function theorem and itss applications.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from jax import random

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

%matplotlib inline

First, let's consider a simple fixed point iteration problem: $z = \tanh(Wz + x)$. Formally, we wish to solve the problem:

Find $z$ such that $g(x, z) = 0$ where $g(x, z) = z - \tanh(Wz + x)$

In [2]:
func = lambda w, x, z: jnp.tanh(W @ z + x)
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / jnp.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,)) / jnp.sqrt(ndim)
z_init = jnp.zeros_like(x)

def fixed_point(func, z_0, tol=1e-5):
    z_i = func(z_0)
    z_prev = z_0
    num_iteration = 1
    while jnp.linalg.norm(z_prev - z_i) > tol:
        z_prev = z_i
        z_i = func(z_i)
        num_iteration += 1
    return z_i, num_iteration
z_star_naive, naive_num_iteration = fixed_point(lambda z: func(W, x, z), z_init)
print(naive_num_iteration, 'Fixed point iterations')
print(z_star_naive)


35 Fixed point iterations
[-0.16798191 -0.39130193 -0.83032316 -0.15473165 -0.4114282  -0.38461813
  0.07424818  0.4667457   0.02648443  0.686279  ]


Alternatively, we can use Newton's method:
$z = z - (\frac{\partial g}{\partial z})^{-1} g(z)$. Since we know the closed form of $g$, we can manually compute $\frac{\partial g}{\partial z}$. However, we can leverage autograd to compute this jacobian for us.

$\frac{\partial g}{\partial z} = I - \text{diag}(\tanh'(Wz + x))W$.

In [3]:
def dg_dz(w, x, z):
    return jnp.eye(z.shape[0]) - (1 / jnp.cosh(w * z + x) ** 2) * w

def newton_solver(func, z_0, tol=1e-5):
    func_root = lambda z_i: func(z_i) - z_i
    # Using autograd!
    newton_eqn = lambda z_i: z_i - jnp.linalg.solve(jax.jacobian(func_root)(z_i), func_root(z_i))
    return fixed_point(newton_eqn, z_0, tol=tol)

newton_z_star, newton_num_iterations = newton_solver(lambda z: func(W, x, z),
                                         z_init)

print("Difference between newton's method and naive fixed point iteration:", jnp.linalg.norm(newton_z_star - z_star_naive))
print("Number of iterations for newton's method:", newton_num_iterations)


Difference between newton's method and naive fixed point iteration: 4.425548e-06
Number of iterations for newton's method: 5


So far so good! What if we want to find $\frac{\partial z^*}{\partial x}$? This is where we will leverage the **implicit function theorem** (IFT). Before stating the theorem, we'll work through an example:

Since this is a fixed point iteration, we know:

$\frac{\partial g(x, z^*)}{\partial x} = 0$

and expanding via the chain rule:

$\frac{\partial g(x, z^*)}{\partial x}  + \frac{\partial g(x, z^*)}{\partial z^*} \cdot \frac{\partial z^*(x)}{\partial x}= 0$.

Thus:

$\frac{\partial z^*(x)}{\partial x} = - \left(\frac{\partial g(x, z^*)}{\partial z^*}\right) ^{-1} \cdot  \frac{\partial g(x, z^*)}{\partial x}$.

In code:

In [4]:
def compute_dz_dx(func, w, x, z_star):
    func_root = lambda z_i: func(w, x, z_i) - z_i
    dg_dz_inv = lambda z_i: -1 * jnp.linalg.solve(jax.jacobian(func_root)(z_i), func_root(z_i))
    dg_dx = jax.jacobian(func, argnums=1)(w, x, z_star)
    return dg_dz_inv(z_star) * dg_dx

dz_dx = compute_dz_dx(func, W, x, newton_z_star)


def newton_forward(func, w, x, z_init):
    return newton_solver(lambda z: func(w, x, z), z_init)[0]

dz_dx_autograd = jax.jacobian(newton_forward, argnums=2)(func, W, x, z_init)

## Is this actually faster?

print('Time for IFT:')
%timeit compute_dz_dx(func, W, x, newton_z_star)

print('Time for autograd:')
%timeit jax.jacobian(newton_forward, argnums=2)(func, W, x, z_init)

Time for IFT:
20.7 ms ± 2.82 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Time for autograd:
217 ms ± 8.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Differentiating through Newton's method is now orders of magnitude faster!

# IFT Formally...

**Implicit Function Theorem:** Let $f: \mathbb{R}^p \times \mathbb{R}^n \rightarrow \mathbb{R}^n$ and $a_0 \in \mathbb{R}^p$, $z_0 \in \mathbb{R}^n$ such that:

1. $f(a_0, z_0) = 0$
2. $f$ is continuously differentiable with non-singular jacobian $\partial _1 f(a_0, z_0) \in \mathbb{R}^{n\times n}$.

Then there exists an open set $S_{a_0} \subset \mathbb{R}^p$, $S_{z_0} \subset \mathbb{R}^n$ containing $a_0$ and $z_0$ respectively and a unique continuous function $z^*: S_{a_0} \rightarrow S_{z_0}$ such that:

1. $z_0 = z^*(a_0)$
2. $f(a, z^*(a)) = 0 \forall a \in S_{a_0}$
3. $z^*$ is differentiable on $S_{a_0}$.

We've covered a simple example of the IFT and shown how the IFT can greatly reduce the cost of differentiating through an optimization loop. However, this notebook only touches the surface of what is possible with the IFT. For some additional references, please see:

- [Jaxopt](https://github.com/google/jaxopt) is a great library leveraging the IFT for differentiable optimization problems. They handle the registering of custom gradient definitions in Jax's autograd engine and act as a drop in replacement for many common optimization problems (e.g., constrained QP, root finding). The analogous PyTorch library is [Theseus](https://github.com/facebookresearch/theseus).
- [NeurIPS 2020 tutorial](http://implicit-layers-tutorial.org/) on implicit layers.

Further references:



# An exercise
We leave the following example as an exercise. Consider a similar fixed point iteration problem as Newton solver. However, this time we will use [Anderson Acceleration](https://en.wikipedia.org/wiki/Anderson_acceleration). We provide the implementation of the forward solver; your task is to use the IFT to compute $\frac{\partial z^*}{\partial x}$.

In [5]:
def anderson_solver(func, z_init, m=5, lam=1e-4, max_iter=50, tol=1e-5, beta=1.0):
    x0 = z_init
    x1 = func(x0)
    x2 = func(x1)
    X = jnp.concatenate([jnp.stack([x0, x1]), jnp.zeros((m - 2, *jnp.shape(x0)))])
    F = jnp.concatenate([jnp.stack([x1, x2]), jnp.zeros((m - 2, *jnp.shape(x0)))])

    res = []
    for k in range(2, max_iter):
        n = min(k, m)
        G = F[:n] - X[:n]
        GTG = jnp.tensordot(G, G, [list(range(1, G.ndim))] * 2)
        H = jnp.block([[jnp.zeros((1, 1)), jnp.ones((1, n))],
                    [ jnp.ones((n, 1)), GTG]]) + lam * jnp.eye(n + 1)
        alpha = jnp.linalg.solve(H, jnp.zeros(n+1).at[0].set(1))[1:]

        xk = beta * jnp.dot(alpha, F[:n]) + (1-beta) * jnp.dot(alpha, X[:n])
        X = X.at[k % m].set(xk)
        F = F.at[k % m].set(func(xk))

        res = jnp.linalg.norm(F[k % m] - X[k % m]) / (1e-5 + jnp.linalg.norm(F[k % m]))
        if res < tol:
            break
    return xk

f = lambda z: func(W, x, z)

anderson_solution = anderson_solver(f, z_init)
print(anderson_solution)

[-0.16797939 -0.3912952  -0.83031964 -0.15471825 -0.41141713 -0.38461316
  0.07425945  0.46674597  0.02648038  0.6862731 ]


In [6]:
anderson_solution

Array([-0.16797939, -0.3912952 , -0.83031964, -0.15471825, -0.41141713,
       -0.38461316,  0.07425945,  0.46674597,  0.02648038,  0.6862731 ],      dtype=float32)