## Setup Only for Colab

In [None]:
from IPython.display import clear_output
from google.colab import drive
import time
drive.mount('/content/drive') # First mount drive
%cd /content/drive/MyDrive/Colab\ Notebooks

In [None]:
## Run if you haven't set up hidden_mediators
! git clone https://github.com/syrgkanislab/hidden_mediators
%cd hidden_mediators
! pip install -r requirements.txt
! python setup.py install
time.sleep(2)
clear_output()

In [None]:
## Run if you have already set up hidden_mediators
%cd hidden_mediators
time.sleep(2)
clear_output()

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from proximalde.gen_synthetic_data import gen_data
from proximalde.proximal import ProximalDE

# Running a Single Experiment

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = 0  # this can be zero; does not hurt
e = .5  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = 0  # this can be zero; does not hurt
n = 50000 # number of samples 
pw = 0 # dimension of controls / confounders
pz = 5 # dimension of Z
px = 5 # dimension of X
pm = 1 # dimension of the mediator M; should not be more than max(pz,px)
sm = 1.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.

In [None]:
W, X, Z, D, Y = gen_data(a, b, c, d, e, f, g, pm, pz, px, pw, n, sm=sm, seed=42)

### Using the ProximalDE Estimator Class

In [None]:
# Classification model used when regressing W on binary variables . 
#        If binary variables in Z or X exist, need to specify indices.
# Regression model used when regressing W on continuous variables. 
# For both, Options = 'xgb', 'linear', or a custom model (see custom_regression_models.ipynb)
est = ProximalDE(model_regression='xgb', model_classification='xgb', binary_Z=[], binary_X=[], 
                 ivreg_type='adv', semi=True, cv=3, random_state=4)
est.fit(W, D, Z, X, Y)

In [None]:
# When no violations, the point estimate is accurate and all our tests pass 
est.summary(decimals=5)

#### Covariance Rank Diagnostic for Covariance of Proxies

In [None]:
svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)

In [None]:
plt.title(f"Number of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
          f"Threshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
plt.scatter(np.arange(len(svalues)), svalues)
plt.axhline(svalues_crit)
plt.show()

#### Confidence Intervals and Robust Confidence Intervals

In [None]:
est.conf_int(alpha=.05) # 95% confidence interval

In [None]:
# 95% confidence interval, robust to weak identification
est.robust_conf_int(alpha=0.05, lb=.1, ub=1.0, ngrid=1000)

#### Influence Data Diagnostics
In addition to the influence score as presented in the paper, we also provide tools for computing 
other commonly-used influence points, like Cook's distance and L2 influence. 

In [None]:
diag = est.run_influence_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)
len(inds)  # size of influential set that can flip the result

In [None]:
from sklearn.base import clone
# let's re-train a clone of the estimator on all the data
# except the influential set; the effect should be ~0
est2 = clone(est)
if pw > 0:
    W_inf = np.delete(W, inds, axis=0)
else:
    W_inf = None
est2.fit(W_inf, np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))
est2.summary(alpha=0.05)

### Subsample-Based Inference
Bootstrapped estimation at various stages.
Confidence interval can either be computed using
pivot bootstrap CI or normal percentile CI

In [None]:
for stage, n_subsamples in zip([1, 2, 3],[10, 100, 1000]):
    if stage == 1 and pw == 0:
        print("Re-estimating at stage 1 equivalent to stage 2, as no controls exist. Skipping...")
        continue
    bs = est.bootstrap_inference(stage=stage, n_subsamples=n_subsamples, fraction=0.5, replace=False, verbose=3, random_state=123)
    display(bs.summary()) # percentile CI
    display(bs.summary(pivot=True)) # pivot bootstrap CI
    plt.hist(bs.point_dist)
    plt.axvline(bs.point, color='r')
    plt.title(f"Bootstrap inference re-estimated at stage {stage} with {n_subsamples}")
    plt.show()

# Quality of Procedure and Diagnostics Across Many Experiments

In [None]:
def run_experiment(seed, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm, *,
            ivreg_type='adv', n_splits=3, semi=True,
            n_jobs=-1, verbose=0):
    """
    n: number of samples
    pw: dimension of controls
    pm: dimension of mediator
    pz: dimension of treatment proxies ("instruments")
    px: dimension of outcome proxies ("treatments")
    a : strength of D -> M edge
    b : strength of M -> Y edge
    c : strength of D -> Y edge
    d : strength of D -> Z edge
    e : strength of M -> Z edge
    f : strength of M -> X edge
    g : strength of X -> Y edge
    sm : scale of noise of M
    ivreg_type : how to estimate nuisance params. 
        Options are 'adv' = adversarial or '2sls'. 
    n_splits : number of cross validation splits 
    semi : flag to use semi-cross fitting
    n_jobs : Number of jobs to parallelize regressing W over.
    verbose : Verbosityt. 0 = False, >0 denotes frequency of 
        print statements for parallellization
    """
    W, X, Z, D, Y = gen_data(a, b, c, d, e, f, g, pm, pz, px, pw, n, sm, seed=seed)
    est = ProximalDE(model_regression='linear', model_classification='linear', 
                     binary_Z=[], binary_X=[],
                     cv=n_splits, semi=semi, binary_D=True,
                     ivreg_type=ivreg_type,
                     n_jobs=n_jobs, random_state=seed, verbose=verbose)
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = 0  # this can be zero; does not hurt
e = .1  # if the product of e*f is small, then we have a weak instrument
f = .1  # if the product of e*f is small, then we have a weak instrument
g = 0  # this can be zero; does not hurt
n = 50000 # number of samples 
pw = 10 # dimension of controls / confounders
pz = 10 # dimension of Z
px = 10 # dimension of X
pm = 1 # dimension of the mediator M; should not be more than max(pz,px)
sm = 2.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.

results = Parallel(n_jobs=-1, verbose=3)(delayed(run_experiment)(i, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm,
                                                          ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

#### Summarize

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")

In [None]:
plt.hist(points_base, label='Distribution of Estimates: debiased')
plt.hist(points_alt, label='Distribution of Estimates: original', alpha=.3)
plt.vlines([c], 0, plt.ylim()[1], color='red', label='truth')
plt.legend()
plt.show()

In [None]:
from statsmodels.graphics.gofplots import qqplot
import scipy.stats
plt.figure(figsize=(15, 5))
ax = plt.subplot(1, 2, 1)
qqplot(np.array(dval), dist=scipy.stats.chi2(df=px), line='45', ax=ax)
ax = plt.subplot(1, 2, 2)
qqplot(np.array(pval), dist=scipy.stats.chi2(df=pz+1), line='45', ax=ax)
plt.show()