In [1]:
import warnings
import numpy as np
from six.moves import reload_module as reload
from sklearn.utils.extmath import fast_logdet
from sklearn.utils import check_random_state
from sklearn.datasets import make_sparse_spd_matrix
from sklearn.covariance import empirical_covariance
from functools import partial

from network_inference.prox import prox_logdet, soft_thresholding_od
import network_inference.datasets; reload(network_inference.datasets)
from network_inference.datasets import is_pos_def, is_pos_semi_def
from network_inference.utils import _scalar_product, update_rho, convergence

In [2]:
np.random.seed(0)
random_state=0
A = make_sparse_spd_matrix(dim=15, alpha=0.7, random_state=random_state)

T_true = A[5:,5:]
K_true = A[10:,5:,]
H_true = A[0:5,0:5]
print(is_pos_def(H_true))
per_cov = K_true*0.3
T_obs = T_true - per_cov.T.dot(np.linalg.inv(H_true)).dot(per_cov)

samples = np.random.multivariate_normal(np.zeros(10), np.linalg.inv(T_obs), 50)

True


In [3]:
def objective_H(H, R=None, T=None, K=None, U= None,_rho=1, _mu=1):
    return _rho/2* np.linalg.norm(R - T + U+ K.T.dot(np.linalg.pinv(H).dot(K)))**2 + _mu*np.linalg.norm(H, 1)


In [13]:
def _choose_lambda(lamda, R, T, K, H, U,  _rho, _mu, prox, grad, gamma, delta=1e-4, eps=0.9, max_iter=500):
    """Choose lambda for backtracking.

    References
    ----------
    Salzo S. (2017). https://doi.org/10.1137/16M1073741

    """
    partial_f = partial(objective_H, R=R, T=T, K=K, U=U, _rho=_rho, _mu=_mu)
    fx = partial_f(H)

    y_minus_x = prox - H
    tolerance = _scalar_product(y_minus_x, grad)
    tolerance += delta / gamma * _scalar_product(y_minus_x, y_minus_x)
    #print("Tolerance:", tolerance)
    for i in range(max_iter):
        # line-search
        x1 = H + lamda * y_minus_x

        loss_diff = partial_f(x1) - fx
        #print("Loss diff:", loss_diff)
        if loss_diff <= lamda * tolerance:
              break
        lamda *= eps
    else:
        return False, i+1
    return lamda, i + 1

In [14]:
def _choose_gamma(gamma, H, R, T, K, U, _rho, _mu, _lambda, grad,
                 eps=0.9, max_iter=500):
    """Choose gamma for backtracking.

    References
    ----------
    Salzo S. (2017). https://doi.org/10.1137/16M1073741

    """
    partial_f = partial(objective_H, R=R, T=T, K=K, U=U, _rho=_rho, _mu=_mu)
    fx = partial_f(H)
    for i in range(max_iter):
        prox = soft_thresholding_od(H - gamma * grad, _mu * gamma)
        if is_pos_def(prox):
            break
        gamma *= eps
    else:
        print("not found gamma")
    return gamma, prox

In [15]:
def _upgrade_H(R, T, K, U, _rho, _mu, verbose=0, random_state=None):
    H = make_sparse_spd_matrix(dim=K.shape[0], alpha=0.5, random_state=random_state)
    _lambda = 1
    gamma = 1
    obj = 1e+10
    for iter_ in range(2000):
        H_old = H
        Hinv = np.linalg.inv(H)
        gradient = -_rho* K.dot(R - T + U + np.linalg.multi_dot((K.T, Hinv, K))).dot(K.T).dot(Hinv).dot(Hinv)
        gamma, _ = _choose_gamma(gamma, H, R, T, K, U, _rho,_mu, _lambda, gradient)
        print(gamma)
        Y = soft_thresholding_od(H - gamma*gradient, gamma*_mu)
        _lambda,_ = _choose_lambda(_lambda, R, T, K, H, U,_rho, _mu, Y, gradient, 1, max_iter=1000, delta=1e-2)
        if not _lambda:
            
        print(_lambda)
        H = H + _lambda*(Y - H)
        obj_old = obj
        obj = objective_H(H, R, T, K, U,_rho=_rho, _mu=_mu)
        obj_diff = obj_old - obj
        iter_diff =np.linalg.norm(H - H_old) 
        if verbose:
            print("Iter: %d, obj: %.5f, iter_diff: %.5f, obj_diff:%.10f"%(iter_, obj, iter_diff, obj_diff))
        if(obj_diff<1e-4): 
            break
    else:
        print("Did not converge")
    return H

