In [3]:
from jax import custom_vjp
import jax.numpy as jnp
from jax import grad, jit, vmap, ops
from jax import random
import numpy as np
import jax
from jax.scipy.linalg import expm
from jax.scipy import linalg


In [112]:

@custom_vjp
def eig(mat):
    return jnp.linalg.eig(mat)


def eig_fwd(mat):
    res = jnp.linalg.eig(mat)
    return res,res

def eig_bwd(res, g):
    e,u = res
    n = e.shape[0]
    ge, gu = g
    ge = jnp.diag(ge)

    f = e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20
    diag_elements = jnp.diag_indices_from(f)
    f = jax.ops.index_update(f, diag_elements, jnp.inf)
    f= 1./f

    ut = jnp.swapaxes(u, -1, -2)
    r1 = f * jnp.dot(ut, gu)
    r2 = -f * (jnp.dot(jnp.dot(ut, jnp.conj(u)), jnp.real(jnp.dot(ut,gu)) * jnp.eye(n)))
    r = jnp.dot(jnp.dot(jnp.linalg.inv(ut), ge + r1 + r2), ut)
    r = jnp.real(r)
    # the derivative is still complex for real input (imaginary delta is allowed), real output
    # but the derivative should be real in real input case when imaginary delta is forbidden
    return (r,)

eig.defvjp(eig_fwd, eig_bwd)

In [95]:
@custom_vjp
def eig(mat):
    return jnp.linalg.eigh(mat)


def eig_fwd(mat):
    res = jnp.linalg.eigh(mat)
    return res,res

def eig_bwd(res, g):
    """Gradient of a general square (complex valued) matrix"""
    e, u = res # eigenvalues as 1d array, eigenvectors in columns
    n = e.shape[-1]
    ge, gu = g
    ge = jnp.diag(ge)

    f = e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20
    diag_elements = jnp.diag_indices_from(f)
    f = jax.ops.index_update(f, diag_elements, 1.)
    f= 1./f

    ut = jnp.swapaxes(u, -1, -2)
    r =  u@(jnp.diag(ge)+f*(ut@gu-(ut@gu).T)/2)@ut
    return (r,)

eig.defvjp(eig_fwd, eig_bwd)

In [113]:
eig= jit(eig, backend='cpu')

def loss(mat):
    mat = mat.T + mat
    e,u = eig(mat)
    return jnp.real(jnp.conjugate(e[0]) * e[0])

In [114]:
key = random.PRNGKey(5)
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))

W_dirderiv_numerical 18.463135
W_dirderiv_autodiff 17.493876


In [115]:
def loss(mat):
    mat = mat.T + mat
    e,u = eig(mat)
    inds = jnp.argsort(e)
    u1 = u[:,inds[0]]
    u2 = u[:,inds[1]]
    return jnp.real(jnp.conjugate(u1).T @ u2)

In [116]:

key = random.PRNGKey(1)
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
#print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))


W_dirderiv_numerical 0.00070780516
W_dirderiv_autodiff -4.2068127e-08
