# Matrix-free Differentiable Linear Solver

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Zeqiang-Lai/Delta-Prox/blob/master/notebooks/csmri.ipynb) 


In this tutorial, we provide a step by step derivation of the matrix-free differentiable linear solver mentioned in ∇-Prox.

Recall that our goal is to find the gradient of the output of a linear solver $\bar{x}$
$$
\bar{x} = \text{Solve}(Kx=b)
$$
with respect to the parameters in the solved linear system, such as $\frac{\partial \bar{x}}{\partial K}$ and $\frac{\partial \bar{x}}{\partial b}$.

In [None]:
# uncomment the following line to install dprox if your are in online google colab notebook
# !pip install dprox

In [1]:
import torch
from torch.autograd.functional import jacobian

## Naive Approach with Auto-Diff

Let us first derive the gradient with auto-differentiation.

In [2]:
torch.manual_seed(0)
theta = torch.randn((32,32), requires_grad=True)  # define parameter of the linOp K
K = theta * 2
x = torch.randn((32))
b = K @ x
b = b.clone().detach().requires_grad_(True)

xhat = torch.linalg.solve(K, b)

loss = xhat.mean()
loss.backward()

print(theta.grad.shape)
print(b.grad.shape)

torch.Size([32, 32])
torch.Size([32])


## Implicit Differentiation

Auto-diff can be used to efficiently differentiate fast direct linear solvers but is often intractable for iterative linear solvers. 

In ∇-Prox, we provide an optimized routine to compute the analytic derivatives of linear (iterative) solver outputs with respect to the parameters of linear operators $\theta$ and $b$. 

###  Derivation of $\frac{\partial \bar{x}}{\partial b}$

Specifically, we differentiate both sides of $K\bar{x} =b$ to obtain the derivatives $\frac{\partial \bar{x}}{\partial b}$ and $\frac{\partial \bar{x}}{\partial \theta}$ as

$$
\partial K \bar{x} + K \partial \bar{x} = \partial b \\
\partial \bar{x} = K^{-1} (-\partial K \bar{x} + \partial b)
$$

from which the gradient $\frac{\partial \bar{x}}{\partial b} = K^{-1}$ can be easily derived. Typically, we are more interested in the gradient of $b$ with respect to a scalar loss function $\mathcal{L}$, which can be obtained with the chain rule of differential calculus.

$$
\frac{\partial \mathcal{L}}{\partial b} =  \left (\frac{\partial \bar{x}}{\partial b} \right )^T \frac{\partial \mathcal{L}}{\partial \bar{x}} =  K^{-T} \frac{\partial \mathcal{L}}{\partial \bar{x}}
$$

Since all the linear operators in our system are matrix-free, we cannot directly evaluate the above formula for gradient computing. Instead, we transform it into
$$
 K^T \frac{\partial \mathcal{L}}{\partial b} = \frac{\partial \mathcal{L}}{\partial x}  
$$
where the right-hand-side is the Jacobian of $\mathcal{L}$ with respect to $x$ that can be efficiently evaluated with auto-diff systems. The calculation of gradient $\frac{\partial \mathcal{L}}{\partial b}$ has thus been converted into solving a linear system, requiring significantly less memory. 

> The above derivation assumes the gradient layout of 
> 
> dx/db = [dx1/db1, dx1/db2, ..., dxn/dbn; dx2/db1, ...; dx3/db1, ...]
> 
> Note that the gradient layout of torch.autograd.functional.jacobian is the same as above.
>
> See also: https://en.wikipedia.org/wiki/Matrix_calculus#Layout_conventions

In [3]:
torch.manual_seed(0)
theta = torch.randn((5,5), requires_grad=True)
K = theta * 2
x = torch.randn(5)
b = K @ x
b = b.clone().detach().requires_grad_(True)

xhat = torch.linalg.solve(K, b)
xhat.retain_grad()  # retain non-leaf gradient for analytical compute

loss = xhat.mean()
loss.backward()

# analytical gradient using implicit differentiation 
db = torch.inverse(K.T) @ xhat.grad  
db2 = torch.linalg.solve(K.T, xhat.grad)

# analytical gradient versus auto-grad 
print(b.grad)
print((b.grad - db).abs().max())  
print(torch.allclose(b.grad, db, rtol=1e-6))
print(torch.allclose(b.grad, db2, rtol=1e-6))

