## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

In [None]:
A = np.random.normal(0, 1, size=(3, 3))

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
import pytest
from joblib import Parallel, delayed
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from proximalde.gen_data import gen_data_complex, gen_data_no_controls
from proximalde.proximal import proximal_direct_effect, ProximalDE, estimate_nuisances, \
    estimate_final, second_stage, _check_input, residualizeW, _gen_subsamples, RegularizedDualIVSolver
from proximalde.inference import NormalInferenceResults, pvalue
from proximalde.ivreg import Regularized2SLS, AdvIV
from proximalde.ivtests import weakiv_tests
from proximalde.tests.utilities import gen_iv_data
from sklearn.linear_model import RidgeCV, Ridge, LinearRegression, LassoCV
from sklearn.base import clone

```
a : strength of D -> M edge
b : strength of M -> Y edge
c : strength of D -> Y edge
d : strength of D -> Z edge
e : strength of M -> Z edge
f : strength of M -> X edge
g : strength of X -> Y edge
```

In [None]:
# W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

In [None]:
def true_params(pw, pz, px, a, b, c, d, e, f, g, sm, sz=1.0, sd=0.5):
    if pz == 1 and px == 1:
        true_Zsq = (e * a + d)**2 * sd**2 + e**2 * sm**2 + sz**2
        true_Msq = sm**2 + a**2 * sd**2
        true_XZ = f * (e * true_Msq + d * a * sd**2)
        true_DZ = (a * e + d) * sd**2
        true_DX = a * f * sd**2
        # D*X / X*Z
        if true_XZ == 0:
            true_gamma = np.inf
        else:
            true_gamma = true_DX / true_XZ
        # D^2 - gamma D * Z
        true_strength = sd**2 - true_gamma * true_DZ
        return true_gamma, true_strength
    else:
        raise AttributeError("Not available")

In [None]:
def exp_summary(it, n, pw, pz, px, a, b, c, d, e, f, g, sm):
    np.random.seed(it)
    W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
    est = ProximalDE(dual_type='Z', cv=3, semi=True, ivreg_type='adv',
                     multitask=False, n_jobs=1, random_state=3, verbose=0)
    est.fit(None, D, Z, X, Y)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    weakiv_stat, _, pi, var_pi = est.weakiv_test(return_pi_and_var=True)
    eigs, eig_crit = est.covariance_rank_test()
    maxeig = eigs[0]
    return est.stderr_, est.idstrength_, est.primal_violation_, est.dual_violation_, est.point_, \
        lb, ub, weakiv_stat, maxeig, pi, var_pi, *est.gamma_.flatten(), *est.ivreg_gamma_.stderr_.flatten(), est.idstrength_std_, eig_crit

