In [32]:
import numpy as np
import pandas as pd

import regreg.api as rr

from selectinf.group_lasso_query_quasi import (group_lasso_quasi)

from selectinf.base import (selected_targets_quasi)

from selectinf.base import restricted_estimator
import scipy.stats

In [49]:
X = np.asarray(pd.read_csv("X_quasi.csv", header=None, index_col=0))
Y = np.asarray(pd.read_csv("Y_quasi.csv", header=None, index_col=0))
groups = np.arange(50).repeat(4)
n,p = X.shape
Y = np.reshape(Y, (n,))

In [50]:
X

array([[ 1.11387862e-01,  0.00000000e+00,  0.00000000e+00, ...,
         3.51420355e-03, -2.73864428e-02,  1.09097755e-02],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -2.31232150e-02, -6.57981512e-03,  2.00652073e-02],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -8.57872700e-03, -1.64771378e-02,  2.27324254e-03],
       ...,
       [ 0.00000000e+00,  0.00000000e+00,  1.11387862e-01, ...,
        -4.14751581e-03,  4.37496513e-02,  6.94383824e-02],
       [ 1.11387862e-01,  0.00000000e+00,  0.00000000e+00, ...,
        -1.61901403e-02, -4.33401497e-02, -8.94975675e-02],
       [ 0.00000000e+00,  1.13542566e-01,  0.00000000e+00, ...,
         4.67863991e-05,  5.89862924e-02,  5.81135345e-02]])

In [51]:
Y

array([0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 1., 2., 1., 1., 1., 0., 2.,
       0., 0., 0., 0., 2., 2., 0., 1., 2., 1., 1., 1., 2., 0., 0., 0., 1.,
       0., 2., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 3., 0., 1., 1., 1.,
       0., 2., 2., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0.,
       0., 0., 0., 1., 1., 0., 4., 0., 0., 2., 0., 0., 1., 0., 0., 2., 2.,
       1., 0., 0., 3., 0., 0., 3., 2., 1., 2., 4., 0., 1., 0., 0., 0., 1.,
       2., 0., 0., 4., 4., 0., 3., 0., 0., 4., 3., 1., 3., 1., 3., 0., 0.,
       0., 0., 1., 0., 1., 0., 0., 0., 0., 2., 5., 1., 0., 2., 0., 4., 0.,
       2., 0., 2., 1., 1., 0., 8., 2., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
       0., 4., 1., 0., 1., 0., 1., 3., 3., 1., 1., 0., 0., 5., 4., 0., 1.,
       0., 0., 0., 1., 1., 0., 0., 0., 2., 0., 3., 0., 4., 0., 1., 1., 2.,
       0., 0., 3., 2., 0., 0., 1., 0., 1., 4., 2., 0., 0., 2., 1., 1., 4.,
       0., 0., 0., 0., 0., 2., 1., 0., 1., 0., 3., 0., 9., 0., 0., 0., 0.,
       2., 0., 0., 1., 4.

In [52]:
Y.shape

(500,)

In [53]:
def estimate_hess():
    loglike = rr.glm.poisson(X, counts=Y)
    # For LASSO, this is the OLS solution on X_{E,U}
    beta_full = restricted_estimator(loglike, np.array([True] * p))
    W_H = np.diag(np.exp(X @ beta_full))
    return X.T @ W_H @ X

hess = estimate_hess()
cov_rand = hess

sigma_ = np.std(Y)
weight_frac = 1.
# weights = dict([(i, 0.5) for i in np.unique(groups)])
weights = dict([(i, weight_frac * sigma_ * np.sqrt(2 * np.log(p))) for i in np.unique(groups)])

conv = group_lasso_quasi.quasipoisson(X=X,
                                      counts=Y,
                                      groups=groups,
                                      weights=weights,
                                      useJacobian=True,
                                      cov_rand=cov_rand)

signs, _ = conv.fit()
nonzero = (signs != 0)

conv.setup_inference(dispersion=1)
cov_score = conv._unscaled_cov_score
target_spec = selected_targets_quasi(conv.loglike,
                                             conv.observed_soln,
                                             cov_score=cov_score,
                                             dispersion=1)

result,_ = conv.inference(target_spec,
                        method='selective_MLE',
                        level=0.9)

pval = result['pvalue']
intervals = np.asarray(result[['lower_confidence',
                               'upper_confidence']])

  loss_terms = - coef * ((counts - 1) * np.log(counts))


In [54]:
print("Selected Group Indices:", conv._ordered_groups)

Selected Group Indices: [0, 3, 5, 6, 14, 42, 48, 49]


In [55]:
print(pval)

0     0.074363
1     0.987501
2     0.100155
3     0.000348
4     0.039442
5     0.008595
6     0.099150
7     0.049677
8     0.723087
9     0.958617
10    0.162142
11    0.374249
12    0.089566
13    0.974900
14    0.366692
15    0.317703
16    0.581260
17    0.230448
18    0.668651
19    0.024849
20    0.336551
21    0.154597
22    0.279856
23    0.555181
24    0.977142
25    0.092882
26    0.097432
27    0.283116
28    0.313683
29    0.632966
30    0.367639
31    0.649634
Name: pvalue, dtype: float64


In [56]:
print(intervals)

[[ 5.26629948e-01  1.29440096e+01]
 [-4.49486979e+00  4.41005724e+00]
 [-2.77136865e-03  1.21417158e+01]
 [ 7.95918460e+00  2.15104121e+01]
 [-1.02370738e+01 -1.14600990e+00]
 [-1.33262581e+01 -3.06563978e+00]
 [-1.05781789e+01 -1.32840480e-02]
 [-1.23051516e+01 -1.08426499e+00]
 [-5.02164929e+00  3.24157561e+00]
 [-5.64479343e+00  6.01253905e+00]
 [-1.06007538e+01  8.60350917e-01]
 [-2.14653101e+00  7.19018881e+00]
 [-9.52839911e+00 -1.50617003e-01]
 [-4.81026796e+00  4.62969630e+00]
 [-1.92208959e+00  6.59771913e+00]
 [-1.68851988e+00  6.91462433e+00]
 [-6.32205809e+00  3.14695404e+00]
 [-1.21278930e+00  7.73981832e+00]
 [-8.75845502e+00  5.14158577e+00]
 [ 1.65176191e+00  1.07248923e+01]
 [-1.90293138e+00  7.25123532e+00]
 [-1.06429331e+01  7.67869126e-01]
 [-2.66371749e+00  1.28674430e+01]
 [-2.64191816e+00  5.59737867e+00]
 [-4.03170991e+00  3.89365334e+00]
 [ 9.92283950e-02  9.28558286e+00]
 [-9.51247476e+00 -3.62398648e-02]
 [-1.73949817e+00  8.27337383e+00]
 [-7.38443719e+00  1

In [57]:
print(result['MLE'])

0      6.735320
1     -0.042406
2      6.069472
3     14.734798
4     -5.691542
5     -8.195949
6     -5.295731
7     -6.694708
8     -0.890037
9      0.183873
10    -4.870201
11     2.521829
12    -4.839508
13    -0.090286
14     2.337815
15     2.613052
16    -1.587552
17     3.263515
18    -1.808435
19     6.188327
20     2.674152
21    -4.937532
22     5.101863
23     1.477730
24    -0.069028
25     4.692406
26    -4.774357
27     3.266938
28    -2.805030
29     1.763976
30     2.238045
31    -1.269197
Name: MLE, dtype: float64
