In [1]:
import warnings
import numpy as np
from six.moves import reload_module as reload
from sklearn.utils.extmath import fast_logdet
from sklearn.utils import check_random_state
from sklearn.datasets import make_sparse_spd_matrix
from sklearn.covariance import empirical_covariance
from functools import partial

from network_inference.prox import prox_logdet, soft_thresholding_od
import network_inference.datasets; reload(network_inference.datasets)
from network_inference.datasets import is_pos_def, is_pos_semi_def
from network_inference.utils import _scalar_product, update_rho, convergence

In [25]:
np.random.seed(0)
random_state=0
A = make_sparse_spd_matrix(dim=15, alpha=0.7, random_state=random_state)

T_true = A[5:,5:]
K_true = A[10:,5:,]
H_true = A[0:5,0:5]
print(is_pos_def(H_true))
per_cov = K_true*0.3
T_obs = T_true - per_cov.T.dot(np.linalg.inv(H_true)).dot(per_cov)

samples = np.random.multivariate_normal(np.zeros(10), np.linalg.inv(T_obs), 500)

True


In [8]:
from sklearn.utils.extmath import squared_norm
from regain.norm import l1_od_norm
from scipy import linalg

In [9]:
def objective_H(H, R=None, T=None, K=None, U= None,_rho=1, _mu=1):
    if not is_pos_def(H):
        return np.inf
    return 0.5 * _rho * squared_norm(R - T + U + np.linalg.multi_dot((K.T, linalg.pinvh(H), K))) \
            + _mu * l1_od_norm(H)

In [10]:
def _choose_lambda(lamda, R, T, K, H, U,  _rho, _mu, prox, grad, gamma, delta=1e-4, eps=0.9, max_iter=500):
    """Choose lambda for backtracking.

    References
    ----------
    Salzo S. (2017). https://doi.org/10.1137/16M1073741

    """
    partial_f = partial(objective_H, R=R, T=T, K=K, U=U, _rho=_rho, _mu=_mu)
    fx = partial_f(H)

    y_minus_x = prox - H
    tolerance = _scalar_product(y_minus_x, grad)
    tolerance += delta / gamma * _scalar_product(y_minus_x, y_minus_x)
    #print("Tolerance:", tolerance)
    for i in range(max_iter):
        # line-search
        x1 = H + lamda * y_minus_x

        loss_diff = partial_f(x1) - fx
        #print("Loss diff:", loss_diff)
        if loss_diff <= lamda * tolerance:
              break
        lamda *= eps
    else:
        return False, i+1
    return lamda, i + 1

In [11]:
def _choose_gamma(gamma, H, R, T, K, U, _rho, _mu, _lambda, grad,
                 eps=0.9, max_iter=500):
    """Choose gamma for backtracking.

    References
    ----------
    Salzo S. (2017). https://doi.org/10.1137/16M1073741

    """
    partial_f = partial(objective_H, R=R, T=T, K=K, U=U, _rho=_rho, _mu=_mu)
    fx = partial_f(H)
    for i in range(max_iter):
        prox = soft_thresholding_od(H - gamma * grad, _mu * gamma)
        if is_pos_def(prox):
            break
        gamma *= eps
    else:
        print("not found gamma")
    return gamma, prox

In [48]:
def _upgrade_H(H, R, T, K, U, _rho, _mu, verbose=0, random_state=None):
    # H = make_sparse_spd_matrix(dim=K.shape[0], alpha=0.5, random_state=random_state)
    _lambda = 1
    gamma = 1
    obj = 1e+10
    for iter_ in range(2000):
        H_old = H.copy()
        Hinv = linalg.pinvh(H)
        gradient = - _rho * K.dot(R - T + U + np.linalg.multi_dot((K.T, Hinv, K))).dot(K.T).dot(Hinv).dot(Hinv)
        gamma, _ = _choose_gamma(gamma, H, R, T, K, U, _rho,_mu, _lambda, gradient)
        # print(gamma)
        Y = soft_thresholding_od(H - gamma * gradient, gamma * _mu)
        _lambda,_ = _choose_lambda(_lambda, R, T, K, H, U,_rho, _mu, Y, gradient, 1, max_iter=1000, delta=1e-2)

        H = H + _lambda * (Y - H)
        
        obj_old = obj
        obj = objective_H(H, R, T, K, U,_rho=_rho, _mu=_mu)
        obj_diff = obj_old - obj
        iter_diff =np.linalg.norm(H - H_old) 
        if verbose:
            print("Iter: %d, obj: %.5f, iter_diff: %.5f, obj_diff:%.10f"%(iter_, obj, iter_diff, obj_diff))
        if(obj_diff<1e-4): 
            break
    else:
        print("Did not converge")
    return H

In [49]:
# R = np.random.rand(10,10)
# T = np.random.rand(10,10)
# U = np.zeros((10,10))
# K = per_cov
# H = make_sparse_spd_matrix(dim=K.shape[0], alpha=0.5)
# R = T - K.T.dot(np.linalg.pinv(H)).dot(K)
# H_found = _upgrade_H(R, T, K, U, 1,0.5, 1, random_state=0)
# H - H_found, H, H_found