In [18]:
R = np.random.rand(10,10)
T = np.random.rand(10,10)
U = np.zeros((10,10))
K = per_cov
H = make_sparse_spd_matrix(dim=K.shape[0], alpha=0.5)
R = T - K.T.dot(np.linalg.pinv(H)).dot(K)
H_found = _upgrade_H(R, T, K, U, 1,0.5, 1, random_state=0)
H - H_found, H, H_found

1
False
Iter: 0, obj: 1.98929, iter_diff: 0.00000, obj_diff:9999999998.0107116699
1
False
Iter: 1, obj: 1.98929, iter_diff: 0.00000, obj_diff:0.0000000000


(array([[ 0.        , -0.12240039,  0.64545624, -0.49486625,  0.22317527],
        [-0.12240039, -0.17894971, -0.51510919,  0.37147863,  0.47651558],
        [ 0.64545624, -0.51510919, -0.76901556,  0.03059548, -0.15915288],
        [-0.49486625,  0.37147863,  0.03059548,  0.33168167,  0.07044981],
        [ 0.22317527,  0.47651558, -0.15915288,  0.07044981, -0.04996649]]),
 array([[ 1.        , -0.58732065,  0.        , -0.49486625, -0.20855428],
        [-0.58732065,  1.34494555,  0.        , -0.21819795,  0.12248824],
        [ 0.        ,  0.        ,  1.        , -0.41903009, -0.26809666],
        [-0.49486625, -0.21819795, -0.41903009,  1.67940014,  0.07044981],
        [-0.20855428,  0.12248824, -0.26809666,  0.07044981,  1.13642391]]),
 array([[ 1.        , -0.46492027, -0.64545624,  0.        , -0.43172955],
        [-0.46492027,  1.52389526,  0.51510919, -0.58967658, -0.35402734],
        [-0.64545624,  0.51510919,  1.76901556, -0.44962556, -0.10894379],
        [ 0.        ,

In [23]:
def objective(emp_cov, K, R, T, H, mu, eta, rho):
    res = fast_logdet(R)
    res += np.sum(R*emp_cov)
    res += rho/2 * np.linalg.norm(R - T + U + K.T.dot(np.linalg.inv(H)).dot(K))**2 
    res += mu*np.linalg.norm(H,1)
    res += eta*np.linalg.norm(T,1)
    return res

In [24]:
def fixed_interlinks_graphical_lasso(X, K, mu=0.01, eta=0.01, rho=1., 
        tol=1e-3, rtol=1e-5, max_iter=100, verbose=False, return_n_iter=True,
        return_history=False, compute_objective=False, compute_emp_cov=False,
        random_state=None):
    
    random_state = check_random_state(random_state)
    if compute_emp_cov:
        n = X.shape[0] 
        emp_cov = empirical_covariance(X, assume_centered=False)
    else:
        emp_cov = X

    H = make_sparse_spd_matrix(K.shape[0], alpha=0.5)
    T = emp_cov.copy()
    T = (T + T.T)/2
    R = T - np.linalg.multi_dot((K.T, np.linalg.inv(H), K))
    U = np.zeros((K.shape[1], K.shape[1]))
    
    checks = []
    for iteration_ in range(max_iter):
        R_old = R.copy()
        
        # R update
        M = T - U - K.T.dot(np.linalg.inv(H)).dot(K)
        M = (M + M.T)/2
        R = prox_logdet(emp_cov - rho*M, 1/rho)
        assert is_pos_def(R), "iter %d"%iteration_
        #print("----------------------R---------------------\n", R)
        # T update
        M = - R - U - K.T.dot(np.linalg.inv(H)).dot(K)
        M = (T + T.T)/2
        T = soft_thresholding_od(M, eta/rho)
        assert is_pos_def(T), "teta iter %d"%iteration_
       
        #print("----------------------T---------------------\n",T)
        # H update
        H = _upgrade_H(R, T, K, U, rho, mu, verbose=0)
        assert(is_pos_def(H))
        #print("----------------------H---------------------\n",H)
        # U update
        KHK = np.linalg.multi_dot((K.T, np.linalg.inv(H), K))
       # assert is_pos_semi_def(KHK)
        U += R - T + KHK

        # diagnostics, reporting, termination checks
        
        obj = objective(emp_cov, K, R, T, H, mu, eta, rho) \
            if compute_objective else np.nan
        rnorm = np.linalg.norm(R - T + KHK)
        snorm = rho *np.linalg.norm(R - R_old)
        check = convergence(
            obj=obj, rnorm=rnorm, snorm=snorm,
            e_pri=(np.sqrt(R.size) * tol + rtol *
                   max(np.sqrt(np.linalg.norm(R)**2 + np.linalg.norm(U)**2),
                       np.linalg.norm(T - KHK))),
            e_dual=(np.sqrt(R.size) * tol + rtol * rho *
                    np.linalg.norm(U))
        )

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check)

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break
        rho_new = update_rho(rho, rnorm, snorm, iteration=iteration_)
        # scaled dual variables should be also rescaled
        U *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    return_list = [R, T, H, emp_cov]
    if return_n_iter:
        return_list.append(iteration_)
    if return_history:
        return_list.append(checks)
    return return_list


In [25]:
res = fixed_interlinks_graphical_lasso(samples, per_cov, mu=1, eta=1, rho=1., 
        verbose=1, compute_objective=1, compute_emp_cov=1,
        random_state=0)

obj: 18.8770, rnorm: 2.0269, snorm: 4.0423,eps_pri: 0.0100, eps_dual: 0.0100
not found lambda
not found lambda
obj: 22.6988, rnorm: 2.5417, snorm: 2.1123,eps_pri: 0.0101, eps_dual: 0.0100
obj: 21.6447, rnorm: 1.3931, snorm: 1.1561,eps_pri: 0.0101, eps_dual: 0.0100
obj: 23.6592, rnorm: 0.6238, snorm: 0.8942,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 25.9274, rnorm: 0.5373, snorm: 0.4487,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
not found lambda
obj: 23.5654, rnorm: 0.6475, snorm: 0.2877,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 24.3972, rnorm: 0.5841, snorm: 0.3449,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 24.5371, rnorm: 0.5819, snorm: 0.1776,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 25.7510, rnorm: 1.5623, snorm: 0.2123,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 24.3098, rnorm: 0.5989, snorm: 0.3677,eps_pri: 0.0101, eps_dual: 0.0101
not found lambda
obj: 24.3914, rnorm: 1.2198, snorm: 0.4178,eps_pri: 0.0101, 

not found lambda
obj: 21.5800, rnorm: 0.5756, snorm: 0.8720,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 23.0748, rnorm: 0.7658, snorm: 0.5617,eps_pri: 0.0102, eps_dual: 0.0105
obj: 20.5121, rnorm: 0.4491, snorm: 1.0163,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 23.4222, rnorm: 0.5330, snorm: 0.8181,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 23.7160, rnorm: 0.5299, snorm: 0.7290,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 23.4418, rnorm: 1.0722, snorm: 0.7961,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 21.8537, rnorm: 0.8662, snorm: 1.0684,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
not found lambda
obj: 20.6573, rnorm: 0.5104, snorm: 0.9025,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
obj: 21.5801, rnorm: 0.3426, snorm: 0.8971,eps_pri: 0.0102, eps_dual: 0.0105
not found lambda
obj: 20.5299, rnorm: 0.3655, snorm: 0.5795,eps_pri: 

