In [1]:
import jax.numpy as jnp
from jax import lax
import jax
import torch
import torch.autograd.functional as F
torch_jvp = F.jvp

## Cases
$$
\newcommand{\d}{{\mathrm{d}}}
\newcommand{\A}{{\mathbf{A}}}
\newcommand{\U}{{\mathbf{U}}}
\newcommand{\S}{{\mathbf{S}}}
\newcommand{\V}{{\mathbf{V}}}
\newcommand{\F}{{\mathbf{F}}}
\newcommand{\I}{{\mathbf{I}}}
\newcommand{\dA}{{\d\A}}
\newcommand{\dU}{{\d\U}}
\newcommand{\dS}{{\d\S}}
\newcommand{\dV}{{\d\V}}
\newcommand{\Ut}{{\U^{\top}}}
\newcommand{\Vt}{{\V^{\top}}}
\newcommand{\Vh}{{\V^{H}}}
\newcommand{\dAt}{{\dA^{\top}}}
\newcommand{\dVt}{{\dV^{\top}}}
\newcommand{\gA}{{\overline{\A}}}
\newcommand{\gU}{{\overline{\U}}}
\newcommand{\gUt}{{\gU^{\top}}}
\newcommand{\gS}{{\overline{\S}}}
\newcommand{\gSt}{{\gS^{\top}}}
\newcommand{\gV}{{\overline{\V}}}
\newcommand{\gVt}{{\gV^{\top}}}
$$
The derivative of the SVD operation is determined by

1. Computing the full SVD vs the "thin"/"partial" SVD
2. Computing the complete factorization $\U\S\Vh$ vs computing just the singular values $\S$
3. complex vs real inputs

These cases create 8 different cases for the SVD derivative each with separate differential formulas for forward mode AD update and adjoint formula for the reverse mode AD update.

## Numerical instability
TODO

## Real valued partial SVD

Reference: https://j-towns.github.io/papers/svd-derivative.pdf

### Forward mode

The differential formulas $\dU$, $\dS$, and $\dV$ in terms of $\dA$, $\U$, $\S$, and $\V$ are found from the standard chain rule (TODO check this from the differential formula).

Standard chain rule

$\dA = \dU \S \Vt + \U \dS \Vt + \U \S \dVt$

Differential formulas

$\dU = \U ( \F \circ [\Ut \dA \V \S + \S \Vt \dAt \U] ) + (\I_m - \U \Ut ) \dA \V \S^{-1}$

$\dS = \I_k \circ [\Ut \dA \V]$

$\dV = \V (\F \circ [\S \Ut \dA \V + \Vt \dAt \U \S]) + (\I_n - \V \Vt) \dAt \U \S^{-1}$

where

$F_{ij} = \frac{1}{s_j^2 - s_i^s}, i \neq j$

$F_{ij} = 0$ otherwise

In [24]:
def svd_jvp_real_valued_partial(A, dA):
    # TODO add dimension check
    
    U, S_vals, Vt = jnp.linalg.svd(A, compute_uv=True, full_matrices=False)

    S = jnp.diag(S_vals)
    Ut = U.T
    V = Vt.T
    dAt = dA.T

    k = S.shape[0]
    m = U.shape[0]
    n = Vt.shape[0]

    I_k = jnp.eye(k)
    I_m = jnp.eye(m)
    I_n = jnp.eye(n)

    S_inv = jnp.linalg.inv(S)

    F_i_j = lambda i, j: lax.cond(i == j, lambda: 0., lambda: 1 / (S_vals[j]**2 - S_vals[i]**2))
    F_fun = jax.vmap(jax.vmap(F_i_j, (None, 0)), (0, None))

    indices = jnp.arange(k)
    F = F_fun(indices, indices)

    dU = U @ (F * ((Ut @ dA @ V @ S) + (S @ Vt @ dAt @ U))) + ((I_m - U @ Ut) @ dA @ V @ S_inv)
    
    # Note that the `I_k *` is extraneous. It zeros out the rest of the matrix besides the diagonal.
    # We only return `dS_vals` which takes only the diagonal of `dS` anyway. 
    dS = I_k * (Ut @ dA @ V)
    dS_vals = jnp.diagonal(dS)
    
    dV = V @ (F * (S @ Ut @ dA @ V + Vt @ dAt @ U @ S)) + (I_n - V @ Vt) @ dAt @ U @ S_inv
    
    return (U, S_vals, Vt), (dU, dS_vals, dV)

In [28]:
def assert_svd_uv(a, b):
    (U, S, Vt), (dU, dS, dVt) = a
    (U_, S_, Vt_), (dU_, dS_, dVt_) = b
    
    assert_allclose(U, U_)
    assert_allclose(S, S_)
    assert_allclose(Vt, Vt_)
    assert_allclose(dU, dU_)
    assert_allclose(dS, dS_)
    assert_allclose(dVt, dVt_)

def assert_allclose(l, r):
    assert(jnp.allclose(jnp.array(l), jnp.array(r)))
        
def jax_real_valued_partial(A, dA):
    A = jnp.array(A)
    dA = jnp.array(dA)
    
    return jax.jvp(
        lambda A: jnp.linalg.svd(A, compute_uv=True, full_matrices=False), 
        (A,), 
        (dA,)
    )

