# Knockoff on Linear Gaussian Model


+ generate knockoff with closed formula, in case x ~ Normal

In [8]:
import os
import sys

module_paths = [
    '../external/deepknockoffs/DeepKnockoffs/DeepKnockoffs',  # kfilter
    '../external/deepknockoffs/examples'
]
module_paths = [os.path.abspath(os.path.join(x)) for x in module_paths]
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)


import math
import numpy as np
import matplotlib.pyplot as plt

from DeepKnockoffs.gaussian import GaussianKnockoffs
from DeepKnockoffs.kfilter import kfilter

import shap

In [2]:
from sklearn import preprocessing
from glmnet import LogitNet

def lasso_stats(X,Xk,y,alpha=0.1,scale=True):
    #
    # W: lasso difference coefficient
    # Z: 
    #
    X  = X.astype("float")
    Xk = Xk.astype("float")
    p = X.shape[1]
    if scale:
        X_concat = preprocessing.scale(np.concatenate((X,Xk),1))
    else:
        X_concat = np.concatenate((X,Xk),1)
    cols_order = np.random.choice(X_concat.shape[1],X_concat.shape[1],replace=False)
    m = LogitNet(alpha=0.1, n_splits=3)
    m.fit(X_concat[:,cols_order].copy(), y.copy())
    Z = np.zeros((2*p,))
    Z[cols_order] = m.coef_.squeeze()
    W = np.abs(Z[0:p]) - np.abs(Z[p:(2*p)])
    return W.squeeze(), Z

In [88]:
# number of data point
n = 1000
# number of variables 
p = 1000
# number of variables with nonzero coefficients
k = 60
# magnitude of nonzero coefficients
amplitude = 15
# noise level 
sigma = 1
# target FDR 
q = 0.5

In [89]:
np.random.seed(0)

Sigma = np.eye(p)
mu = np.zeros(p)
S0 = np.random.choice(p, k, replace=False)
beta = np.zeros(p)
beta[S0] = amplitude/np.sqrt(n)

def sigmoid(x):
    return np.exp(x)/(1 + np.exp(x))

def logistic_model(X):
    return sigmoid(X@beta)

def sample_logistic(X):
    return np.random.binomial(1, logistic_model(X))

def summary(S):
    
    true_discovery = sum(beta[S] > 0)
    power = true_discovery*100/k
    FDP = sum(beta[S] == 0) / max(1, np.size(S))
    
    print(f"""
        true_discovery = {true_discovery}/{k} (power = {power}%)
        FDP = {FDP} % (target FDR = {q})
    """)

In [90]:
# Data
X = np.random.multivariate_normal(mu, Sigma, size=n)
y = sample_logistic(X)
np.shape(X), np.shape(y)

((1000, 1000), (1000,))

In [91]:
# 1) Generate knockoffs X_k
knockoff_generator = GaussianKnockoffs(Sigma, method='equi') # method='equi' too slow
X_k = knockoff_generator.generate(X)
# 2) Compute pairwise statistics W
W, Z = lasso_stats(X,X_k,y,alpha=0,scale=False)
# 3) Compute threshold 
t = kfilter(W, q=q)
print(f'threshold: {t}')
# 4) Perform Test
S = np.where(W >= t)[0]
summary(S)

threshold: 0.00018660335538271783

        true_discovery = 46/60 (power = 76.66666666666667%)
        FDP = 0.23333333333333334 % (target FDR = 0.5)
    


In [116]:
# Use SHAP to compute pairwise statistics
#
np.random.seed(0)

# 1) Generate knockoffs X_k
knockoff_generator = GaussianKnockoffs(Sigma, method='equi') # method='equi' too slow
X_k = knockoff_generator.generate(X)

# for sample_id in range(30):
#     print(f'sample_id={sample_id}')

#     # 2) Compute pairwise statistics W, based on X[sample_id,:] X_k[sample_id,:]
#     explainer   = shap.KernelExplainer(logistic_model, shap.kmeans(X, 3))
#     explainer_k = shap.KernelExplainer(logistic_model, shap.kmeans(X_k, 3))
#     shap_values = explainer.shap_values(X[sample_id,:])
#     shap_values_k = explainer.shap_values(X_k[sample_id,:])
#     # # difference in magnitude
#     W = np.abs(shap_values) - np.abs(shap_values_k)

#     # 3) Compute threshold 
#     t = kfilter(W, q=q,offset=0)
#     print(f'threshold: {t}')
#     # 4) Perform Test
#     S = np.where(W >= t)[0]
#     summary(S)

explainer   = shap.KernelExplainer(logistic_model, shap.kmeans(X, 3))
explainer_k = shap.KernelExplainer(logistic_model, shap.kmeans(X_k, 3))
shap_values = explainer.shap_values(X[:50,:])
shap_values_k = explainer.shap_values(X_k[:50,:])   
W_acc = np.mean(shap_values,axis=0) - np.mean(shap_values_k,axis=0)

print('accumulated pairwise statistics')
# 3) Compute threshold 
t = kfilter(W_acc, q=q,offset=0)
print(f'threshold: {t}')
# 4) Perform Test
S = np.where(W >= t)[0]
summary(S)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


accumulated pairwise statistics
threshold: 0.014718101639042153

        true_discovery = 33/60 (power = 55.0%)
        FDP = 0.0 % (target FDR = 0.5)
    


In [130]:
W_acc = np.mean(np.abs(shap_values),axis=0) - np.mean(np.abs(shap_values_k),axis=0)

print('accumulated pairwise statistics')
# 3) Compute threshold 
t = kfilter(W_acc, q=q,offset=0)
print(f'threshold: {t}')
# 4) Perform Test
S = np.where(W >= t)[0]
summary(S)

accumulated pairwise statistics
threshold: 0.01856918258411911

        true_discovery = 32/60 (power = 53.333333333333336%)
        FDP = 0.0 % (target FDR = 0.5)
    


In [None]:
I = np.argsort(W_acc)
np.intersect1d(S0,I[:100]), S0, I[:100], W_acc[I[:100]]


# shap_values[:30], shap_values_k[:30]

In [None]:

compute_first_n_shap_values = 10
shap_values = explainer.shap_values(X[:compute_first_n_shap_values,:])

In [None]:
shap.initjs()
shap.force_plot(float(explainer.expected_value), shap_values[:10,:], X[:10,:])
# shap.force_plot(explainer.expected_value[0], shap_values[0], X_test, link="logit")


In [None]:
sample_ind = 0
shap.waterfall_plot(explainer.expected_value, shap_values[sample_ind], X[sample_ind,:], max_display=14)


In [None]:
shap.summary_plot(shap_values, X[:10,:], max_display=14)