In [39]:
import numpy as np 
import matplotlib.pyplot as plt 
import time

In [45]:
def posdef():
    A = np.random.rand(100,100)
    return np.einsum('ab,cb->ac',A,A)


In [3]:
def KL_1(A,B):
    m = A.shape[1]
    inv_A = np.linalg.solve(A,np.identity(m))
    _, logdet = np.linalg.slogdet(
        np.einsum('ab,bc->ac', inv_A, B))
    kl = .5 * (np.einsum('ab,ba', inv_A, B) - logdet - m)
    return (kl)

In [33]:
def KL_2(A,B):
    m = A.shape[1]
    Z  = np.linalg.solve(np.linalg.cholesky(A) , np.linalg.cholesky(B))
    k  = np.dot(Z.flatten(), Z.flatten().T) - 2.*np.sum(np.log(np.diag(Z))) -m
    return 0.5*k

In [34]:
def KL_3(A,B,exact=True):
    R1 = np.mat(A)
    R2 = np.mat(B)

    if (np.isnan(R2.sum())):
        return nan

    R1 = 0.5 * (R1 + R1.T)
    R2 = 0.5 * (R2 + R2.T)
    fudge_factor = 1000
    u,v = np.linalg.eigh(R2) # eig for sym matrix. Now u is real.
    eps = 2e-16#numpy.lib.getlimits.finfo(R1.dtype).eps
    pos = np.argwhere(u > fudge_factor*eps*sum(u) )

    if np.size(pos) == 0:
        print("numerical issue in kullback")
        vp = v
        up = u
    else:
        vp = v[:,pos]
        up = u[pos]

    k = 0.5* (np.sum(np.diag(vp.T*R1*vp)/up) + np.sum(np.log(up)))
    if exact:
        k -= 0.5* (np.log (np.linalg.det (R1)) + np.shape(R1)[1])

    # FIXME
    if np.iscomplex(k):
        print("The criterion is complex !!!")        #k = real(k)

    return k

In [48]:
test1 = []
test2 = []
for _ in range(100):
    A = posdef()
    B = posdef()
    test1.append(KL_2(A,B))
    test2.append(KL_1(A,B))
   


In [49]:
test1 = []
test2 = []
for _ in range(1000):
    A = posdef()
    B = posdef()
    time_start = time.time()
    KL_2(A,B)
    time_end = time.time()
    test2.append(time_end-time_start)
    
    time_start = time.time()
    KL_1(A,B)
    time_end = time.time()
    test1.append(time_end-time_start)
    
    
   



In [50]:
print(np.array(test1).mean())
print(np.array(test2).mean())

0.0008087773323059082
0.0004889364242553711
