In [58]:
import numpy as np 
from functools import reduce
from sps4lat import model as mod

ref_param = {'a':1.,'b':np.array([1.2,2.6]),'c':np.array([[1.,2.7],[7.8,2.3]])}
theta_list = [np.array(vv).flatten() for vv in ref_param.values()]

In [2]:
def pos_def(n):
    A = np.random.rand(10000,n,n)
    return np.einsum('lab,lcb->lac',A,A)


In [65]:
def KL_1(A,B):
    inv_A = np.linalg.inv(A)
    m = A.shape[1]
    sign,logdet = np.linalg.slogdet(np.einsum('lab,lbc->lac',inv_A,B)) 
    return .5*(np.einsum('lab,lba->l',inv_A,B) )

In [55]:
def KL_2(A,B):
    m = A.shape[1]
    chol_A = np.linalg.cholesky(A)
    chol_B = np.linalg.cholesky(B)
    z  = np.linalg.solve(chol_A, chol_B)
    k  = np.einsum('laa->l',z) 
    return .5*k

In [56]:
A = pos_def(10)
B = pos_def(10)

In [57]:
%%time
res_1 = KL_1(A,B)
print(res_1)

[ 67.94833834  28.06858598 210.22632242 ...  55.89697828 228.20550718
 276.64319614]
CPU times: user 170 ms, sys: 25.4 ms, total: 195 ms
Wall time: 107 ms


In [58]:
%%time
res_2 = KL_2(A,B)
print(res_2)

[4.94637707 4.68578956 4.43734394 ... 4.42768365 5.30389291 7.68950807]
CPU times: user 107 ms, sys: 16.8 ms, total: 124 ms
Wall time: 74.1 ms


In [37]:
print(np.abs(res_1-res_2))

[254.0746173  363.10812259 778.25755095 ...  44.03968552  41.29215547
 567.57236305]


In [85]:
pl = mod.PowerLaw(nu_0=50., ell_0=10.)
freqs = np.array([100.,200.,300.])
ells = np.linspace(2,100,99)
cov1 = pl.eval(nu=freqs,ell=ells, alpha=3.,beta=2.5) + np.broadcast_to(1e-10 * np.identity(3),(99,3,3))
cov2 = pl.eval(nu=freqs,ell=ells, alpha=4.2,beta=3.6) + np.broadcast_to(1e-10 * np.identity(3),(99,3,3))


In [86]:
print(np.linalg.det(cov2))
print(np.linalg.det(cov1))


[ 2.00637562e-16  1.10407653e-15  3.70177939e-15  9.42088470e-15
  2.00109673e-14  3.82335478e-14  7.36822177e-14  1.09862979e-13
  1.71014459e-13  1.02042005e-13  5.88235096e-13  8.22669439e-13
  5.61527511e-13  2.25080479e-12  3.92955744e-12  2.54215097e-12
  0.00000000e+00  8.08732392e-12  0.00000000e+00  0.00000000e+00
 -2.99395294e-11  0.00000000e+00 -4.36674789e-11  0.00000000e+00
 -1.22233138e-10 -1.46638454e-10  0.00000000e+00  0.00000000e+00
  4.14052567e-10 -4.87372151e-10  5.56892674e-10  0.00000000e+00
 -7.18378039e-10 -1.29821722e-09  0.00000000e+00  4.09872545e-10
 -2.06302840e-09  0.00000000e+00  2.84330570e-09  6.30803051e-10
 -5.58391236e-09 -7.70493715e-10  0.00000000e+00  5.59558369e-09
  0.00000000e+00 -8.95576270e-09 -7.33780277e-09  1.06687886e-08
  1.16135660e-08 -2.52415668e-08  0.00000000e+00  0.00000000e+00
 -3.20904170e-08  5.19919522e-08 -3.73861964e-08 -8.05426583e-08
  0.00000000e+00  0.00000000e+00 -4.99524532e-08  0.00000000e+00
  1.71983950e-07  2.45251

In [67]:
KL_1(cov2,cov1)

LinAlgError: Singular matrix