In [1]:
from __future__ import print_function

import numpy as np
import random

from matplotlib import pyplot as plt

from selectinf.nbd_lasso import nbd_lasso
from selectinf.Utils.discrete_family import discrete_family
from statsmodels.distributions.empirical_distribution import ECDF
from instance import GGM_instance

from nbd_naive_and_ds import *
from scipy.integrate import quad
from scipy.optimize import root_scalar

In [208]:
def print_nonzero_intervals(nonzero, intervals, prec, X):
    # Intervals, prec, X are all in their original scale
    n, p = X.shape
    S = X.T @ X / n

    for i in range(p):
        for j in range(i + 1, p):
            if nonzero[i, j]:
                print("(", i, ",", j, ")", "selected")
                print("Theta", "(", i, ",", j, ")", "interval:", intervals[i, j, :])
                print("Theta", "(", i, ",", j, ")", prec[i, j])
                print("S/n", "(", i, ",", j, ")", S[i, j])

In [254]:
# TODO: Add root n to the randomization covariance
# Remark: Not needed (?) anymore since X is now scaled
prec,cov,X = GGM_instance(n=200,p=50, max_edges=1)
nbd_instance = nbd_lasso.gaussian(X, n_scaled=False, weights_const=0.4)
active_signs_nonrandom = nbd_instance.fit(perturb=np.zeros((50,49)))
active_signs_random = nbd_instance.fit()
print(active_signs_nonrandom.shape)
print(np.abs(active_signs_nonrandom).sum())
print(np.abs(active_signs_random).sum())
print(np.abs(prec != 0).sum() - 10)

(50, 49)
10.0
78.0
42


In [255]:
n,p = X.shape
X_n = X / np.sqrt(n)
prec_n = prec * n
nonzero = get_nonzero(active_signs_nonrandom)

In [256]:
intervals = conditional_inference(X_n, nonzero)
# coverage is upper-triangular
coverage = get_coverage(nonzero, intervals, prec_n, n, p)

interval_len = 0
nonzero_count = 0  # nonzero_count is essentially upper-triangular
for i in range(p):
    for j in range(i+1,p):
        if nonzero[i,j]:
            interval = intervals[i,j,:]
            interval_len = interval_len + (interval[1] - interval[0])
            nonzero_count = nonzero_count + 1
if nonzero_count > 0:
    avg_len = interval_len / nonzero_count
    cov_rate = coverage.sum() / nonzero_count
    print(cov_rate)
else:
    print("No selection")

Normalized pdf is nan
theta: 1.0855805792587502e+308
suff stat max: 3.0 suff stat min: -3.0
Min log order: -inf
Min _thetaX: -inf
Min log weights: -inf
Max log order: nan
# nan in _thetaX: 0
Normalized pdf is nan
theta: 8.141854344440626e+307
suff stat max: 3.0 suff stat min: -3.0
Min log order: -inf
Min _thetaX: -inf
Min log weights: -inf
Max log order: nan
# nan in _thetaX: 0
Normalized pdf is nan
theta: 1.0855805792587502e+308
suff stat max: 3.0 suff stat min: -3.0
Min log order: -inf
Min _thetaX: -inf
Min log weights: -inf
Max log order: nan
# nan in _thetaX: 0
Normalized pdf is nan
theta: 8.141854344440626e+307
suff stat max: 3.0 suff stat min: -3.0
Min log order: -inf
Min _thetaX: -inf
Min log weights: -inf
Max log order: nan
# nan in _thetaX: 0
0.625


In [257]:
print_nonzero_intervals(nonzero, intervals, prec, X)

( 0 , 45 ) selected
Theta ( 0 , 45 ) interval: [0.05307822 0.3456344 ]
Theta ( 0 , 45 ) -0.0
S/n ( 0 , 45 ) -0.24586886933880886
( 2 , 21 ) selected
Theta ( 2 , 21 ) interval: [-0.2293921   0.07628459]
Theta ( 2 , 21 ) -0.0
S/n ( 2 , 21 ) 0.34674454756492196
( 2 , 39 ) selected
Theta ( 2 , 39 ) interval: [-0.16304337  0.1381958 ]
Theta ( 2 , 39 ) 0.0
S/n ( 2 , 39 ) 0.33812718963534566
( 14 , 39 ) selected
Theta ( 14 , 39 ) interval: [-0.39140583 -0.1118494 ]
Theta ( 14 , 39 ) -0.0
S/n ( 14 , 39 ) 0.2536207573301948
( 16 , 39 ) selected
Theta ( 16 , 39 ) interval: [-0.09312802  0.22181834]
Theta ( 16 , 39 ) -0.0
S/n ( 16 , 39 ) -0.23210903600317706
( 19 , 39 ) selected
Theta ( 19 , 39 ) interval: [-0.06249556  0.23747341]
Theta ( 19 , 39 ) -0.0
S/n ( 19 , 39 ) -0.2699399168282541
( 21 , 33 ) selected
Theta ( 21 , 33 ) interval: [-0.04540548  0.24123728]
Theta ( 21 , 33 ) -0.0
S/n ( 21 , 33 ) -0.23383202003838882
( 21 , 39 ) selected
Theta ( 21 , 39 ) interval: [-4.07092717e+305 -4.07092

In [262]:
def edge_inference(j0k0, S, n, p, var=None,
                   ngrid=10000):
    j0 = j0k0[0]
    k0 = j0k0[1]
    # n_total: the total data points in data splitting
    #        : the raw dimension of X in naive
    inner_prod = S[j0,k0]
    # print("inner_prod", "(", j0, ",", k0, "):" , inner_prod)
    # print("var:", var)

    S_copy = np.copy(S)

    #stat_grid = np.zeros((ngrid,))
    #print("n=100 assumed")
    stat_grid = np.linspace(-10,10,num=ngrid)
    def log_det_S_j_k(s_val):
        S_j_k = S_copy
        S_j_k[j0,k0] = s_val
        S_j_k[k0,j0] = s_val
        if np.linalg.det(S_j_k) < 0:
            #print("negative det", np.linalg.det(S_j_k),
            #      "grid", s_val)
            return -np.inf
        return np.log((np.linalg.det(S_j_k))) * (n-p-1)/2

    logWeights = np.zeros((ngrid,))
    for g in range(ngrid):
        logWeights[g] = log_det_S_j_k(stat_grid[g])

    # normalize logWeights
    logWeights = logWeights - np.max(logWeights)
    # Set extremely small values (< e^-500) to e^-500 for numerical stability
    # logWeights_zero = (logWeights < -500)
    # logWeights[logWeights_zero] = -500

    condlWishart = discrete_family(stat_grid, np.exp(logWeights),
                                   logweights=logWeights)

    neg_interval = condlWishart.equal_tailed_interval(observed=inner_prod,
                                                      alpha=0.1)
    if np.isnan(neg_interval[0]) or np.isnan(neg_interval[1]):
        print("Failed to construct intervals: nan")

    interval = invert_interval(neg_interval)

    pivot = condlWishart.ccdf(theta=0)

    return pivot, interval[0], interval[1]#neg_interval, condlWishart

In [263]:
S_ = X.T @ X / n

In [264]:
pivot, lcb, ucb = edge_inference(j0k0=(21,39), S=S_, n=n, p=p, ngrid=10000)

In [265]:
lcb / n , ucb / n

(-1.1691259993492786, -0.7780059277458823)

In [120]:
print(nonzero)

[[False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False  True False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
 [False False False False False False False False False False False False
  False False False False False False False False]
