# 8.3 Reverse-Mode "Adjoint" Differentiation

Assume that we have a matrix $A(p) \in \mathcal{R}^{n \times n}$ that depends on some parameters $p \in \mathcal{R}^{m}$. Assume furthermore, that we form a linear system with $A(p)$, and a given a fixed $b \in \mathcal{R}^n$.

$$
\begin{align*}
A(p) x = b \\
\implies x = A(p)^{-1} b
\end{align*}
$$
Where $x \in \mathcal{R}^n$

Assume then, that we use $x$ within some operation $f(x)$. For example:
$$ f(x) = || x ||_2 = \langle x, x \rangle^{0.5}$$

Next, we want minimize $f(x)$ and are therfore interested in the differential $df$, and the gradient $\nabla f$.
$$
\frac{df}{dp} = \frac{df}{dx} \frac{dx}{dp}
$$

We know that:
$$
\begin{align*}
\frac{df}{dx} 
&= f'(x) \\
&= 0.5 \langle x,x \rangle ^{-0.5} \cdot 2 x \\
&= \frac{x}{f(x)}
\end{align*}
$$

Another way of putting this is $f(x) = g(h(x))$ where $h(x) = \langle x, x \rangle$, and $g(u) = u^{0.5}$. 

It follows that $g'(u) = 0.5 u^{-0.5}$, and $h'(x) = 2x $.

Then:

$$ 
\begin{align*}
f'(x) 
&= g'(h(x)) \cdot h'(x) \\
&= 0.5 {\langle x, x \rangle}^{-0.5} \cdot 2x
\end{align*}
$$

Now, back to $\frac{dx}{dp}$. Since $x = A(p)^{-1} b$:
$$
\begin{align*}
dx
&= dA(p)^{-1} b + A(p)^{-1} \underbrace{db}_0 \\
&= \underbrace{dA(p)^{-1}}_{d(A^{-1}) = - A^{-1} dA A^{-1}} b \\
&= - A^{-1} dA \ \underbrace{A^{-1} b}_x & \text{assume } A := A(p) \\
&= - A^{-1} dA \ x
\end{align*}
$$

And since $df = f'(x) dx$, we can substitute $f'(x)$, and $dx$ to find the differential $df$
$$
\begin{align*}
df 
&= \underbrace{\frac{x^T}{f(x)} \cdot (-A^{-1}}_{v^T} dA \ x) \\
&= v^T dA \ x
\end{align*}
$$

The vector $v$ can be found like this:
$$
\begin{align*}
v^T
&= - \frac{x^T}{f(x)} A^{-1} \\
\implies v^T A &= - \frac{x^T}{f(x)} \\
\implies A^T v &= - \frac{x}{f(x)} \\
\implies v &= - (A^T)^{-1} \frac{x}{f(x)}
\end{align*}
$$
This means that $v$ is the result of solving the linear sytem with matrix $A^T$ and vector $-f'(x)$.

Finally, we are interested in $\frac{df}{dp}$:
$$
\begin{align*}
\frac{df}{dp}
&= v^T \frac{dA}{dp} x \\
&= v^T \underbrace{A'(p)}_{(n \times n \times m) \text{ tensor}} x
\end{align*}
$$

In [1]:
import torch

p = torch.tensor([1.0, 2, 3, 4], requires_grad=True)
m = 4
assert p.size(0) == m

b = torch.tensor([4.0, 3, 2, 1])
n = 4
assert b.size(0) == n

def A(p):
    row1 = torch.stack([p[0] + p[1], p[2]**2])
    row2 = torch.stack([p[3] - p[2], p[0]**3])
    quarter = torch.stack([row1, row2])
    half = torch.cat([quarter, quarter])
    full = torch.cat([half, half], 1)
    return full + torch.eye(4)

loss = A(p)
loss.backward(gradient=torch.ones_like(loss))

p.grad

tensor([16.,  4., 20.,  4.])

In [2]:
from torch.linalg import solve

def x(A):
    return solve(A, b)

x(A(p))

tensor([ 2.8000,  0.4667,  0.8000, -1.5333], grad_fn=<LinalgSolveExBackward0>)

In [3]:
def f(x):
    return torch.norm(x)

def fprime(x):
    return x / f(x)

In [4]:
def v(p):
    return -solve(A(p).T, fprime(x(A(p))))

v(p)

tensor([-0.1711, -1.0255,  0.4305, -0.4239], grad_fn=<NegBackward0>)

# Numerical Comparison

In [5]:
def dA():
    torch.manual_seed(0)
    return torch.randn(m, m) * 1e-5

def finite_difference(p):
    A_ = A(p)
    return f(x(A_ + dA())) - f(x(A_))

def df(p):
    return v(p) @ dA() @ x(A(p))

The analytical implementation agrees with the finite difference.

In [6]:
finite_difference(p), df(p)

(tensor(-6.1512e-05, grad_fn=<SubBackward0>),
 tensor(-6.1578e-05, grad_fn=<DotBackward0>))