# Computation of the gradient of the Spectral Radius in a `NaN` -safe way

In [1]:
from jax import numpy as jnp
import jax
import numpy as np
import scipy

First we look a small matrix with a few elements. 
Note that depending on the weights `weights_normal` or `weights_nan` the gradient of the largest eigenvalue with respect to the input parameters is either solved by jax or reported as nan.


In [2]:

ind = jnp.array([[2,0,1,1,0,3],[3,1,2,0,3,1]])

M = jnp.zeros([4,4])

#weights = jnp.array([1., 2., 3., 4.])
weights_normal = jnp.array([1.,1.,1.,1.,1.,1.])
weights_nan = jnp.array([1.,1.,0.,0.,1.,1.])

M = M.at[ind[0], ind[1]].set(weights_normal)
M

DeviceArray([[0., 1., 0., 1.],
             [1., 0., 1., 0.],
             [0., 0., 0., 1.],
             [0., 1., 0., 0.]], dtype=float32)

In [3]:
def spectralRadius(w,ind,M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    r = jnp.max(jnp.abs(e_val))
    return r

First we compute teh largest eigenvalue and gradient: results in normal numerical values:

In [4]:
jax.value_and_grad(spectralRadius,argnums=0)(weights_normal,ind,M)

(DeviceArray(1.5213797, dtype=float32),
 DeviceArray([0.16824293, 0.25596112, 0.1682428 , 0.42420426, 0.16824281,
              0.3364857 ], dtype=float32))

Then we use the `weights_nan` and notice that the gradient is `nan`:

In [5]:
jax.value_and_grad(spectralRadius,argnums=0)(weights_nan,ind,M)

(DeviceArray(0., dtype=float32),
 DeviceArray([nan, nan, nan, nan, nan, nan], dtype=float32))

Now we redefine the `jvp` jacobian-vector product function to give null gradient is the gradient is `nan`

In [6]:

@jax.custom_jvp
def spectralRadiusZeroNaNGrad(w, ind, M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    r = jnp.max(jnp.abs(e_val))
    return r

def spectralRadiusZeroNaNGrad_aux(w, ind, M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    r = jnp.max(jnp.abs(e_val))

    return r

@spectralRadiusZeroNaNGrad.defjvp
def spectralRadiusZeroNaNGrad_jvp(primals, tangents):
  w, ind, M = primals
  w_dot, ind_dot, M_dot  = tangents
  primal_out = spectralRadiusZeroNaNGrad(w, ind, M )
  rep_values = jnp.nan_to_num(jax.grad(spectralRadiusZeroNaNGrad_aux)(w,ind,M))
  tangent_out = jnp.dot(rep_values,w_dot)

  return primal_out, tangent_out

Run the two examples:

In [7]:
jax.value_and_grad(spectralRadiusZeroNaNGrad,argnums=0)(weights_normal,ind,M)

(DeviceArray(1.5213797, dtype=float32),
 DeviceArray([0.16824293, 0.25596112, 0.1682428 , 0.42420426, 0.16824281,
              0.3364857 ], dtype=float32))

In [8]:
jax.value_and_grad(spectralRadiusZeroNaNGrad,argnums=0)(weights_nan,ind,M)

(DeviceArray(0., dtype=float32),
 DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32))

# Testing real matrix from synthetic computation

In [37]:
M = np.loadtxt("M.csv", delimiter=",")
M = jnp.array(M)
ind = np.loadtxt("ind.csv", delimiter=",").astype(np.int32)
w = np.loadtxt("w.csv", delimiter=",")
w = jnp.array(w)

In [38]:
all(M[ind[0],ind[1]] == w)

True

In [18]:

def spectralRadiusN2(w, ind, M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    r = jnp.linalg.norm(e_val)
    return r


def myeig(w, ind, M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    
    return e_val

In [40]:
v,g = jax.value_and_grad(spectralRadiusN2,argnums=0)(w,ind,M)
print(v)
g[1:10]

1.3774141


DeviceArray([nan, nan, nan, nan, nan, nan, nan, nan, nan], dtype=float32)

In [57]:
e = myeig(w,ind,M)

In [43]:
J = jax.jacfwd(myeig)(w,ind,M)
J.shape

(409, 832)

In [47]:
jnp.isnan(J).all()

DeviceArray(True, dtype=bool)

Since all the jacobian are `nan`s we cannot do much with it

# Avlant's code



In [28]:
def lreig(A):
    # for DENSE Matrix
    #fall back if eigs fails
    e, w, v = scipy.linalg.eig(A, left = True)
    selected = np.argmax(np.abs(e))
    eValue = e[selected]
    # selected = (e == eValue)

    # if numpy.sum(selected) == 1:
    w = w[:,selected]
    v = v[:,selected]
    # else:
    #     w = numpy.sum(w[:,selected], axis=1, keepdims=True)
    #     v = numpy.sum(v[:,selected], axis=1, keepdims=True)
    #     w = w/norm(w)
    #     v = v/norm(v)
    return eValue, v, w


In [78]:
# A: sparse matrix
@jax.custom_jvp
def AvlantsSpectralRadius(weights, ind, M):
    
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    tolerance = 10**-6

    try:
        e, v = scipy.sparse.linalg.eigs(A, k=1, which='LM', ncv=100, tol = tolerance)
        v = v[:,0]
        e = e[0]
    except  (KeyboardInterrupt, SystemExit):
        raise
    except:
        print('Forward fail (did not find any eigenvalue with eigs)')
        tmpA = A.toarray()
        e, v, w = lreig(tmpA) #fall back to solving full eig problem

    spectralRadius = np.abs(e)
    #ctx.e = e
    #ctx.v = v
    #ctx.w = np.empty(0)

    return spectralRadius

def AvlantsSpectralRadius_save_out(weights, ind, M):
    
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    tolerance = 10**-6

    try:
        e, v = scipy.sparse.linalg.eigs(A, k=1, which='LM', ncv=100, tol = tolerance)
        v = v[:,0]
        e = e[0]
    except  (KeyboardInterrupt, SystemExit):
        raise
    except:
        print('Forward fail (did not find any eigenvalue with eigs)')
        tmpA = A.toarray()
        e, v, w = lreig(tmpA) #fall back to solving full eig problem

    spectralRadius = np.abs(e)
    e = e
    v = v
    w = np.empty(0)

    return spectralRadius, e, v, w

@AvlantsSpectralRadius.defjvp
def AvlantsSpectralRadius_jvp(primals, tangents):
    weights, ind, M = primals
    w_dot, ind_dot, M_dot  = tangents
    primal_out, e, v, w = AvlantsSpectralRadius_save_out(weights, ind, M)

    tolerance = 10**-6
    networkList = ind
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    
    tmpA = A
    tmpA = tmpA.T  #tmpA.T.toarray()

    if w.shape[0]==0:
        try:
            eT = e
            if np.isreal(eT): #does for some reason not converge if imag = 0
                eT = eT.real
            e2, w = scipy.sparse.linalg.eigs(tmpA, k=1, sigma=eT, OPpart='r', tol=tolerance)
            selected = 0 #numpy.argmin(numpy.abs(e2-eT))
            w = w[:,selected]
            e2 = e2[selected]
            #Check if same eigenvalue
            if abs(e-e2)>(tolerance*10):
                print('Backward fail (eigs left returned different eigenvalue)')
                w = np.empty(0)
                #e, v, w = lreig(tmpA) #fall back to solving whole eig problem
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            print('Backward fail (did not find any eigenvalue with eigs)')
            #e, v, w = lreig(tmpA) #fall back to solving full eig problem
            delta = np.zeros(weights.shape)


    if w.shape[0] != 0:
        divisor = w.T.dot(v).flatten()
        if abs(divisor) == 0:
            delta = np.zeros(weights.shape)
            print('Empty eig')
        else:
            delta = np.multiply(w[networkList[0]], v[networkList[1]])/divisor
            direction = e/np.abs(e)
            delta = (delta/direction).real
    else:
        #print('Empty eig')
        delta = np.zeros(weights.shape)

    #deltaFilter = numpy.not_equal(numpy.sign(delta), numpy.sign(ctx.weights))
    #delta[deltaFilter] = 0

    #delta = torch.tensor(delta, dtype = grad_output.dtype)

    constrainNorm = True
    if constrainNorm:
        norm = np.linalg.norm(delta)
        if norm>10:
            delta = delta/norm #typical seems to be ~0.36
        #delta = delta * numpy.abs(ctx.weights)
        #delta = delta/norm(delta)

    tangent_out = jnp.dot(delta,w_dot)

    return primal_out, tangent_out

In [None]:
# %%timeit 
jax.value_and_grad(spectralRadius,argnums=0)(weights_normal,ind,M)

(DeviceArray(5.6821356, dtype=float32),
 DeviceArray([0.20256595, 0.25577933, 0.06752197, 0.46549958, 0.270088  ,
              0.25883427], dtype=float32))

Compute gradient with finite differences:

In [None]:
# %%timeit 
NP_cfd_spectralRadius(weights_normal,ind,np.array(M))

array([0.20253658, 0.25582314, 0.06754398, 0.46544075, 0.27008057,
       0.25873184])

In [None]:
# %%timeit 
cfd_spectralRadius(weights_normal,ind,M)

array([0.20256042, 0.25577545, 0.06744862, 0.46527386, 0.2699852 ,
       0.25897026])

In [80]:
#AvlantsSpectralRadius(weights_normal, ind, M)
jax.value_and_grad(AvlantsSpectralRadius,argnums=0)(weights_normal, ind, M)




TypeError: Incompatible shapes for dot: got (6, 1) and (6,).

In [17]:
def getSpecRad(A):
    tolerance = 10**-6
    try:
        e, v = scipy.sparse.linalg.eigs(A, k=1, which='LM', ncv=100, tol = tolerance)
        v = v[:,0]
        e = e[0]
    except  (KeyboardInterrupt, SystemExit):
        raise
    except:
        print('Forward fail (did not find any eigenvalue with eigs)')
        tmpA = A.toarray()
        e, v, w = lreig(tmpA) #fall back to solving full eig problem
    return np.abs(e)

def NP_cfd_spectralRadius(w,ind,M):

    dw = 0.01
    A = M
    
    A[ind[0],ind[1]] = w
    
    grad = np.zeros(w.shape)

    for i in range(len(w)):
        w_working = w.copy()
        w_working = w_working.at[i].set(w_working[i] + dw)
        A[ind[0],ind[1]] = w_working

        e_forward = getSpecRad(A)
        
        w_working = w.copy()
        w_working = w_working.at[i].set(w_working[i] - dw)
        A[ind[0],ind[1]] = w_working

        e_backward = getSpecRad(A)
        grad[i] = (e_forward - e_backward)/(2*dw)
        
    return grad

In [18]:
def cfd_spectralRadius(w,ind,M):

    dw = 0.01
    A = scipy.sparse.csr_matrix((w, ind), shape=M.shape)

    A[ind[0],ind[1]] = w.copy()
    
    grad = np.zeros(w.shape)

    for i in range(len(w)):
        w_working = w.copy()
        w_working = w_working.at[i].set(w_working[i] + dw)
        A[ind[0],ind[1]] = w_working

        e_forward = getSpecRad(A)
        
        w_working = w.copy()
        w_working = w_working.at[i].set(w_working[i] - dw)
        A[ind[0],ind[1]] = w_working

        e_backward = getSpecRad(A)
        grad[i] = (e_forward - e_backward)/(2*dw)
        
    return grad


In [59]:

ind = jnp.array([[2,0,1,1,0,3],[3,1,2,0,3,1]])
M = jnp.zeros([4,4])

weights_normal = jnp.array([1.,2.,3.,4.,5.,6.])

M = M.at[ind[0], ind[1]].set(weights_normal)

def spectralRadius(w,ind,M):
    M = M.at[ind[0], ind[1]].set(w)
    e_val = jnp.linalg.eigvals(M)
    r = jnp.max(jnp.abs(e_val))
    return r

def NPspectralRadius(M):
    
    e_val = np.linalg.eigvals(M)
    r = np.max(np.abs(e_val))
    return r

M

DeviceArray([[0., 2., 0., 5.],
             [4., 0., 3., 0.],
             [0., 0., 0., 1.],
             [0., 6., 0., 0.]], dtype=float32)

Compute the gradient via JAX:

In [62]:
# %%timeit 
jax.value_and_grad(spectralRadius,argnums=0)(weights_normal,ind,M)

(DeviceArray(5.6821356, dtype=float32),
 DeviceArray([0.20256595, 0.25577933, 0.06752197, 0.46549958, 0.270088  ,
              0.25883427], dtype=float32))

Compute gradient with finite differences:

In [63]:
# %%timeit 
NP_cfd_spectralRadius(weights_normal,ind,np.array(M))

array([0.20253658, 0.25582314, 0.06754398, 0.46544075, 0.27008057,
       0.25873184])

In [64]:
# %%timeit 
cfd_spectralRadius(weights_normal,ind,M)

array([0.20256042, 0.25577545, 0.06744862, 0.46527386, 0.2699852 ,
       0.25897026])

In [65]:
# %%timeit 
jax.value_and_grad(AvlantsSpectralRadius,argnums=0)(weights_normal, ind, M)

(5.682134,
 DeviceArray([0.20256588, 0.2557794 , 0.06752203, 0.46549955, 0.27008793,
              0.25883427], dtype=float32))

Time the real case

In [54]:
M = np.loadtxt("M.csv", delimiter=",")
M_large = jnp.array(M)
ind_large = np.loadtxt("ind.csv", delimiter=",").astype(np.int32)
w = np.loadtxt("w.csv", delimiter=",")
w_large = jnp.array(w)

In [57]:
# %%timeit 
v,g = jax.value_and_grad(spectralRadius,argnums=0)(w_large,ind_large,M_large)
# print(v)
# print(g[0])
# 0.37025064
# nan
# 38.3 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Compute gradient with finite differences:

In [44]:
# %%timeit 
g = NP_cfd_spectralRadius(w_large,ind_large,np.array(M_large))
# g[1:5]
# array([ 2.68220901e-05,  3.57627869e-05, -1.63912773e-05,  2.08616257e-05])
# 30.5 s ± 1.25 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

30.5 s ± 1.25 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [84]:
#%%timeit 
g = cfd_spectralRadius(w_large,ind_large,M_large)
g[1:5]
# array([-1.78813934e-05,  5.96046448e-06,  1.49011612e-06, -8.94069672e-05])

# 25 sec

array([-5.96046448e-06, -2.23517418e-05, -1.49011612e-06,  1.78813934e-05])

In [88]:
%%timeit 
v,g2 = jax.value_and_grad(AvlantsSpectralRadius,argnums=0)(w_large, ind_large, M_large)
#print(v)
#print(g2[1:5])

21.4 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


The jax and CFD based gradient are very similar:

In [90]:
jnp.linalg.norm(g-g2)/jnp.linalg.norm(0.5*(g+g2))

DeviceArray(0.00168966, dtype=float32)

In [103]:
%%timeit
A = scipy.sparse.csr_matrix((w_large, ind_large), shape=M_large.shape, dtype='float32')


363 µs ± 4.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [94]:
%%timeit
B = scipy.sparse.csr_matrix(M_large)


1.65 ms ± 84.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [101]:
(A.todense() == M_large).all()

DeviceArray(True, dtype=bool)

In [102]:
(B.todense() == M_large).all()

DeviceArray(True, dtype=bool)