## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

In [None]:
A = np.random.normal(0, 1, size=(3, 3))

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
import pytest
from joblib import Parallel, delayed
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from proximalde.gen_data import gen_data_complex, gen_data_no_controls
from proximalde.proximal import proximal_direct_effect, ProximalDE, estimate_nuisances, \
    estimate_final, second_stage, _check_input, residualizeW, _gen_subsamples
from proximalde.inference import NormalInferenceResults, pvalue
from proximalde.ivreg import Regularized2SLS
from proximalde.ivtests import weakiv_tests
from proximalde.tests.utilities import gen_iv_data
from sklearn.linear_model import RidgeCV, Ridge, LinearRegression, LassoCV
from sklearn.base import clone

```
a : strength of D -> M edge
b : strength of M -> Y edge
c : strength of D -> Y edge
d : strength of D -> Z edge
e : strength of M -> Z edge
f : strength of M -> X edge
g : strength of X -> Y edge
```

In [None]:
# W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

In [None]:
def exp_summary(it, n, pw, pz, px, a, b, c, d, e, f, g):
    np.random.seed(it)
    W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
    est = ProximalDE(dual_type='Z', cv=3, semi=True,
                     multitask=False, n_jobs=1, random_state=3, verbose=0)
    est.fit(W, D, Z, X, Y)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.stderr_, est.idstrength_, est.primal_violation_, est.dual_violation_, est.point_, lb, ub

def run_summary(n, pw, pz, px, a, b, c, d, e, f, g):
    res = np.array(Parallel(n_jobs=-1, verbose=3)(delayed(exp_summary)(it, n, pw, pz, px, a, b, c, d, e, f, g)
                                                  for it in range(100)))
    print(f"Mean estimate: {np.mean(res[:, 4]):.3f}")
    print(f"Bias: {np.mean(res[:, 4] - c):.3f}")
    print(f"RMSE: {np.sqrt(np.mean((res[:, 4] - c)**2)):.3f}")
    cov = np.mean((res[:, 4] + 1.96 * res[:, 0] >= c) & (res[:, 4] - 1.96 * res[:, 0] <= c))
    print(f"Coverage: {cov:.3f}")
    rcov = np.mean((res[:, 5] <= c) & (res[:, 6] >= c))
    print(f"ID-Robust Coverage: {rcov:.3f}")
    plt.figure(figsize=(10, 5))
    plt.subplot(2, 2, 1)
    plt.title(f"stderr: mean={np.mean(res[:, 0]):.3f}, %>1.0={np.mean(res[:, 0] > 2.0):.3f}")
    plt.hist(res[:, 0])
    plt.subplot(2, 2, 2)
    plt.title(f"idstrength: mean={np.mean(res[:, 1]):.3f}, %<3.85={np.mean(res[:, 1] < 3.85):.3f}")
    plt.hist(res[:, 1])
    plt.subplot(2, 2, 3)
    plt.title(f"primal_violation: mean={np.mean(res[:, 2]):.3f}, %>5.99={np.mean(res[:, 2] > 5.99):.3f}")
    plt.hist(res[:, 2])
    plt.subplot(2, 2, 4)
    plt.title(f"dual_violation: mean={np.mean(res[:, 3]):.3f}, %>3.85={np.mean(res[:, 3] > 3.85):.3f}")
    plt.hist(res[:, 3])
    plt.tight_layout()
    plt.show()

In [None]:
n = 10000
pw = 1
pz, px = 1, 1
n_splits = 3

## Failure mode 1

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has no relationship to M, 
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that even though both primal and dual have feasible solutions
# the identification is very weak, since Z is only driven by D and
# hence D can almost perfectly predict D. So the id_stregth test failed
# and caught this failure mode. Also confidence intervals are quite large
# so we are not artificially confident about a wrong result.

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

## Failure mode 2

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Q', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that in this case even though the idstrength test passed
# the dual violation test did not pass. In this case the dual problem
# does not admit a solution and this was detected by the dual_violation
# statistic which was very high

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

## Failure Mode 3

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has direct relationship to M, 
# X has no direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 1.0, 0.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Q', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that in this case even though the idstrength test passed
# and the dual violation test passed, the primal violation almost did
# not pass. This test should be catching this failure.

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

## Failure mode 4

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has no direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 0.0, 0.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Q', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
np.mean(est.Dres_ * est.Xres_), np.mean(est.Zres_ * est.Xres_), np.mean(est.Dres_ * est.Zres_)

In [None]:
est.eta_, est.gamma_

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

## Failure Mode 5

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has no direct relationship to M, X has direct relationship to Y
d, e, f, g = 0.0, 0.0, 0.0, 5.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Q', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

## No Failure Mode

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M, 
# X has direct relationship to M, X has direct relationship to Y
d, e, f, g = 1.0, 1.0, 1.0, 1.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g)

# OLD

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# Z has no relationship to M, but X has relationship to M
# D also has direct relationship to Z, X doesn't have direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that even though both primal and dual have feasible solutions
# the identification is very weak, since Z is only driven by D and
# hence D can almost perfectly predict D. So the id_stregth test failed
# and caught this failure mode.

In [None]:
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# Z has no relationship to M, but X has relationship to M
# D also has direct relationship to Z, X doesn't have direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
np.mean(est.Dres_ * est.Xres_)

In [None]:
np.mean(est.Zres_ * est.Xres_)

In [None]:
est.summary()

In [None]:
diag = est.run_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)

In [None]:
from sklearn.base import clone
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))

In [None]:
est2.summary(alpha=0.05)

#### Verifying we get the same as ivreg on the dataset in the documentation of ivreg

In [None]:
from proximalde.tests.utilities import gen_kmenta_data
Z, X, Y, labels, controls = gen_kmenta_data()

In [None]:
from proximalde.ivtests import weakiv_tests
weakiv_tests(Z, X, Y, controls=controls)

In [None]:
from proximalde.diagnostics import IVDiagnostics
diag = IVDiagnostics(add_constant=True).fit(Z, X, Y)

In [None]:
diag.influence_plot(labels=labels)
plt.show()