# Scalar root finding pullback / vJP rule

In [13]:
import jax
import jax.numpy as np

from functools import partial

## Theory

**Function:**

$$f(\theta)=\{ \text{solve } g(x, \theta) = 0 \text{ for } x\} \triangleq x$$

where $x\in \mathbb{R}$, $\theta \in \mathbb{R}$, $g: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}$ and $f: \mathbb{R} \rightarrow \mathbb{R}$. Solve could be done using iterative solvers such as Bisection method or Netwon-Raphson method.

**Task:** Backpropagate $\bar{x} \in \mathbb{R}$ to $\mathbb{\theta} \in \mathbb{R}$ with reverse-mode AD without doing it through the solver (sometimes called unrolling or piggy packing):

$$\bar{\theta} = \bar{x} \frac{\partial x}{\partial \theta}$$

We need to find $\frac{\partial x}{\partial \theta}$.

**Derivation:** The input and output are related by the following relation: 

$$g(x, \theta) = 0.$$

This is actually a really interesting perspective to have of a solver! To get a bit of intuition, think of how you would define a circle centered at the origin as $f(x, y) = x^2 + y^2 - r^2 = 0$.

Calculating the total derivative of both sides (implicit differentiation) wrt $\theta$, we get

\begin{align}
\frac{d}{d\theta} \{ g(x, \theta) \} &= 0 \\
\frac{\partial g}{\partial x} \frac{\partial x}{\partial \theta} + \frac{\partial g}{\partial \theta} 1 &= 0 \\
\frac{\partial x}{\partial \theta} &= - \frac{\frac{\partial g}{\partial \theta}}{\frac{\partial g}{\partial x}}.
\end{align}

## Simple example: inverting $f(x) = x^2$

In [197]:
def f(x):
    return x ** 2

In [210]:
def g(x, θ):  # the roots of this function would be solutions to f(x) = theta
    return f(x) - θ

dg_dx = jax.grad(g, argnums=0)
dg_dθ = jax.grad(g, argnums=1)

In [211]:
def nr_one_step(_, x, θ):
    return x - g(x, θ) / dg_dx(x, θ)

In [216]:
def nr(θ, n_iters=1000, init_val=1.):
    nr_one_step_partial = partial(nr_one_step, θ=θ)
    return jax.lax.fori_loop(0, n_iters, nr_one_step_partial, init_val=init_val)

In [217]:
nr(9.)

Array(3., dtype=float32, weak_type=True)

Two ways of computing $\frac{\partial x}{\partial \theta}$:

In [218]:
dx_dθ_unrolled = jax.jit(jax.grad(nr))

In [219]:
%timeit dx_dθ_unrolled(9.)

19.1 µs ± 49.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [220]:
@jax.jit
def dx_dθ_implicit(x, theta):
    return - dg_dθ(x, theta) / dg_dx(x, theta)

In [221]:
%timeit dx_dθ_implicit(3., 9.)

3.3 µs ± 27.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


But to what extend is this truly useful? If the forward pass is too slow we probably don't want the solve anyway. 