In [74]:
import jax.numpy as jnp
from jax import lax
import jax
import torch
import torch.autograd.functional as F
torch_jvp = F.jvp

## Cases
$$
\newcommand{\d}{{\mathrm{d}}}
\newcommand{\A}{{\mathbf{A}}}
\newcommand{\U}{{\mathbf{U}}}
\newcommand{\S}{{\mathbf{S}}}
\newcommand{\V}{{\mathbf{V}}}
\newcommand{\F}{{\mathbf{F}}}
\newcommand{\I}{{\mathbf{I}}}
\newcommand{\dA}{{\d\A}}
\newcommand{\dU}{{\d\U}}
\newcommand{\dS}{{\d\S}}
\newcommand{\dV}{{\d\V}}
\newcommand{\Ut}{{\U^{\top}}}
\newcommand{\Vt}{{\V^{\top}}}
\newcommand{\Vh}{{\V^{H}}}
\newcommand{\dAt}{{\dA^{\top}}}
\newcommand{\dVt}{{\dV^{\top}}}
\newcommand{\gA}{{\overline{\A}}}
\newcommand{\gU}{{\overline{\U}}}
\newcommand{\gUt}{{\gU^{\top}}}
\newcommand{\gS}{{\overline{\S}}}
\newcommand{\gSt}{{\gS^{\top}}}
\newcommand{\gV}{{\overline{\V}}}
\newcommand{\gVt}{{\gV^{\top}}}
$$
The derivative of the SVD operation is determined by

1. Computing the full SVD vs the "thin"/"partial" SVD
2. Computing the complete factorization $\U\S\Vh$ vs computing just the singular values $\S$
3. complex vs real inputs

These cases create 8 different cases for the SVD derivative each with separate differential formulas for forward mode AD update and adjoint formula for the reverse mode AD update.

## Numerical instability
TODO

## Real valued partial SVD

Reference: https://j-towns.github.io/papers/svd-derivative.pdf

### Forward mode

The differential formulas $\dU$, $\dS$, and $\dV$ in terms of $\dA$, $\U$, $\S$, and $\V$ are found from the standard chain rule (TODO check this from the differential formula).

Standard chain rule

$\dA = \dU \S \Vt + \U \dS \Vt + \U \S \dVt$

Differential formulas

$\dU = \U ( \F \circ [\Ut \dA \V \S + \S \Vt \dAt \U] ) + (\I_m - \U \Ut ) \dA \V \S^{-1}$

$\dS = \I_k \circ [\Ut \dA \V]$

$\dV = \V (\F \circ [\S \Ut \dA \V + \Vt \dAt \U \S]) + (\I_n - \V \Vt) \dAt \U \S^{-1}$

where

$F_{ij} = \frac{1}{s_j^2 - s_i^s}, i \neq j$

$F_{ij} = 0$ otherwise

In [66]:
def svd_jvp_real_valued_partial(A, dA):
    # TODO add dimension check
    
    U, S_vals, Vt = jnp.linalg.svd(A, compute_uv=True, full_matrices=False)

    S = jnp.diag(S_vals)
    Ut = U.T
    V = Vt.T
    dAt = dA.T

    k = S.shape[0]
    m = U.shape[0]
    n = Vt.shape[0]

    I_k = jnp.eye(k)
    I_m = jnp.eye(m)
    I_n = jnp.eye(n)

    S_inv = jnp.linalg.inv(S)

    F_i_j = lambda i, j: lax.cond(i == j, lambda: 0., lambda: S_vals[i] * S_vals[j])
    F_fun = jax.vmap(jax.vmap(F_i_j, (0, None)), (None, 0))

    indices = jnp.arange(k)
    F = F_fun(indices, indices)

    dU = U @ (F * (Ut @ dA @ V @ S + S @ Vt @ dAt @ U)) + (I_m - U @ Ut) @ dA @ V @ S_inv
    dS = I_k * (U.T @ dA @ V)
    dV = V @ (F * (S @ Ut @ dA @ V + Vt @ dAt @ U @ S)) + (I_n - V @ Vt) @ dAt @ U @ S_inv
    
    dS_vals = jnp.diagonal(dS)
    
    return (U, S_vals, Vt), (dU, dS_vals, dV)

In [78]:
torch.tensor([[1., 2., 3.], [5., 4., 3.], [10., 15., 12.]])

tensor([[ 1.,  2.,  3.],
        [ 5.,  4.,  3.],
        [10., 15., 12.]])

In [82]:
def check_real_valued_partial(A, dA):
    (U, S, Vt), (dU, dS, dVt) = jax.jvp(
        lambda A: jnp.linalg.svd(A, compute_uv=True, full_matrices=False), 
        (A,), 
        (dA,)
    )
    (U_, S_, Vt_), (dU_, dS_, dVt_) = svd_jvp_real_valued_partial(A, dA)
    
    print(dU)
    print(dU_)
    
    assert(jnp.allclose(U, U_))
    assert(jnp.allclose(S, S_))
    assert(jnp.allclose(Vt, Vt_))
    assert(jnp.allclose(dU, dU_))
    assert(jnp.allclose(dS, dS_))
    assert(jnp.allclose(dVt, dVt_))

A_ = [[1., 2., 3.], [5., 4., 3.], [10., 15., 12.]]

