# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, load_ukbb_data
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from proximalde.crossfit import fit_predict
import os
pd.options.display.max_columns = None

In [None]:
SAVE_PATH = './results/'
from tqdm import tqdm 
for D_label in ['Female','Black',  'Obese', 'Asian']:
    print(D_label)
    for Y_label in tqdm(['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']):
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

        for dual_type in ['Z', 'Q']:
            for ivreg_type in ['adv']:
                save_dir = f'{SAVE_PATH}/ivreg={ivreg_type}_dual={dual_type}_D={D_label}_Y={Y_label}'
                print(save_dir)
                np.random.seed(4)
                est = ProximalDE(cv=3, semi=True, dual_type=dual_type, ivreg_type=ivreg_type,
                                 multitask=False, n_jobs=-1, random_state=3, verbose=1)
                est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)                
                sm = est.summary(decimals=5, save_dir=save_dir)
                print(sm.tables[0])
                print(sm.tables[2])


In [None]:
SAVE_PATH = './results/'
from tqdm import tqdm 
for D_label in ['Black', 'Female', 'Obese', 'Asian']:
    print(D_label)
    for Y_label in tqdm(['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']):
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

        for dual_type in ['Z', 'Q']:
            for ivreg_type in ['adv', '2sls']:
                save_dir = f'{SAVE_PATH}/ivreg={ivreg_type}_dual={dual_type}_D={D_label}_Y={Y_label}'
                print(save_dir)
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)
                if not os.path.exists(save_dir + '/tests.csv'):
                    np.random.seed(4)

                    est = ProximalDE(cv=3, semi=True, dual_type=dual_type, ivreg_type=ivreg_type,
                                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
                    est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)                
                    sm = est.summary(decimals=5, save_dir=save_dir)
                    print(sm.tables[0])
                    print(sm.tables[2])
                    
                if dual_type == 'Z' and ivreg_type == 'adv' and not os.path.exists(save_dir + '/covrank_test.csv'):
                    svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)
                    np.save(save_dir + '/covrank_test.npy', np.concatenate([svalues, [svalues_crit]]))

## Analyze results 

In [None]:
### Load all data into a single dataframe

In [None]:
test_dfs = []
point_dfs = []
for D_label in ['Black', 'Female', 'Obese','Asian']:
    print(D_label)
    for Y_label in ['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']:

        for dual_type in ['Z', 'Q']:
            for ivreg_type in ['adv']:
                try:
                    save_dir = f'{SAVE_PATH}/ivreg={ivreg_type}_dual={dual_type}_D={D_label}_Y={Y_label}'
                    test_df = pd.read_csv(save_dir + '/tests.csv', header=1, index_col=1)
                    test_df = test_df.drop(columns=['0'])
                    test_df_flat = test_df.T.unstack().to_frame().sort_index(level=1).T
                    test_df_flat.columns = test_df_flat.columns.map('_'.join)
                    point_df = pd.read_csv(save_dir + '/point_est.csv', header=1, index_col=1)
                    point_df = point_df.drop(columns=['0'])
                    test_df_flat['D_Y'] = point_df['D_Y'] = f'{D_label}_{Y_label}'
                    test_df_flat['dual'] = point_df['dual'] = dual_type
                    test_df_flat['ivreg'] = point_df['ivreg'] = ivreg_type
                    point_dfs.append(point_df)
                    test_dfs.append(test_df_flat)
                except:
                    pass
                
point_df = pd.concat(point_dfs)
test_df = pd.concat(test_dfs)
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
point_df.to_csv('./results/all_point_est.csv')
test_df.to_csv('./results/all_tests.csv')

In [None]:
ss_DY = point_df[(np.sign(point_df.ci_lower) == np.sign(point_df.ci_upper)) & (np.abs(point_df.point)>.05)].D_Y.unique()
point_df[(np.sign(point_df.ci_lower) == np.sign(point_df.ci_upper)) & (np.abs(point_df.point)>.05)]


In [None]:
test_df = test_df.reindex(sorted(test_df.columns), axis=1)
test_df[test_df.D_Y.isin(ss_DY)] 

In [None]:
for D_label in ['Black', 'Female', 'Obese','Asian']:
    print(D_label)
    for Y_label in ['OA']:
        covrank_data = np.load(f'{SAVE_PATH}/ivreg=adv_dual=Z_D={D_label}_Y={Y_label}/covrank_test.npy')
        svalues, svalues_crit = covrank_data[:-1], covrank_data[-1]
        plt.title(f"D={D_label}_Y={Y_label}\nNumber of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
                  f"\nThreshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
        plt.scatter(np.arange(len(svalues)), svalues)
        plt.axhline(svalues_crit)
        plt.show()

In [None]:
SAVE_PATH = './results/'
from tqdm import tqdm 
for D_label in ['Black', 'Female', 'Obese', 'Asian']:
    print(D_label)
    for Y_label in tqdm(['OA', 'RA', 'myoc', 'copd', 'deprs', 'back']):
        W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)

        for dual_type in ['Z', 'Q']:
            for ivreg_type in ['adv']:
                save_dir = f'{SAVE_PATH}/ivreg={ivreg_type}_dual={dual_type}_D={D_label}_Y={Y_label}'
                print(save_dir)
                np.random.seed(4)

                est = ProximalDE(cv=3, semi=True, dual_type=dual_type, ivreg_type=ivreg_type,
                                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
                est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)    
                if
                diag = est.run_diagnostics()
                inds = est.influential_set(alpha=0.05)
                print(len(inds))
                diag.cookd_plot()
                plt.title(save_dir)
                plt.show()
                diag.l2influence_plot()
                plt.title(save_dir)
                plt.show()
                diag.influence_plot(influence_measure='cook', npoints=10)
                plt.title(save_dir)
                plt.show()
        print()
        print()

