In [23]:
import jax.numpy as jnp
from jax import lax
import jax
from jax import random

import torch
import torch.autograd.functional as TF
torch_jvp = TF.jvp

## Cases
$$
\newcommand{\tr}{{\mathrm{tr}}}
\newcommand{\d}{{\mathrm{d}}}
\newcommand{\A}{{\mathbf{A}}}
\newcommand{\U}{{\mathbf{U}}}
\newcommand{\S}{{\mathbf{S}}}
\newcommand{\V}{{\mathbf{V}}}
\newcommand{\F}{{\mathbf{F}}}
\newcommand{\I}{{\mathbf{I}}}
\newcommand{\dA}{{\d\A}}
\newcommand{\dU}{{\d\U}}
\newcommand{\dS}{{\d\S}}
\newcommand{\dV}{{\d\V}}
\newcommand{\Ut}{{\U^{\top}}}
\newcommand{\Vt}{{\V^{\top}}}
\newcommand{\Vh}{{\V^{H}}}
\newcommand{\dAt}{{\dA^{\top}}}
\newcommand{\dVt}{{\dV^{\top}}}
\newcommand{\gA}{{\overline{\A}}}
\newcommand{\gAt}{{\gA^{\top}}}
\newcommand{\gU}{{\overline{\U}}}
\newcommand{\gUt}{{\gU^{\top}}}
\newcommand{\gS}{{\overline{\S}}}
\newcommand{\gSt}{{\gS^{\top}}}
\newcommand{\gV}{{\overline{\V}}}
\newcommand{\gVt}{{\gV^{\top}}}
$$
The derivative of the SVD operation is determined by

1. Computing the full SVD vs the "thin"/"partial" SVD
2. Computing the complete factorization $\U\S\Vh$ vs computing just the singular values $\S$
3. complex vs real inputs

These cases create 8 different cases for the SVD derivative each with separate differential formulas for forward mode AD update and adjoint formula for the reverse mode AD update.

## Numerical instability
TODO

## Real valued partial SVD

Reference: https://j-towns.github.io/papers/svd-derivative.pdf

### Forward mode

The differential formulas $\dU$, $\dS$, and $\dV$ in terms of $\dA$, $\U$, $\S$, and $\V$ are found from the chain rule (TODO check this from the differential formula).

##### Chain rule

$\dA = \dU \S \Vt + \U \dS \Vt + \U \S \dVt$

##### Differential formulas

$\dU = \U ( \F \circ [\Ut \dA \V \S + \S \Vt \dAt \U] ) + (\I_m - \U \Ut ) \dA \V \S^{-1}$

$\dS = \I_k \circ [\Ut \dA \V]$

$\dV = \V (\F \circ [\S \Ut \dA \V + \Vt \dAt \U \S]) + (\I_n - \V \Vt) \dAt \U \S^{-1}$

where

$F_{ij} = \frac{1}{s_j^2 - s_i^s}, i \neq j$

$F_{ij} = 0$ otherwise

In [24]:
def svd_jvp_real_valued_partial(A, dA):    
    U, S_vals, Vt = jnp.linalg.svd(A, compute_uv=True, full_matrices=False)

    S = jnp.diag(S_vals)
    Ut = U.T
    V = Vt.T
    dAt = dA.T

    k = S.shape[0]
    m = A.shape[0]
    n = A.shape[1]

    I_k = jnp.eye(k)
    I_m = jnp.eye(m)
    I_n = jnp.eye(n)

    S_inv = jnp.linalg.inv(S)

    F_ = F(S_vals, k)

    dU = U @ (F_ * ((Ut @ dA @ V @ S) + (S @ Vt @ dAt @ U))) + ((I_m - U @ Ut) @ dA @ V @ S_inv)
    
    # Note that the `I_k *` is extraneous. It zeros out the rest of the matrix besides the diagonal.
    # We only return `dS_vals` which takes only the diagonal of `dS` anyway. 
    dS = I_k * (Ut @ dA @ V)
    dS_vals = jnp.diagonal(dS)
    
    dV = V @ (F_ * (S @ Ut @ dA @ V + Vt @ dAt @ U @ S)) + (I_n - V @ Vt) @ dAt @ U @ S_inv
    
    return (U, S_vals, Vt), (dU, dS_vals, dV.T)