tensor([ 0.1406, -0.0938,  0.2671, -0.1739, -0.1323])
tensor(1.4901e-08, grad_fn=<MaxBackward1>)
True
True


###  Derivation of $\frac{\partial \bar{x}}{\partial \theta}$

Similarly, the gradient $\frac{\partial \mathcal{L}}{\partial \theta}$ with respect to the parameters $\theta$ of the linear operator $K$ can be derived as
$$
\frac{\partial \mathcal{L}}{\partial \theta} = \left (\frac{\partial \bar{x}}{\partial \theta} \right )^T  \frac{\partial \mathcal{L}}{\partial \bar{x}}  \quad \text{s.t.} \quad \frac{\partial \bar{x}}{\partial \theta} = -K^{-1}  \frac{\partial K}{\partial \theta} \bar{x} \,.
$$

Again, $\frac{\partial K}{\partial \theta}$ cannot be evaluated directly as we consider matrix-free linear operators. To circumvent this obstacle, we use the fact that 
$$\frac{\partial K}{\partial \theta}\bar{x}=\frac{\partial b}{\partial \theta}$$ 
to transform it into
$$
K \frac{\partial \bar{x}}{\partial \theta} = -  \frac{\partial b}{\partial \theta} ,
$$
where $\frac{\partial b}{\partial \theta}$ can be computed by backpropagating the forward computation $K\bar{x}=b$. As such, the calculation of gradients $\frac{\partial \mathcal{L}}{\partial b}$ and $\frac{\partial \mathcal{L}}{\partial \theta}$ is converted into solving linear systems during backpropagation without requiring storing intermediate states, thereby significantly reducing memory consumption and saving computation time. 

> Note that we assume $\theta$ to have the same shape as $K$, so that the shape of $\frac{\partial K}{\partial \theta}\bar{x}=\frac{\partial b}{\partial \theta}$ holds. The gradient with the real $\bar{\theta}$ can be automatically tracked by auto-diff if we know the function that transforms $\theta$ into $\bar{\theta}$.

**Reference Implementation with the Explicit Matrix**

Suppose $K \in \mathrm{R}^{R\times C}$, $\theta \in \mathrm{R}^{R2\times C2} $, $x \in \mathrm{R}^{N}$, $b \in \mathrm{R}^{N}$.

Since $K$ is a square matrix, $R$, $C$, $R2$, $C2$, $N$ are of the same value. We simply use different symbols to better illustrate the gradient layout.

$$
\frac{\partial \bar{x}}{\partial \theta} = -K^{-1}  \frac{\partial K}{\partial \theta} \bar{x}
$$


$$
\frac{\partial \mathcal{L}}{\partial \theta} = \left (\frac{\partial \bar{x}}{\partial \theta} \right )^T  \frac{\partial \mathcal{L}}{\partial \bar{x}}
$$

Note that this might be confused for the shape computation. However, keep in mind, that we are interested in the gradient layout. As all $R$, $C$, $R2$, $C2$, $N$ are of the same value, they are valid for matrix multiplications.

In [4]:
# define a linOp depending on the parameter theta
Kmat = lambda theta: theta * 2

dK_dtheta = jacobian(Kmat, theta) # [R x C] x [R2 x C2]

# In PyTorch, dK_dtheta @ xhat is recognized as batched matrix multiplication
# It would be [R x C] x [R2 x C2] @ [N x 1]
# so dK_dtheta @ xhat actually returns [R x C] x R2

# Method 1
# R x C @ [R x C] x R2 = R x C x R2
dxhat_dtheta = - K.inverse() @ dK_dtheta @ xhat

# Method 2
# In theory, dxhat_dtheta should be N x [R2 x C2], 
# but torch.linalg.solve returns [R2 x C2] x N, 
# Note: K = [R x C], -(dK_dtheta @ xhat) = [R2 x C2](batch size) x N x 1
dxhat_dtheta = torch.linalg.solve(K, -(dK_dtheta @ xhat).unsqueeze(-1)).squeeze(-1)

# Therefore, we do not need to transpose dxhat_dtheta here.
dloss_dtheta = dxhat_dtheta @ xhat.grad

print(torch.mean(torch.abs(dloss_dtheta - theta.grad)))
print(torch.allclose(dloss_dtheta, theta.grad, rtol=1e-6))

tensor(2.3991e-08, grad_fn=<MeanBackward0>)
True


