In [3]:
from jax import custom_vjp
import jax.numpy as jnp
from jax import grad, jit, vmap, ops
from jax import random
import numpy as np
import jax
from jax.scipy.linalg import expm
from jax.scipy import linalg


In [150]:

@custom_vjp
def eig(mat):
    return jnp.linalg.eig(mat)


def eig_fwd(mat):
    res = jnp.linalg.eig(mat)
    return res,res

def eig_bwd(res, g):
    e,u = res
    n = e.shape[0]
    ge, gu = g
    ge = jnp.diag(ge)

    f = e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20
    diag_elements = jnp.diag_indices_from(f)
    f = jax.ops.index_update(f, diag_elements, jnp.inf)
    f= 1./f

    #f = 1/(e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20)
    #f -= jnp.diag(f)


    ut = jnp.swapaxes(u, -1, -2)
    r1 = f * jnp.dot(ut, gu)
    r2 = -f * (jnp.dot(jnp.dot(ut, jnp.conj(u)), jnp.real(jnp.dot(ut,gu)) * jnp.eye(n)))
    r = jnp.dot(jnp.dot(jnp.linalg.inv(ut), ge + r1 + r2), ut)
    r = jnp.real(r)
    # the derivative is still complex for real input (imaginary delta is allowed), real output
    # but the derivative should be real in real input case when imaginary delta is forbidden
    return (r,)

eig.defvjp(eig_fwd, eig_bwd)

In [158]:
eig= jit(eig, backend='cpu')

def loss(mat):
    mat = mat.T + mat
    e,u = eig(mat)
    return jnp.real(jnp.conjugate(e[0]) * e[0])

In [161]:


key = random.PRNGKey(23)
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
nu =  W_grad_numerical
ad =  jnp.vdot(grad(loss)(mat), unitvec)
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))

W_dirderiv_numerical 145452.72
W_dirderiv_autodiff 2.5335462


In [153]:
def loss(mat):
    e,u = eig(mat)
    inds = jnp.argsort(e)
    u1 = u[:,inds[0]]
    u2 = u[:,inds[1]]
    return jnp.real(jnp.conjugate(u1).T @ u2)

In [154]:

key = random.PRNGKey(1)
# Set a step size for finite differences calculations
eps = 1e-4

key, subkey = random.split(key)
mat = random.normal(subkey, (5,5))
unitvec = mat / jnp.sqrt(jnp.vdot(mat, mat))
W_grad_numerical = (loss(mat + eps / 2. * unitvec) - loss(mat - eps / 2. * unitvec)) / eps
#print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(mat), unitvec))


W_dirderiv_numerical -0.00059604645
W_dirderiv_autodiff 7.674403e-09


In [132]:
eig(mat)

[DeviceArray([-3.1788244+0.j       ,  0.5754158+0.7541545j,
               0.5754158-0.7541545j, -1.7028826+0.j       ,
              -0.9303485+0.j       ], dtype=complex64),
 DeviceArray([[-0.57083344+0.j        , -0.49249107+0.03827315j,
               -0.49249107-0.03827315j,  0.09055754+0.j        ,
                0.07526224+0.j        ],
              [-0.05209664+0.j        ,  0.4238208 +0.24178955j,
                0.4238208 -0.24178955j, -0.28192702+0.j        ,
                0.1292539 +0.j        ],
              [ 0.5457766 +0.j        ,  0.06265222-0.04868248j,
                0.06265222+0.04868248j,  0.74575   +0.j        ,
                0.60991126+0.j        ],
              [-0.5359167 +0.j        ,  0.69301116+0.j        ,
                0.69301116-0.j        , -0.1301903 +0.j        ,
                0.6920431 +0.j        ],
              [-0.29386458+0.j        , -0.08337508-0.15617388j,
               -0.08337508+0.15617388j, -0.5824293 +0.j        ,
          

In [None]:
@custom_vjp
def eig(mat):
    return jnp.linalg.eigh(mat)


def eig_fwd(mat):
    res = jnp.linalg.eigh(mat)
    return res,res

def eig_bwd(res, g):
    """Gradient of a general square (complex valued) matrix"""
    e, u = res # eigenvalues as 1d array, eigenvectors in columns
    n = e.shape[-1]
    ge, gu = g
    ge = jnp.diag(ge)

    f = e[..., jnp.newaxis, :] - e[..., :, jnp.newaxis] + 1.e-20
    diag_elements = jnp.diag_indices_from(f)
    #f = jax.ops.index_update(f, diag_elements, 1.)
    f= 1./f

    ut = jnp.swapaxes(u, -1, -2)
    r =  u@(jnp.diag(ge)+f*(ut@gu-(ut@gu).T)/2)@ut
    return (r,)

eig.defvjp(eig_fwd, eig_bwd)