def run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm):
    res = np.array(Parallel(n_jobs=-1, verbose=3)(delayed(exp_summary)(it, n, pw, pz, px, a, b, c, d, e, f, g, sm)
                                                  for it in range(100)))
    print(f"Mean estimate: {np.mean(res[:, 4]):.3f}")
    print(f"Bias: {np.mean(res[:, 4] - c):.3f}")
    print(f"RMSE: {np.sqrt(np.mean((res[:, 4] - c)**2)):.3f}")
    cov = np.mean((res[:, 4] + 1.96 * res[:, 0] >= c) & (res[:, 4] - 1.96 * res[:, 0] <= c))
    print(f"Coverage: {cov:.3f}")
    rcov = np.mean((res[:, 5] <= c) & (res[:, 6] >= c))
    print(f"ID-Robust Coverage: {rcov:.3f}")
    plt.figure(figsize=(10, 10))
    plt.subplot(3, 2, 1)
    plt.title(f"stderr: mean={np.mean(res[:, 0]):.3f}, %>1.0={np.mean(res[:, 0] > 2.0):.3f}")
    plt.hist(res[:, 0])
    plt.subplot(3, 2, 2)
    crit = np.round(scipy.stats.foldnorm(c=10).ppf(.95), 2)
    plt.title(f"idstrength: mean={np.mean(res[:, 1] / res[:, 11 + 2*pz]):.3f}, %<{crit}={np.mean(res[:, 1] / res[:, 11 + 2*pz] < crit):.3f}")
    plt.hist(res[:, 1] / res[:, 11 + 2*pz])
    plt.subplot(3, 2, 3)
    crit = np.round(scipy.stats.chi2(df=pz + 1).ppf(.95), 2)
    plt.title(f"primal_violation: mean={np.mean(res[:, 2]):.3f}, %>{crit}={np.mean(res[:, 2] > crit):.3f}")
    plt.hist(res[:, 2])
    plt.subplot(3, 2, 4)
    crit = np.round(scipy.stats.chi2(df=px).ppf(.95), 2)
    plt.title(f"dual_violation: mean={np.mean(res[:, 3]):.3f}, %>{crit}={np.mean(res[:, 3] > crit):.3f}")
    plt.hist(res[:, 3])
    plt.subplot(3, 2, 5)
    plt.title(f"weakiv_test: mean={np.mean(res[:, 7]):.3f}, %<23={np.mean(res[:, 7] < 23):.3f}")
    plt.hist(res[:, 7])
    plt.subplot(3, 2, 6)
    eig_stat = res[:, 8] / res[:, 11+2*pz+1]
    plt.title(f"rank_test: mean={np.mean(eig_stat):.3f}, %<1={np.mean(eig_stat < 1):.3f}")
    plt.hist(eig_stat)
    plt.tight_layout()
    plt.show()

    pi = res[:, 9]
    var_pi = n * res[:, 10]
    print(f"mean(pi) = {np.mean(pi):.3f}")
    print(f"True n * var(pi)={n * np.var(pi):.3f}")
    print(f"n * hatvar(pi): mean={np.mean(var_pi):.3f}, median={np.median(var_pi):.3f}, "
          f"(5%, 95%)=({np.percentile(var_pi, 5):.3f}, {np.percentile(var_pi, 95):.3f})"
          f"(1%, 99%)=({np.percentile(var_pi, 1):.3f}, {np.percentile(var_pi, 99):.3f})")

    gamma = res[:, 11:11+pz]
    var_gamma = n * (res[:, 11+pz:11+2*pz])**2
    print("mean(gamma) = ", np.round(np.mean(gamma, axis=0), 3))
    print("True n * var(gamma)=", np.round(n * np.var(gamma, axis=0), 3))
    print("n * hatvar(gamma): mean=", np.round(np.mean(var_gamma, axis=0), 3), "median=", np.round(np.median(var_gamma, axis=0), 3),
          "(5%, 95%)=(", np.round(np.percentile(var_gamma, 5, axis=0), 3), np.round(np.percentile(var_gamma, 95, axis=0),3), ")")

    idstrength = res[:, 1]
    var_idstrength = n * (res[:, 11+2*pz])**2
    print("mean(str) = ", np.round(np.mean(idstrength, axis=0), 3))
    print("True n * var(str)=", np.round(n * np.var(idstrength, axis=0), 3))
    print("n * hatvar(str): mean=", np.round(np.mean(var_idstrength, axis=0), 3), "median=", np.round(np.median(var_idstrength, axis=0), 3),
          "(5%, 95%)=(", np.round(np.percentile(var_idstrength, 5, axis=0), 3), np.round(np.percentile(var_idstrength, 95, axis=0),3), ")")

    if pz == 1 and px == 1:
        true_gamma, _ = true_params(pw, pz, px, a, b, c, d, e, f, g, sm)
        W, D, _, Z, X, Y = gen_data_no_controls(100000, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
        D = D.reshape(-1, 1)
        D = D - D.mean(axis=0)
        X = X - X.mean(axis=0)
        Z = Z - Z.mean(axis=0)
        print(f"true(pi)={LinearRegression(fit_intercept=False).fit(D - true_gamma * Z, D.flatten()).coef_[0]:.3f}")

In [None]:
def test_pi_and_var_pi():
    np.random.seed(123)
    n = 10000
    pw = 1
    pz, px = 3, 3
    n_splits = 3
    # Indirect effect is a*b, direct effect is c
    a, b, c = 1.0, 1.0, .5
    # D has direct relationship to Z, Z has direct relationship to M, 
    # X has direct relationship to M, X has direct relationship to Y
    d, e, f, g = 1.0, 1.0, 1.0, 1.0
    sm = 2.0
    res = np.array(Parallel(n_jobs=-1, verbose=3)(delayed(exp_summary)(it, n, pw, pz, px, a, b, c, d, e, f, g, sm)
                                                  for it in range(100)))

    pi = res[:, 9]
    var_pi = n * res[:, 10]
    print(np.mean(pi))
    print(n * np.var(pi), np.mean(var_pi))
    print(n * np.var(pi), np.percentile(var_pi, 1))
    print(n * np.var(pi), np.percentile(var_pi, 99))
    assert np.allclose(n * np.var(pi), np.mean(var_pi), atol=5e-3)
    assert np.allclose(n * np.var(pi), np.percentile(var_pi, 1), atol=5e-3)
    assert np.allclose(n * np.var(pi), np.percentile(var_pi, 99), atol=5e-3)

In [None]:
# test_pi_and_var_pi()

In [None]:
# exp_summary(0, n, pw, pz, px, a, b, c, d, e, f, g, sm)

In [None]:
n = 10000
pw = 1
pz, px = 20, 20
n_splits = 3
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
d, e, f, g = 1.0, 1.0, 1.0, 1.0
sm = 2.0
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

In [None]:
n = 10000
pw = 1
pz, px = 20, 20
n_splits = 3

## Failure mode 1

D has direct relationship to Z, but Z has no direct relationship from M. So Z is not a good proxy treatment. However, X is a good proxy outcome and this can allow us to detect the failure mode. Here the fact that Z is correlated with D, makes the dual assumption not be violated. However, the solution basically leads to an orthogonal instrument that perfectly predicts treatment and hence the identification strength assumption is violated.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has no relationship to M, 
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True, ivreg_type='adv',
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that even though both primal and dual have feasible solutions
# the identification is very weak, since Z is only driven by D and
# hence D can almost perfectly predict D. So the id_stregth test failed
# and caught this failure mode. Also confidence intervals are quite large
# so we are not artificially confident about a wrong result.

In [None]:
est.summary()

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## Failure mode 2

D has no direct relationship to Z, but Z has no direct relationship from M. So Z is not a good proxy treatment. However, X is a good proxy outcome and this can allow us to detect the failure mode. Her Z is un-correlated with D, which makes the dual assumption violated.

Because the dual doesn't have a solution, the weakIV test is invalid, as the covariance calculation depends on the validity of the dual moment.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 0.0, 1.0, 0.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that in this case even though the idstrength test passed
# the dual violation test did not pass. In this case the dual problem
# does not admit a solution and this was detected by the dual_violation
# statistic which was very high

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## Failure Mode 3


D is a good proxy outcome, but X is not and is un-related to M. This makes the existence of a solution to the primal IV assumption violated.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has direct relationship to M, 
# X has no direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 1.0, 0.0, 0.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that in this case even though the idstrength test passed
# and the dual violation test passed, the primal violation almost did
# not pass. This test should be catching this failure.

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## Failure mode 4


Z and X are unrelated to all other variables!

The rank test can catch this failure mode.

Sometimes the dual violation will catch this failure mode.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has no direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 0.0, 0.0, 0.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## Failure Mode 5

Almost the same as mode 4, but only X is related to Y

The rank test catches this failure mode.

Sometimes the primal/dual violation, catch this failure mode.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M, 
# X has no direct relationship to M, X has direct relationship to Y
d, e, f, g = 0.0, 0.0, 0.0, 5.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## Failure Mode 6

X, Z are good proxies, but the mediator is super correlated with the treatment, which leads to lack of identification of the direct effect.

The weakIV and idstrength tests will catch this case.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M, 
# X has direct relationship to M, X has direct relationship to Y
d, e, f, g = 1.0, 1.0, 1.0, 1.0
sm = 0.05
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

## No Failure Mode

In [None]:
np.random.seed(1)
n = 500000
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M, 
# X has direct relationship to M, X has direct relationship to Y
d, e, f, g = 1.0, 1.0, 1.0, 1.0
sm = 2.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(None, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
run_summary(n, pw, pz, px, a, b, c, d, e, f, g, sm)

# OLD

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# Z has no relationship to M, but X has relationship to M
# D also has direct relationship to Z, X doesn't have direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])
# we see that even though both primal and dual have feasible solutions
# the identification is very weak, since Z is only driven by D and
# hence D can almost perfectly predict D. So the id_stregth test failed
# and caught this failure mode.

In [None]:
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# Z has no relationship to M, but X has relationship to M
# D also has direct relationship to Z, X doesn't have direct relationship to Y
d, e, f, g = 1.0, 0.0, 1.0, 0.0
W, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
est = ProximalDE(dual_type='Z', cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=0)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0])
display(est.summary().tables[1])
display(est.summary().tables[2])

In [None]:
np.mean(est.Dres_ * est.Xres_)

In [None]:
np.mean(est.Zres_ * est.Xres_)

In [None]:
est.summary()

In [None]:
diag = est.run_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)

In [None]:
from sklearn.base import clone
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))

In [None]:
est2.summary(alpha=0.05)

#### Verifying we get the same as ivreg on the dataset in the documentation of ivreg

In [None]:
from proximalde.tests.utilities import gen_kmenta_data
Z, X, Y, labels, controls = gen_kmenta_data()

In [None]:
from proximalde.ivtests import weakiv_tests
weakiv_tests(Z, X, Y, controls=controls)

In [None]:
from proximalde.diagnostics import IVDiagnostics
diag = IVDiagnostics(add_constant=True).fit(Z, X, Y)

In [None]:
diag.influence_plot(labels=labels)
plt.show()