def F(S_vals, k):
    F_i_j = lambda i, j: lax.cond(i == j, lambda: 0., lambda: 1 / (S_vals[j]**2 - S_vals[i]**2))
    F_fun = jax.vmap(jax.vmap(F_i_j, (None, 0)), (0, None))

    indices = jnp.arange(k)
    F_ = F_fun(indices, indices)
    
    return F_

In [25]:
def assert_svd_uv(a, b):
    (U, S, Vt), (dU, dS, dVt) = a
    (U_, S_, Vt_), (dU_, dS_, dVt_) = b
    
    assert_allclose(U, U_)
    assert_allclose(S, S_)
    assert_allclose(Vt, Vt_)
    assert_allclose(dU, dU_)
    assert_allclose(dS, dS_)
    assert_allclose(dVt, dVt_)

def assert_allclose(l, r):
    assert(jnp.allclose(jnp.array(l), jnp.array(r), atol=1e-05))
        
def jax_real_valued_partial(A, dA):
    A = jnp.array(A)
    dA = jnp.array(dA)
    
    return jax.jvp(
        lambda A: jnp.linalg.svd(A, compute_uv=True, full_matrices=False), 
        (A,), 
        (dA,)
    )

def pytorch_real_valued_partial(A, dA):
    A = torch.tensor(A)
    dA = torch.tensor(dA)
    
    return torch_jvp(lambda A: torch.linalg.svd(A, full_matrices=False), A, dA)
    
def check_real_valued_partial(A, dA):
    jax_res = jax_real_valued_partial(A, dA)
    torch_res = pytorch_real_valued_partial(A, dA)
    res = svd_jvp_real_valued_partial(jnp.array(A), jnp.array(dA))
    
    # example == jax == torch
    assert_svd_uv(jax_res, torch_res)
    assert_svd_uv(jax_res, res)

def check_real_valued_partials():
    key = random.PRNGKey(0)
    
    for m in range(3, 11):
        for n in range(3, 11):
            key, subkey = random.split(key)
            
            A = random.normal(subkey, (m, n))
            dA = jnp.ones_like(A)
            
            A = A.tolist()
            dA = dA.tolist()
            
            check_real_valued_partial(A, dA)


check_real_valued_partials()

### Reverse mode

##### Chain rule

$ \tr(\gAt \dA) = \tr(\gUt \dU) + \tr(\gSt \dS) + \tr(\gVt \dV) $

##### Gradient formula

The formula for A's gradient has terms $\mathrm{term}_U$, $\mathrm{term}_S$, and $\mathrm{term}_V$ found from the respective trace terms in the chain rule

$\gA = \mathrm{term}_U + \mathrm{term}_S + \mathrm{term}_V$

$\mathrm{term}_U = [\U (\F \circ [\Ut \gU - \gUt \U]) \S + (\I_m - \U \Ut) \gU \S^{-1} ] \Vt$

$\mathrm{term}_S = \U (\I_k \circ \gS ) \Vt$

$\mathrm{term}_V = \U [\S (\F \circ [\Vt \gV - \gVt \V]) \Vt + \S^{-1} \gVt (\I_n - \V \Vt)]$


In [26]:
def svd_vjp_real_valued_partial(A, U, S_vals, Vt, gU, gS_vals, gVt):
    S = jnp.diag(S_vals)
    gS = jnp.diag(gS_vals)
    
    k = S.shape[0]
    m = A.shape[0]
    n = A.shape[1]
    
    I_m = jnp.eye(m)
    I_k = jnp.eye(k)
    I_n = jnp.eye(n)
    
    V = Vt.T
    Ut = U.T
    gUt = gU.T
    gV = gVt.T
    
    S_inv = jnp.linalg.inv(S)
    
    F_ = F(S_vals, k)
    
    term_U = (U @ (F_ * (Ut @ gU - gUt @ U)) @ S + (I_m - U @ Ut) @ gU @ S_inv) @ Vt
    
    term_S = U @ (I_k * gS) @ Vt
    
    term_V = U @ (S @ (F_ * (Vt @ gV - gVt @ V)) @ Vt + S_inv @ gVt @ (I_n - V @ Vt))
    
    gA = term_U + term_S + term_V
    
    return gA