In [None]:
D_label = 'Asian'
Y_label = 'OA'
dual_type='Z'
ivreg_type = 'adv'
W, W_feats, X, X_feats, Z, Z_feats, Y, D = load_ukbb_data(D_label=D_label, Y_label=Y_label)
np.random.seed(4)
est = ProximalDE(cv=3, semi=True, dual_type=dual_type, ivreg_type=ivreg_type,
                                     multitask=False, n_jobs=-1, random_state=3, verbose=1)
est.fit(W, D, Z, X, Y, D_label=D_label, Y_label=Y_label)   

In [None]:
np.random.seed(4)
diag = est.run_diagnostics()
inds = est.influential_set(alpha=0.05)
len(inds)  # size of influential set that can flip the result

In [None]:
test_df[test_df.D_Y == 'Asian_OA']

In [None]:
from sklearn.base import clone
# let's re-train a clone of the estimator on all the data
# except the influential set
np.random.seed(4)
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0),D_label=D_label, Y_label=Y_label, save_fname_addn=f'_rmInf{dual_type}')
est2.summary(alpha=0.05)

In [None]:
diag.cookd_plot()
plt.show()

In [None]:
diag.l2influence_plot()
plt.show()

In [None]:
diag.influence_plot(influence_measure='cook', npoints=10)
plt.show()

In [None]:
diag.influence_plot(influence_measure='l2influence', npoints=10)
plt.show()

In [None]:
# tests can also be accessed individually
display(est.weakiv_test(alpha=0.05))
display(est.idstrength_violation_test(alpha=0.05))
display(est.primal_violation_test(alpha=0.05))
display(est.dual_violation_test(alpha=0.05))

#### Confidence Intervals and Robust Confidence Intervals

In [None]:
est.conf_int(alpha=.05) # 95% confidence interval

In [None]:
# 95% confidence interval, robust to weak identification
est.robust_conf_int(alpha=0.05, lb=.1, ub=1.0, ngrid=1000)

#### Unusual Data Diagnostics

In [None]:
diag = est.run_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)
len(inds)  # size of influential set that can flip the result

In [None]:
from sklearn.base import clone
# let's re-train a clone of the estimator on all the data
# except the influential set
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))
est2.summary(alpha=0.05)

In [None]:
diag.cookd_plot()
plt.show()

In [None]:
diag.l2influence_plot()
plt.show()

In [None]:
diag.influence_plot(influence_measure='cook', npoints=10)
plt.show()

In [None]:
diag.influence_plot(influence_measure='l2influence', npoints=10)
plt.show()

### Subsample-Based Inference

In [None]:
inf = est.bootstrap_inference(stage=3, n_subsamples=1000, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.vlines([inf.point], 0, 300, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=2, n_subsamples=100, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.vlines([inf.point], 0, 300, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=1, n_subsamples=10, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.vlines([inf.point], 0, 300, color='r')
plt.show()

In [None]:
inf.summary(pivot=True)

# Quality of Procedure and Diagnostics Across Many Experiments

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls
from proximalde.proximal import ProximalDE

In [None]:
def exp_res(it, n, pw, pz, px, a, b, c, d, e, f, g, sm, *,
            dual_type='Z', ivreg_type='adv', n_splits=5, semi=False,
            multitask=False, n_jobs=-1, verbose=0):
    np.random.seed(it)
    if pw > 0:
        # M is unobserved so we omit it from the return variables
        W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
        est = ProximalDE(cv=n_splits, semi=semi,
                         dual_type=dual_type, ivreg_type=ivreg_type,
                         multitask=multitask, n_jobs=n_jobs,
                         random_state=it, verbose=verbose)
        est.fit(W, D, Z, X, Y)
    else:
        _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
        est = ProximalDE(cv=n_splits, semi=semi,
                         dual_type=dual_type, ivreg_type=ivreg_type,
                         multitask=multitask, n_jobs=n_jobs,
                         random_state=it, verbose=verbose)
        est.fit(None, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .5  # this can be zero; does not hurt
e = .5  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt
sm = .35  # strength of mediator noise; needs to be non-zero for identifiability
n = 50000
pw = 0
pz, px = 5, 5

results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, n, pw, pz, px, a, b, c, d, e, f, g, sm,
                                                          dual_type='Z', ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

#### Summarize

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")

In [None]:
plt.hist(points_base, label='Distribution of Estimates: debiased')
plt.hist(points_alt, label='Distribution of Estimates: original', alpha=.3)
plt.vlines([c], 0, plt.ylim()[1], color='red', label='truth')
plt.legend()
plt.show()