In [50]:
def objective(emp_cov, K, R, T, H, mu, eta, rho):
    res = - fast_logdet(R) + np.sum(R * emp_cov)
    res += rho / 2. * squared_norm(R - T + U + np.linalg.multi_dot((K.T, linalg.pinvh(H), K)))
    res += mu * l1_od_norm(H)
    res += eta * l1_od_norm(T)
    return res

In [51]:
def fixed_interlinks_graphical_lasso(X, K, mu=0.01, eta=0.01, rho=1., 
        tol=1e-3, rtol=1e-5, max_iter=100, verbose=False, return_n_iter=True,
        return_history=False, compute_objective=False, compute_emp_cov=False,
        random_state=None):
    
    random_state = check_random_state(random_state)
    if compute_emp_cov:
        n = X.shape[0] 
        emp_cov = empirical_covariance(X, assume_centered=False)
    else:
        emp_cov = X

    # H = make_sparse_spd_matrix(K.shape[0], alpha=0.5)
    H = np.eye(K.shape[0])
    T = emp_cov.copy()
    T = (T + T.T) / 2.
    R = T - np.linalg.multi_dot((K.T, linalg.pinvh(H), K))
    U = np.zeros((K.shape[1], K.shape[1]))
    
    checks = []
    for iteration_ in range(max_iter):
        R_old = R.copy()
        
        # R update
        Hinv = linalg.pinvh(H)
        M = T - U - K.T.dot(Hinv).dot(K)
        M = (M + M.T)/2
        R = prox_logdet(emp_cov - rho * M, 1. / rho)
        assert is_pos_def(R), "iter %d"%iteration_
        #print("----------------------R---------------------\n", R)
        
        # T update
        M = R + U + K.T.dot(Hinv).dot(K)
        M = (M + M.T) / 2.
        # print(M)
        T = soft_thresholding_od(M, eta / rho)
        assert is_pos_def(T, tol=1e-8), "teta iter %d"%iteration_
       
        #print("----------------------T---------------------\n",T)
        # H update
        H = _upgrade_H(H, R, T, K, U, rho, mu, verbose=0)
        assert(is_pos_def(H))
        #print("----------------------H---------------------\n",H)
        
        # U update
        KHK = np.linalg.multi_dot((K.T, linalg.pinvh(H), K))
       # assert is_pos_semi_def(KHK)
        U += R - T + KHK

        # diagnostics, reporting, termination checks
        
        obj = objective(emp_cov, K, R, T, H, mu, eta, rho) \
            if compute_objective else np.nan
        rnorm = np.linalg.norm(R - T + KHK)
        snorm = rho * np.linalg.norm(R - R_old)
        check = convergence(
            obj=obj, rnorm=rnorm, snorm=snorm,
            e_pri=(np.sqrt(R.size) * tol + rtol *
                   max(np.linalg.norm(R),
                       np.linalg.norm(T - KHK))),
            e_dual=(np.sqrt(R.size) * tol + rtol * rho *
                    np.linalg.norm(U))
        )

        if verbose:
            print("obj: %.4f, rnorm: %.4f, snorm: %.4f,"
                  "eps_pri: %.4f, eps_dual: %.4f" % check)

        checks.append(check)
        if check.rnorm <= check.e_pri and check.snorm <= check.e_dual:
            break
        rho_new = update_rho(rho, rnorm, snorm, iteration=iteration_)
        # scaled dual variables should be also rescaled
        U *= rho / rho_new
        rho = rho_new
    else:
        warnings.warn("Objective did not converge.")

    return_list = [R, T, H, emp_cov]
    if return_n_iter:
        return_list.append(iteration_)
    if return_history:
        return_list.append(checks)
    return return_list


In [52]:
res = fixed_interlinks_graphical_lasso(samples, per_cov, mu=1, eta=1, rho=1., 
        verbose=1, compute_objective=1, compute_emp_cov=1,
        random_state=0)

obj: 13.4987, rnorm: 0.3683, snorm: 3.6179,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.1296, rnorm: 1.0062, snorm: 1.0862,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.1274, rnorm: 0.6145, snorm: 0.4625,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.3102, rnorm: 0.4136, snorm: 0.2503,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.5045, rnorm: 0.2795, snorm: 0.1989,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.6839, rnorm: 0.1892, snorm: 0.1202,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.8161, rnorm: 0.1374, snorm: 0.0587,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.9141, rnorm: 0.1009, snorm: 0.0405,eps_pri: 0.0100, eps_dual: 0.0100
obj: 11.9857, rnorm: 0.0748, snorm: 0.0285,eps_pri: 0.0100, eps_dual: 0.0100
obj: 12.0372, rnorm: 0.0561, snorm: 0.0203,eps_pri: 0.0100, eps_dual: 0.0100
obj: 12.0744, rnorm: 0.0425, snorm: 0.0146,eps_pri: 0.0100, eps_dual: 0.0100
obj: 12.1011, rnorm: 0.0324, snorm: 0.0106,eps_pri: 0.0100, eps_dual: 0.0100
obj: 12.1203, rnorm: 0.0250, snorm: 0.0078,eps_pri: 0.0100, eps_dual: 0.0100