In [28]:
A = jnp.array([[1., 2., 3.], [5., 4., 6.], [10., 9., 8.]])

U, S, Vt = jnp.linalg.svd(A)

gU = jnp.ones_like(U)
gS = jnp.ones_like(S)
gVt = jnp.ones_like(Vt)

gA = svd_vjp_real_valued_partial(A, U, S, Vt, gU, gS, gVt)

gA

DeviceArray([[-0.604877  ,  0.4115927 ,  0.2906126 ],
             [ 0.369074  , -0.5303678 ,  0.9319923 ],
             [ 0.78237915,  0.78433084, -0.0146427 ]], dtype=float32)

Sources:

SVD real differentiability:
- https://j-towns.github.io/papers/svd-derivative.pdf
- https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
- https://arxiv.org/pdf/1509.07838.pdf

SVD complex differentiability:
- https://arxiv.org/pdf/1909.02659.pdf
- https://giggleliu.github.io/2019/04/02/einsumbp.html

Existing implementations:

Jax forward:
- https://github.com/google/jax/blob/2a00533e3e686c1c9d7dfe9ed2a3b19217cfe76f/jax/_src/lax/linalg.py#L1578
- Jax only implements the forward rule because jax can derive the backward rule from the forward rule and vice versa.

Pytorch forward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3122

Pytorch backward:
- https://github.com/pytorch/pytorch/blob/7a8152530d490b30a56bb090e9a67397d20e16b1/torch/csrc/autograd/FunctionsManual.cpp#L3228

Tensorflow forward:
- https://github.com/tensorflow/tensorflow/blob/bbe41abdcb2f7e923489bfa21cfb546b6022f330/tensorflow/python/ops/linalg_grad.py#L815

General complex differentiability:
- https://mediatum.ub.tum.de/doc/631019/631019.pdf

In [5]:
import jax.numpy as jnp
import jax

def f(A, B):
    return jnp.trace(A @ B)

A = jnp.array([[1, 2], [3, 4]], dtype=float)
B = jnp.array([[1,2,3], [4,5,6]], dtype=float)

print(f(A, B))

jax.jacfwd(f, argnums=(0, 1))(A, B)

35.0


(DeviceArray([[1., 4.],
              [2., 5.]], dtype=float32),
 DeviceArray([[1., 3., 0.],
              [2., 4., 0.]], dtype=float32))

In [6]:
from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
    return jnp.dot(x, x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
    x, = primals
    t, = tangents
    return f(x), 2 * x @ t

f.defjvp(f_jvp)

<function __main__.f_jvp(primals, tangents)>

In [7]:
jax.grad(f)(jnp.array([1.,2.,4.]))

DeviceArray([2., 4., 8.], dtype=float32)

In [8]:
jax.make_jaxpr(jax.vjp(f, jnp.array([1.,2.,3.]))[1])(1.)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3][39m; b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
  [34m[22m[1min [39m[22m[22m(c,) }

In [9]:
jax.make_jaxpr(jax.vjp(lambda x: x @ x, jnp.array([1.,2.,3.]))[1])(1.)

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f32[3][39m; b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
    d[35m:f32[3][39m = dot_general[
      dimension_numbers=(((), ()), ((), ()))
      precision=None
      preferred_element_type=None
    ] b a
    e[35m:f32[3][39m = add_any d c
  [34m[22m[1min [39m[22m[22m(e,) }

In [10]:
jax.jacrev(f)(jnp.array([1.,2.,3.]))

DeviceArray([2., 4., 6.], dtype=float32)

In [11]:
jax.jacrev(lambda x: x @ x)(jnp.array([1., 2., 3.]))

DeviceArray([2., 4., 6.], dtype=float32)

In [12]:
2 * jnp.array([1.,2.])

DeviceArray([2., 4.], dtype=float32)