A_jax = jnp.array(A_)
dA_jax = jnp.ones_like(A)

A_torch = torch.tensor(A_)
dA_torch = torch.ones_like(A_torch)

torch_jvp(lambda A: torch.linalg.svd(A, full_matrices=False), A_torch, dA_torch)

# check_real_valued_partial(A, dA)

((tensor([[-0.1524,  0.4672, -0.8709],
          [-0.2956, -0.8624, -0.4109],
          [-0.9431,  0.1949,  0.2695]]),
  tensor([22.9576,  2.2562,  0.9267]),
  tensor([[-0.4818, -0.6810, -0.5515],
          [-0.8404,  0.1808,  0.5110],
          [-0.2483,  0.7096, -0.6594]])),
 (tensor([[-0.0589, -0.0750, -0.0299],
          [-0.0451, -0.0318,  0.0992],
          [ 0.0237,  0.0390,  0.0545]]),
  tensor([2.3847, 0.0298, 0.2004]),
  tensor([[-0.0126,  0.0126, -0.0045],
          [-0.0083,  0.0456, -0.0298],
          [ 0.0527,  0.0005, -0.0193]])))

In [61]:
jnp.allclose

<CompiledFunction of <function allclose at 0x7facef2aa1f0>>

SVD's reverse mode autodiff is defined by the formula

$$
\newcommand{\gAa}{{[\U (\F \circ [\Ut \gU - \gUt \U]) \S + (\I_m - \U \Ut) \gU \S^{-1} ] \Vt}}
\newcommand{\gAb}{{\U (\I_k \circ \gS ) \Vt}}
\newcommand{\gAc}{{\U [\S (\F \circ [\Vt \gV - \gVt \V]) \Vt + \S^{-1} \gVt (\I_n - \V \Vt)]}}
\gA = \gAa \newline + \gAb \newline + \gAc
$$

In [3]:
def svd_vjp(A, U, S, Vh, gU, gS, gVh):
    pass

In [27]:
A = np.array([[1.,2.], [5., 2.], [10., 23.], [10., 23.]])

U, S, Vh = np.linalg.svd(A)

print(U)
print(S)
print(Vh)

[[-0.06249881 -0.0248252  -0.70550606 -0.70550606]
 [-0.10831214 -0.99361751  0.02227914  0.02227914]
 [-0.70155626  0.07780729  0.52970552 -0.47029448]
 [-0.70155626  0.07780729 -0.47029448  0.52970552]]
[35.74646973  3.76694863]
[[-0.40941623 -0.91234772]
 [-0.91234772  0.40941623]]


Sources:

SVD real differentiability:
- https://j-towns.github.io/papers/svd-derivative.pdf
- https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
- https://arxiv.org/pdf/1509.07838.pdf

SVD complex differentiability:
- https://arxiv.org/pdf/1909.02659.pdf
- https://giggleliu.github.io/2019/04/02/einsumbp.html

Existing implementations:

Jax forward:
- https://github.com/google/jax/blob/2a00533e3e686c1c9d7dfe9ed2a3b19217cfe76f/jax/_src/lax/linalg.py#L1578
- Jax only implements the forward rule because jax can derive the backward rule from the forward rule and vice versa.

Pytorch forward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3122

Pytorch backward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3228

Tensorflow forward:
- https://github.com/tensorflow/tensorflow/blob/bbe41abdcb2f7e923489bfa21cfb546b6022f330/tensorflow/python/ops/linalg_grad.py#L815

General complex differentiability:
- https://mediatum.ub.tum.de/doc/631019/631019.pdf

In [35]:
import jax.numpy as jnp
import jax

def f(A, B):
    return jnp.trace(A @ B)

A = jnp.array([[1, 2], [3, 4]], dtype=float)
B = jnp.array([[1,2,3], [4,5,6]], dtype=float)

print(f(A, B))

jax.jacfwd(f, argnums=(0, 1))(A, B)

35.0


(DeviceArray([[1., 4.],
              [2., 5.]], dtype=float32),
 DeviceArray([[1., 3., 0.],
              [2., 4., 0.]], dtype=float32))

In [93]:
from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
    return jnp.dot(x, x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return f(x), 2 * x @ t

f.defjvp(f_jvp)

<function __main__.f_jvp(primals, tangents)>

In [94]:
jax.grad(f)(jnp.array([1.,2.,4.]))

DeviceArray([2., 4., 8.], dtype=float32)

In [95]:
jax.make_jaxpr(jax.vjp(f, jnp.array([1.,2.,3.]))[1])(1.)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3][39m; b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
  [34m[22m[1min [39m[22m[22m(c,) }

In [96]:
jax.make_jaxpr(jax.vjp(lambda x: x @ x, jnp.array([1.,2.,3.]))[1])(1.)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3][39m; b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
    d[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
    e[35m:f32[3][39m = add_any d c
  [34m[22m[1min [39m[22m[22m(e,) }

In [99]:
jax.jacrev(f)(jnp.array([1.,2.,3.]))

DeviceArray([2., 4., 6.], dtype=float32)

In [100]:
jax.jacrev(lambda x: x @ x)(jnp.array([1., 2., 3.]))

DeviceArray([2., 4., 6.], dtype=float32)

In [90]:
2 * jnp.array([1.,2.])

DeviceArray([2., 4.], dtype=float32)