In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.

from itertools import product

import jax
import jax.numpy as jnp
import torch
from torch import Tensor, dot, eye, jit, outer, tensordot
from torch.linalg import lstsq, solve
from tqdm.autonotebook import tqdm

# Torch

In [None]:
m, n = 64, 32
A0 = torch.randn(m, n)
print(f"{torch.linalg.cond(A0)=}")
xi = torch.randn(1)
phi = torch.randn(m)
psi = torch.randn(n)
U, S, V = torch.linalg.svd(A0)
u, s, v = U[:, 0], S[0], V[0, :]
R = xi * s + u.dot(phi) + v.dot(psi)
print(f"f(A)={R}")

## numerical gradient

In [None]:
eps = 10**-5
E = eps * torch.randn(m, n)

UE, SE, VE = torch.linalg.svd(A0 + E)
ue, se, ve = UE[:, 0], SE[0], VE[0, :]
F2 = xi * se + ue.dot(phi) + ve.dot(psi)
print(f"f(A+∆A) = {F2}")

In [None]:
%%script echo skipping

@jit.script
def g_value(A: Tensor, outer_grad: tuple[Tensor, Tensor, Tensor]) -> Tensor:
    xi, phi, psi = outer_grad
    UE, SE, VE = torch.linalg.svd(A)
    ue, se, ve = UE[:, 0], SE[0], VE[0, :]
    RE = xi * se + ue.dot(phi) + ve.dot(psi)
    return RE

G_numerical = torch.zeros((m, n))
X = torch.zeros((m, n))

for i, j in tqdm(product(range(m), range(n))):
    X[i, j] = 1
    GL = g_value(A0 + eps*X, (xi, phi, psi))
    GR = g_value(A0 - eps*X, (xi, phi, psi))
    G_numerical[i, j] = (GL - GR) / (2*eps)
    X[i, j] = 0

## Torch autograd

In [None]:
device = torch.device("cpu")
A = torch.nn.Parameter(A0.clone().to(device))
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)
r.backward()
print(r)
G_torch = A.grad.clone().detach().cpu()
print(G_torch)
diff_y = abs(F2 - r.cpu() - tensordot(G_torch, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y/eps}")
# assert torch.allclose(A.grad, torch.outer(u, v))

# Jax variant

In [None]:
%%script echo skipping

device = torch.device("cpu")
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)


def svd_grad(X, xi, phi, psi):
    U, S, V = jnp.linalg.svd(X, full_matrices=False, compute_uv=True)
    u, s, v = U[:, 0], S[0], V[0, :]
    return xi.item() * s + u.dot(phi) + v.dot(psi)


f = jax.value_and_grad(svd_grad)

value, G_jax = f(A0.numpy(), xi.numpy(), phi.numpy(), psi.numpy())
print(value)
print(G_jax)
err_grad = jnp.linalg.norm(G_jax - G_torch.numpy()) / jnp.linalg.norm(G_torch.numpy())
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2.numpy() - value - (G_jax * E.numpy()).sum()).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y}")

## manual computation

In [None]:
device = torch.device("cuda")

A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)

K = torch.cat(
    [
        torch.cat([s * I_m, -A], dim=-1),
        torch.cat([-A.T, s * I_n], dim=-1),
    ],
    dim=0,
)
print(torch.linalg.cond(K))
# x = torch.linalg.solve(K, torch.cat([phi, psi]))
x = torch.linalg.lstsq(K, torch.cat([phi, psi]))[0]
p = x[:m]
q = x[m:]

g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y/eps}")

In [None]:
PB = torch.eye(m + n, device=device)
PB[:m, :m] -= outer(u, u)
PB[m:, m:] -= outer(v, v)
z = torch.cat([u, -v])
L = K - s * outer(z, z)

In [None]:
Kp = K.pinverse()
Lp = L.pinverse()

torch.linalg.cond(K), torch.linalg.cond(PB @ K), torch.linalg.cond(L)

In [None]:
(Lp - Kp) / outer(z, z)

In [None]:
Kp = K.pinverse()
Lp = L.pinverse()

In [None]:
1 - s * z.T @ Kp @ z

In [None]:
Kp - outer(z, z) / (4 * s)

## Augmented K matrix  !!! ONLY CORRECT APPROACH !!!

In [None]:
device = torch.device("cuda")

A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)
O_m = torch.zeros(m, device=device)
O_n = torch.zeros(n, device=device)

K = torch.cat(
    [
        torch.cat([s * I_m, -A], dim=-1),
        torch.cat([-A.T, s * I_n], dim=-1),
        torch.cat([u, O_n], dim=0).unsqueeze(0),
        torch.cat([O_m, v], dim=0).unsqueeze(0),
    ],
    dim=0,
)
c = torch.cat([phi, psi])
print(torch.linalg.cond(K))
# x = torch.linalg.solve(K, torch.cat([phi, psi]))
x = torch.linalg.lstsq(K.T, c)[0]
p = x[:m]
q = x[m : m + n]
λ = x[-2]
μ = x[-1]
print(s, λ, μ)
g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y}")

In [None]:
(K.T @ x - c).norm()

## manual computation Block inversion

In [None]:
device = torch.device("cuda")

A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)