def pytorch_real_valued_partial(A, dA):
    A = torch.tensor(A)
    dA = torch.tensor(dA)
    
    return torch_jvp(lambda A: torch.linalg.svd(A, full_matrices=False), A, dA)
    
def check_real_valued_partial(A, dA):
    jax_res = jax_real_valued_partial(A, dA)
    torch_res = pytorch_real_valued_partial(A, dA)
    res = svd_jvp_real_valued_partial(jnp.array(A), jnp.array(dA))
    
    print(jax_res[0][0])
    print(torch_res[0][0])
    
    print('*********')
    
    print(jax_res[0][1])
    print(torch_res[0][1])
    
    print('********')
    
    print(jax_res[0][2])
    print(torch_res[0][2])
    
    
    print('**********')
    print('**********')
    print('**********')

    print(jax_res[1][0])
    print(torch_res[1][0])
    print(res[1][0])
    
    print('*********')
    
    print(jax_res[1][1])
    print(torch_res[1][1])
    print(res[1][1])
    
    print('********')
    
    print(jax_res[1][2])
    print(torch_res[1][2])
    print(res[1][2])
    
    return
    
    assert_svd_uv(jax_res, torch_res)
    assert_svd_uv(jax_res, res)

A = [[1., 2., 3.], [5., 4., 3.], [10., 15., 12.]]
dA = [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]

# check_real_valued_partial(A, dA)

A = jnp.array(A)
dA = jnp.array(dA)

U, S, Vt = jnp.linalg.svd(A)

In [70]:
s = S
s_zeros = (s == 0).astype(s.dtype)
s_inv = 1 / (s + s_zeros) - s_zeros
s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv)

In [69]:
s_inv

DeviceArray([0.04355856, 0.44321346, 1.0791255 ], dtype=float32)

In [71]:
s_inv_mat

DeviceArray([[0.04355856, 0.        , 0.        ],
             [0.        , 0.44321346, 0.        ],
             [0.        , 0.        , 1.0791255 ]], dtype=float32)

In [73]:
S = jnp.diag(S)
print(S)
S_inv = jnp.linalg.inv(S)
S_inv

[[22.957602   0.         0.       ]
 [ 0.         2.2562492  0.       ]
 [ 0.         0.         0.9266763]]


DeviceArray([[0.04355856, 0.        , 0.        ],
             [0.        , 0.44321346, 0.        ],
             [0.        , 0.        , 1.0791255 ]], dtype=float32)

SVD's reverse mode autodiff is defined by the formula

$$
\newcommand{\gAa}{{[\U (\F \circ [\Ut \gU - \gUt \U]) \S + (\I_m - \U \Ut) \gU \S^{-1} ] \Vt}}
\newcommand{\gAb}{{\U (\I_k \circ \gS ) \Vt}}
\newcommand{\gAc}{{\U [\S (\F \circ [\Vt \gV - \gVt \V]) \Vt + \S^{-1} \gVt (\I_n - \V \Vt)]}}
\gA = \gAa \newline + \gAb \newline + \gAc
$$

In [None]:
def svd_vjp(A, U, S, Vh, gU, gS, gVh):
    pass

In [None]:
A = np.array([[1.,2.], [5., 2.], [10., 23.], [10., 23.]])

U, S, Vh = np.linalg.svd(A)

print(U)
print(S)
print(Vh)

Sources:

SVD real differentiability:
- https://j-towns.github.io/papers/svd-derivative.pdf
- https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
- https://arxiv.org/pdf/1509.07838.pdf

SVD complex differentiability:
- https://arxiv.org/pdf/1909.02659.pdf
- https://giggleliu.github.io/2019/04/02/einsumbp.html

Existing implementations:

Jax forward:
- https://github.com/google/jax/blob/2a00533e3e686c1c9d7dfe9ed2a3b19217cfe76f/jax/_src/lax/linalg.py#L1578
- Jax only implements the forward rule because jax can derive the backward rule from the forward rule and vice versa.

Pytorch forward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3122

Pytorch backward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3228

Tensorflow forward:
- https://github.com/tensorflow/tensorflow/blob/bbe41abdcb2f7e923489bfa21cfb546b6022f330/tensorflow/python/ops/linalg_grad.py#L815

General complex differentiability:
- https://mediatum.ub.tum.de/doc/631019/631019.pdf

In [None]:
import jax.numpy as jnp
import jax

def f(A, B):
    return jnp.trace(A @ B)

A = jnp.array([[1, 2], [3, 4]], dtype=float)
B = jnp.array([[1,2,3], [4,5,6]], dtype=float)

print(f(A, B))

jax.jacfwd(f, argnums=(0, 1))(A, B)

In [None]:
from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
    return jnp.dot(x, x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return f(x), 2 * x @ t

f.defjvp(f_jvp)

In [None]:
jax.grad(f)(jnp.array([1.,2.,4.]))

In [None]:
jax.make_jaxpr(jax.vjp(f, jnp.array([1.,2.,3.]))[1])(1.)

In [None]:
jax.make_jaxpr(jax.vjp(lambda x: x @ x, jnp.array([1.,2.,3.]))[1])(1.)

In [None]:
jax.jacrev(f)(jnp.array([1.,2.,3.]))

In [None]:
jax.jacrev(lambda x: x @ x)(jnp.array([1., 2., 3.]))

In [None]:
2 * jnp.array([1.,2.])