In [1]:
from jax import custom_vjp
import jax.numpy as jnp
from jax import grad, jit, vmap, ops
from jax import random
import numpy as np
import jax
from jax.scipy.linalg import expm
from jax.scipy import linalg
key = random.PRNGKey(0)

In [3]:

@custom_vjp
def eig(mat):
    return jnp.linalg.eig(mat)


def eig_fwd(mat):
    res = jnp.linalg.eig(mat)
    return res,res

def eig_bwd(res, g):
    """Gradient of a general square (complex valued) matrix"""
    e, u = res # eigenvalues as 1d array, eigenvectors in columns
    n = e.shape[-1]
    ge, gu = g
    ge = jnp.diag(ge)
    f = 1/(e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20)
    f -= jnp.diag(f)
    ut = jnp.swapaxes(u, -1, -2)
    r = jnp.linalg.inv(ut)@(ge+f*(ut@gu) - f*(ut@jnp.conj(u)@(jnp.real(ut@gu)*jnp.eye(n))))@ut
    r = jnp.real(r)
    return (r,)

eig.defvjp(eig_fwd, eig_bwd)

In [4]:
eig= jit(eig, backend='cpu')

def loss(mat):
    e,u = eig(mat)
    return jnp.real(jnp.conjugate(e[0]) * e[0])

In [5]:
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))

W_dirderiv_numerical 1.502037
W_dirderiv_autodiff 1.4994137


In [9]:
def loss(mat):
    e,u = eig(mat)
    return jnp.real(jnp.conjugate(u[0,0]) * u[0,0])

In [10]:
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))

W_dirderiv_numerical 0.001527369
W_dirderiv_autodiff -242210590000.0
