## Background

In Stochastic Reconfiguration (SR), Time-Dependent Variational Principle (TDVP), and Natural Gradient Descent (NGD), the common objective is to solve the following linear equation:
$$
(O^{\top}O) \dot{\theta} = F,
$$
where $O$ is a $N \times M$ matrix, $\dot{\theta}$ and $F$ are both $M$-dimentional vectors.
Usually we cannot invert $O^{\top} O$ directly since it is ill-conditioned, so we need to regularize it, that is, 
$$
(O^{\top} O + \lambda I)\dot{\theta} = F.
$$
In SR and other applications, $N$ denotes the batch size or number of samples, and $M$ is the total number of variational parameters.
Here we deal with the case when $M \gg N$.
The $\dot{\theta}$ is the update parameter of interests, $O^{\top} O$ is called the (empirical) Fisher information matrix, where
$$
O_{ij} = \frac{1}{\sqrt{N}} \frac{\partial \log P_{\theta}(x_i)}{\partial \theta_j},
$$
and $F$ is the gradients of the loss function with respect to parameters
$$
F_j = \frac{\partial L(\theta)}{\partial \theta_j}.
$$

## SVD-based method

Suppose we can perform the SVD of $O$:
$$
O = U \Sigma V^{\top},
$$
then the solution is given by
$$
\dot{\theta} = V(\Sigma^2+\lambda \tilde{I})^{-1}V^{\top}F + \frac{1}{\lambda}(F - VV^{\top}F).
$$
It is easily verified by noting that $O^{\top}O=V\Sigma^2 V^{\top}$ and $V^{\top}V=\tilde{I}$.

## Cholesky-based method

The Cholesky Solve of Damped Fisher Algorithm is given as follows:

1. $W \leftarrow OO^{\top} + \lambda \tilde{I}$, $~~~~~W$ is a $N \times N$ matrix
2. $L \leftarrow \text{Chol}(W)$, $~~~~~L$ is a $N \times N$, lower triangular matrix
3. $Q \leftarrow L^{-1} O$, $~~~~~Q$ is a $N \times M$ matrix
4. Finally, $\dot{\theta} \leftarrow \frac{1}{\lambda}(F - Q^{\top}QF)$

See [arXiv:2310.17556](https://arxiv.org/abs/2310.17556) for the proof.

Note: In practice, $Q^\top Q v = (L^{-1}S)^\top (L^{-1}S)v = S^\top W^{-1} Sv$. 
See [https://pytorch.org/docs/stable/generated/torch.cholesky_solve.html#torch-cholesky-solve](https://pytorch.org/docs/stable/generated/torch.cholesky_solve.html#torch-cholesky-solve)

In [1]:
import torch
import math

N = 1000
M = 10000
lambd = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

# define O matrix (N by M) and F vector (M)
O_mat = torch.randn(N, M, device=device) / math.sqrt(N)
F_vec = torch.rand(M, device=device)


def svd_solve(O_mat, F_vec, lambd=1e-3):
    # solve the linear system (O O^T + lambda I) dtheta = F by SVD
    # compute the eigenvalue decomposition of O O^T = U Sigma^2 U^T
    Sigma2, U = torch.linalg.eigh(O_mat @ O_mat.T)
    # V = O^T U Sigma^{-1}
    V = O_mat.T @ U @ torch.diag(1.0 / torch.sqrt(Sigma2))

    return V @ torch.diag(1.0 / (Sigma2 + lambd)) @ V.T @ F_vec + (F_vec - V @ V.T @ F_vec) / lambd


def cholesky_solve(O_mat, F_vec, lambd=1e-3):
    # solve the linear system (O O^T + lambda I) dtheta = F by Cholesky decomposition
    # TODO: optimize this function
    N, _ = O_mat.size()
    W = O_mat @ O_mat.T + lambd * torch.eye(N, device=O_mat.device)
    L = torch.linalg.cholesky(W)
    Q = torch.linalg.inv(L) @ O_mat

    return (F_vec - Q.T @ Q @ F_vec) / lambd


def cholesky_solve_fast(O_mat, F_vec, lambd=1e-3):
    N, _ = O_mat.size()
    W = O_mat @ O_mat.T + lambd * torch.eye(N, device=O_mat.device)
    L = torch.linalg.cholesky(W)
    QTQv = O_mat.T @ torch.cholesky_solve(O_mat, L) @ F_vec

    return (F_vec - QTQv) / lambd


dtheta_svd = svd_solve(O_mat, F_vec, lambd)
dtheta_chol = cholesky_solve(O_mat, F_vec, lambd)
dtheta_chol_fast = cholesky_solve_fast(O_mat, F_vec, lambd)
# print(f"Distance between dtheta_svd and dtheta_chol: {torch.dist(dtheta_svd, dtheta_chol)}")
print(f"Distance between dtheta_chol and dtheta_chol_fast: {torch.dist(dtheta_chol, dtheta_chol_fast)}")

F_svd = (O_mat.T @ O_mat + torch.eye(M, device=device) * lambd) @ dtheta_svd
F_chol = (O_mat.T @ O_mat + torch.eye(M, device=device) * lambd) @ dtheta_chol
print(f"Distence between F_svd and F_exact: {torch.dist(F_svd, F_vec)}")
print(f"Distence between F_chol and F_exact: {torch.dist(F_chol, F_vec)}")


import torch.utils.benchmark as benchmark

# benchmark
bench_svd_solve = benchmark.Timer(
    stmt="svd_solve(O_mat, F_vec, lambd)",
    setup="from __main__ import svd_solve",
    globals={"O_mat": O_mat, "F_vec": F_vec, "lambd": lambd},
)

bench_cholesky_solve = benchmark.Timer(
    stmt="cholesky_solve(O_mat, F_vec, lambd)",
    setup="from __main__ import cholesky_solve",
    globals={"O_mat": O_mat, "F_vec": F_vec, "lambd": lambd},
)

bench_cholesky_solve_fast = benchmark.Timer(
    stmt="cholesky_solve_fast(O_mat, F_vec, lambd)",
    setup="from __main__ import cholesky_solve_fast",
    globals={"O_mat": O_mat, "F_vec": F_vec, "lambd": lambd},
)

print(bench_svd_solve.timeit(10))
print(bench_cholesky_solve.timeit(10))
print(bench_cholesky_solve_fast.timeit(10))

Distance between dtheta_chol and dtheta_chol_fast: 0.017137594521045685
Distence between F_svd and F_exact: 0.390462726354599
Distence between F_chol and F_exact: 0.21405534446239471
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8092142cb0>
svd_solve(O_mat, F_vec, lambd)
setup: from __main__ import svd_solve
  92.01 ms
  1 measurement, 10 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f819a8358d0>
cholesky_solve(O_mat, F_vec, lambd)
setup: from __main__ import cholesky_solve
  41.32 ms
  1 measurement, 10 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f80a54ff2b0>
cholesky_solve_fast(O_mat, F_vec, lambd)
setup: from __main__ import cholesky_solve_fast
  39.06 ms
  1 measurement, 10 runs , 1 thread