In [5]:
torch.manual_seed(0)

theta = torch.randn((5,5), requires_grad=True)
x = torch.randn(5)

def f(theta):
    K = theta * 2
    b = K @ x
    b = b.clone().detach().requires_grad_(True)
    xhat = torch.linalg.solve(K, b)
    return xhat

# Directly evaluate df_dtheta using auto-grad (note that this naive approach scales very poorly)
jab = jacobian(f, theta)  
xhat = xhat.clone().detach().requires_grad_()
# xhat.retain_grad()

loss = xhat.mean()
loss.backward()


dtheta = jab.permute(1,2,0) @ xhat.grad
print(torch.mean(torch.abs(dtheta - dloss_dtheta)))
print(jab.permute(1,2,0)[0])
print(dxhat_dtheta[0])

tensor(3.3677e-08, grad_fn=<MeanBackward0>)
tensor([[-1.0499e-01,  1.0606e-02,  3.8049e-02,  9.7144e-05, -9.9203e-01],
        [ 3.4649e-02, -3.5001e-03, -1.2557e-02, -3.2060e-05,  3.2739e-01],
        [-1.1986e-01,  1.2108e-02,  4.3439e-02,  1.1091e-04, -1.1326e+00],
        [ 4.9314e-02, -4.9815e-03, -1.7872e-02, -4.5629e-05,  4.6596e-01],
        [-2.3773e-01,  2.4014e-02,  8.6152e-02,  2.1996e-04, -2.2462e+00]])
tensor([[-1.0499e-01,  1.0606e-02,  3.8049e-02,  9.7131e-05, -9.9203e-01],
        [ 3.4649e-02, -3.5002e-03, -1.2557e-02, -3.2051e-05,  3.2739e-01],
        [-1.1986e-01,  1.2108e-02,  4.3439e-02,  1.1087e-04, -1.1326e+00],
        [ 4.9314e-02, -4.9815e-03, -1.7871e-02, -4.5622e-05,  4.6596e-01],
        [-2.3773e-01,  2.4014e-02,  8.6152e-02,  2.1991e-04, -2.2462e+00]],
       grad_fn=<SelectBackward0>)


**Matrix-Free Reference Implementation**

Suppose $K \in \mathrm{R}^{R\times C}$, and $\theta \in \mathrm{R}^{R2 \times C2} $, and $x \in \mathrm{R}^{N}$, $b \in \mathrm{R}^{N}$.

Since $K$ is a square matrix, $R$, $C$, $R2$, $C2$, $N$ are of the same value. We simply use different symbols to better illustrate the gradient layout.

$$
K \frac{\partial \bar{x}}{\partial \theta} = -  \frac{\partial b}{\partial \theta} ,
$$

$$
\frac{\partial \mathcal{L}}{\partial \theta} = \left(\frac{\partial \bar{x}}{\partial \theta} \right )^T  \frac{\partial \mathcal{L}}{\partial \bar{x}}
$$


In [6]:
torch.manual_seed(0)
theta = torch.randn((5,5), requires_grad=True)
K = theta * 2
x = torch.randn(5)
b = K @ x
b = b.clone().detach().requires_grad_(True)

xhat = torch.linalg.solve(K, b)
xhat.retain_grad()

loss = xhat.mean()
loss.backward()

def linop(theta):
    return (theta*2) @ xhat

# R2 x C2 here works like batch size
db_dtheta = jacobian(linop, theta).permute(1,2,0).unsqueeze(-1) # [R2 x C2] x N x 1

# in theory, dxhat_dtheta should be N x [R2 x C2], 
# but torch.linalg.solve return [R2 x C2] x N, 
# Note: K = [R x C], db_dtheta = [R2 x C2](batch size) x N x 1
dxhat_dtheta = torch.linalg.solve(K, -db_dtheta).squeeze(-1)

# therefore, we don't need to transpose dxhat_dtheta here.
dLoss_dtheta = dxhat_dtheta @ xhat.grad

print(dLoss_dtheta.shape)
print(torch.mean(torch.abs(dLoss_dtheta - theta.grad) / torch.max(torch.abs(theta.grad))))
print(torch.allclose(dLoss_dtheta, theta.grad, rtol=1e-6))

torch.Size([5, 5])
tensor(2.6605e-08, grad_fn=<MeanBackward0>)
True