P = s**2 * I_m - A @ A.T
Q = s**2 * I_n - A.T @ A

driver = "gels"

x = lstsq(P, s * phi, driver=driver)[0]
y = lstsq(P, A.mv(psi), driver=driver)[0]
w = lstsq(Q, A.T.mv(phi), driver=driver)[0]
z = lstsq(Q, s * psi, driver=driver)[0]

p = x + y
q = w + z

g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y/eps}")

In [None]:
print(
    dx := (P @ x - s * phi).norm().item(),
    dy := (P @ y - A.mv(psi)).norm().item(),
    dw := (Q @ w - A.T.mv(phi)).norm().item(),
    dz := (Q @ z - s * psi).norm().item(),
    sep="\n",
)

## manual Backward mixed approach

In [None]:
device = torch.device("cuda")

A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)
O_m = torch.zeros(m, device=device)
O_n = torch.zeros(n, device=device)

H = torch.cat([A.T, s**2 * I_m], dim=0)
h = torch.cat([-psi, O_m], dim=0)
J = torch.cat([A, s**2 * I_n], dim=0)
j = torch.cat([-phi, O_n], dim=0)

x = solve(s**2 * I_m - A @ A.T, s * phi)
y = lstsq(H, h)[0]
w = lstsq(J, j)[0]
z = solve(s**2 * I_n - A.T @ A, s * psi)

print(
    torch.linalg.cond(K),
    torch.linalg.cond(torch.cat([A.T, s**2 * I_m], dim=0)),
    torch.linalg.cond(torch.cat([A, s**2 * I_n], dim=0)),
)

p = x + y
q = w + z

g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y}")

In [None]:
print(
    dx := ((s**2 * I_m - A @ A.T) @ x - s * phi).norm().item(),
    dy := (H @ y - h).norm().item(),
    dw := (J @ w - j).norm().item(),
    dz := ((s**2 * I_n - A.T @ A) @ z - s * psi).norm().item(),
    sep="\n",
)

## manual Backward substitution

In [None]:
device = torch.device("cuda")

A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)
O_m = torch.zeros(m, device=device)
O_n = torch.zeros(n, device=device)

P = torch.cat([A.T, s**2 * I_m], dim=0)
Q = torch.cat([A, s**2 * I_n], dim=0)

Y = torch.cat(
    [
        torch.stack([-s * phi, -phi], dim=-1),
        torch.zeros((n, 2), device=device),
    ],
    dim=0,
)
Z = torch.cat(
    [
        torch.stack([-psi, -s * psi], dim=-1),
        torch.zeros((m, 2), device=device),
    ],
    dim=0,
)

mu, w = lstsq(Q, Y)[0].T
y, nu = lstsq(P, Z)[0].T
x = lstsq(A.T, mu)[0]
z = lstsq(A, nu)[0]

p = x + y
q = w + z

g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y}")

In [None]:
print(
    dx := (A.T @ x - mu).norm().item(),
    dz := (A @ z - nu).norm().item(),
    dmu := (Q @ mu - Y[:, 0]).norm().item(),
    dnu := (P @ nu - Z[:, 1]).norm().item(),
    dw := (P @ w - Y[:, 1]).norm().item(),
    dy := (Q @ y - Z[:, 0]).norm().item(),
    sep="\n",
)

## manual Forward substitution

In [None]:
device = torch.device("cuda")
A = torch.nn.Parameter(A0.clone()).to(device)
xi = xi.to(device)
phi = phi.to(device)
psi = psi.to(device)

U, S, V = torch.linalg.svd(A)
u, s, v = U[:, 0], S[0], V[0, :]
r = xi * s + phi.dot(u) + psi.dot(v)

I_m = eye(m, device=device)
I_n = eye(n, device=device)

mu = lstsq(A, s * phi)[0]
nu = lstsq(A.T, s * psi)[0]

P = torch.cat([A.T, s**2 * I_m], dim=0)
Q = torch.cat([A, s**2 * I_n], dim=0)

Y = torch.cat(
    [
        torch.stack([-mu, -psi], dim=-1),
        torch.zeros((m, 2), device=device),
    ],
    dim=0,
)
Z = torch.cat(
    [
        torch.stack([-phi, -nu], dim=-1),
        torch.zeros((n, 2), device=device),
    ],
    dim=0,
)

x = lstsq(P, Y)[0]
y = lstsq(Q, Z)[0]

p = x.sum(dim=-1)
q = y.sum(dim=-1)

g_sigma = xi * outer(u, v)
g_u = outer(p - dot(u, p) * u, v)
g_v = outer(u, q - dot(q, v) * v)
G = (g_sigma + g_u + g_v).cpu()

print(G)
err_grad = (G - G_torch).norm() / G_torch.norm()
print(f"diff to torch {err_grad:.4%}")
diff_y = abs(F2 - r.cpu() - tensordot(G, E)).item()
print(f"|f(A+∆A) - f(A) - ∇f(A)∆A|={diff_y}")

In [None]:
print(
    dmu := (A @ mu - s * phi).norm().item(),
    dnu := (A.T @ nu - s * psi).norm().item(),
    dp := (P @ x - Y).norm().item(),
    dq := (Q @ y - Z).norm().item(),
    sep="\n",
)