## Setup Only for Colab

In [None]:
def exp_summary_multidim(it, n, pm, pw, pz, px, a, b, c, d, E, F, g):
    np.random.seed(it)
    _, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pz, px, a, b, c, d, E, F, g, pm=pm)
    est = ProximalDE(cv=3, semi=True,
                     n_jobs=1, random_state=3, verbose=0)
    est.fit(None, D, Z, X, Y)
    return est.primal_violation_, est.dual_violation_


def test_multidim_mediator_violations_nominal_failure_prob():
    np.random.seed(12)
    pw = 1
    pm = 7
    for n, pz, px in [(10000, 20, 10), (10000, 80, 50)]:
        # Indirect effect is a*b, direct effect is c
        a, b, c = 1.0, 1.0, .5
        # D has direct relationship to Z, Z has no relationship to M,
        # X has direct relationship to M, X has no direct relationship to Y
        d, g = 0.0, 0.0
        full_rank = False
        while not full_rank:
            E = np.random.normal(0, 2, (pm, pz))
            F = np.random.normal(0, 2, (pm, px))
            if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
                full_rank = True

        res = np.array(Parallel(n_jobs=-1, verbose=3)(delayed(exp_summary_multidim)(it, n, pm, pw, pz, px,
                                                                                    a, b, c, d, E, F, g)
                                                      for it in range(100)))
        pval, dval = map(np.array, zip(*res))
        print(np.mean(dval > chi2(df=px).ppf(.95)))
        print(np.mean(pval > chi2(df=pz + 1).ppf(.95)))
        try:
            assert np.isclose(np.mean(dval > chi2(df=px).ppf(.95)), 0.05, atol=2e-2)
            assert np.isclose(np.mean(pval > chi2(df=pz + 1).ppf(.95)), 0.05, atol=3e-2)
        except:
            import ipdb; ipdb.set_trace()
test_multidim_mediator_violations_nominal_failure_prob()

In [None]:
def idstrength_violation_z(sm):
    for seed in range(10):
        np.random.seed(seed)
        errors = []
        strengths = []
        covs = []
        for _ in range(1):
            n = 100000
            pw = 1
            pz, px = 1, 1
            # Indirect effect is a*b, direct effect is c
            a = .7  # (2 * np.random.binomial(1, .5) - 1) * np.random.uniform(.5, 2)
            b = .8  # (2 * np.random.binomial(1, .5) - 1) * np.random.uniform(.5, 2)
            c = .5  # (2 * np.random.binomial(1, .5) - 1) * np.random.uniform(.5, 2)
            # D has direct relationship to Z, Z has no relationship to M,
            # X has direct relationship to M, X has no direct relationship to Y
            d = np.random.uniform(.5, 2)
            e = np.random.uniform(.5, 2)
            f = (2 * np.random.binomial(1, .5) - 1) * np.random.uniform(.5, 2)
            g = (2 * np.random.binomial(1, .5) - 1) * np.random.uniform(.5, 2)
            sz, sx, sy = np.random.uniform(.5, 2), np.random.uniform(.5, 2), np.random.uniform(.5, 2)
            W, D, _, Z, X, Y = gen_data_no_controls(n, pz, px, a, b, c, d, e, f, g, sm=sm, sz=sz, sx=sx, sy=sy)
            W=None
            D = D.reshape(-1, 1)
            D = D - D.mean(axis=0)
            sd = np.sqrt(np.mean(D**2))
            X = X - X.mean(axis=0)
            Z = Z - Z.mean(axis=0)
            Y = Y.reshape(-1, 1)
            Y = Y - Y.mean(axis=0)

            true_Zsq = (e * a + d)**2 * sd**2 + e**2 * sm**2 + sz**2
            print('Z**2', np.mean(Z**2), true_Zsq)
            assert np.allclose(np.mean(Z**2), true_Zsq, atol=7e-2)
            true_Msq = sm**2 + a**2 * sd**2
            true_XZ = f * (e * true_Msq + d * a * sd**2)
            print('X*Z', np.mean(X*Z), true_XZ)
            assert np.allclose(np.mean(X*Z), true_XZ, atol=7e-2)
            true_DZ = (a * e + d) * sd**2
            print('D*Z', np.mean(D * Z), true_DZ)
            assert np.allclose(np.mean(D * Z), true_DZ, atol=5e-2)
            true_DX = a * f * sd**2
            print('D*X', np.mean(D * X), true_DX)
            assert np.allclose(np.mean(D * X), true_DX, atol=5e-2)
            # D*X / X*Z
            true_gamma = true_DX / true_XZ
            print("gamma", true_gamma)
            # D^2 - gamma D * Z
            true_strength = sd**2 - true_gamma * true_DZ
            print("strength", true_strength)

            est = ProximalDE(cv=3, semi=True,
                             n_jobs=-1, random_state=3, verbose=0)
            est.fit(W, D, Z, X, Y)
            print('point, std', est.point_, est.stderr_)
    #         print(est.gamma_, true_gamma)
            assert np.allclose(est.gamma_, true_gamma, rtol=1e-1, atol=8e-2)
    #         print(np.mean(est.Dres_ * est.Dbar_), true_strength)
            assert np.allclose(np.mean(est.Dres_ * est.Dbar_), true_strength, rtol=1e-1, atol=5e-2)
            print("id strength", est.idstrength_, est.primal_violation_, est.dual_violation_)
            cov = (est.point_ - 2 * est.stderr_ <= c) & (est.point_ + 2 * est.stderr_ >= c)
    #         print(cov, est.idstrength_)
            assert cov or (est.idstrength_ < id_cv)
            error = np.abs(est.point_ - c)
            id_cv = scipy.stats.foldnorm(c=c / est.idstrength_std_, scale=est.idstrength_std_).ppf(1 - .05)
            try:
                assert (error < .4) or (est.idstrength_ < id_cv)
            except: 
                print(sm, error, est.idstrength_, )
                import ipdb; ipdb.set_trace()
            errors.append(error)
            strengths.append(est.idstrength_)
            covs.append(cov)
    return errors, strengths, covs
def test_strength_violation_z():
    errors = []
    strengths = []
    covs = []
    for sm in np.linspace(0, .5, 10):
        print(sm)
        e, s, c = idstrength_violation_z(sm)
        errors += e
        strengths += s
        covs += c
test_strength_violation_z()

In [None]:
from IPython.display import clear_output
from google.colab import drive
import time
drive.mount('/content/drive') # First mount drive
%cd /content/drive/MyDrive/Colab\ Notebooks

In [None]:
## Run if you haven't set up hidden_mediators
! git clone https://github.com/syrgkanislab/hidden_mediators
%cd hidden_mediators
! pip install -r requirements.txt
! python setup.py install
time.sleep(2)
clear_output()

In [None]:
## Run if you have already set up hidden_mediators
%cd hidden_mediators
time.sleep(2)
clear_output()

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_synthetic_data import gen_data_no_controls
from proximalde.proximal import ProximalDE

In [None]:
def exp_summary(it, n, pz, px, a, b, c, d, e, f, g, sm):
    np.random.seed(it)
    _, D, _, Z, X, Y = gen_data_no_controls(n, pz, px, a, b, c, d, e, f, g, sm=sm)
    est = ProximalDE(cv=3, semi=True, ivreg_type='adv', n_jobs=1, random_state=3, verbose=0)
    est.fit(None, D, Z, X, Y)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    weakiv_stat, _, _, _, pi, var_pi = est.weakiv_test(return_pi_and_var=True)
    eigs, eig_crit = est.covariance_rank_test(calculate_critical=True)
    maxeig = eigs[0]
    return est.stderr_, est.idstrength_, est.primal_violation_, est.dual_violation_, est.point_, \
        lb, ub, weakiv_stat, maxeig, pi, var_pi, *est.gamma_.flatten(), \
        *est.ivreg_gamma_.stderr_.flatten(), est.idstrength_std_, eig_crit

def run_summary(n, pz, px, a, b, c, d, e, f, g, sm):
    res = np.array(Parallel(n_jobs=-1, verbose=3)(delayed(exp_summary)(it, n, pz, px, a, b, c, d, e, f, g, sm)
                                                  for it in range(100)))
    print(f"Mean estimate: {np.mean(res[:, 4]):.3f}")
    print(f"Bias: {np.mean(res[:, 4] - c):.3f}")
    print(f"RMSE: {np.sqrt(np.mean((res[:, 4] - c)**2)):.3f}")
    cov = np.mean((res[:, 4] + 1.96 * res[:, 0] >= c) & (res[:, 4] - 1.96 * res[:, 0] <= c))
    print(f"Coverage: {cov:.3f}")
    rcov = np.mean((res[:, 5] <= c) & (res[:, 6] >= c))
    print(f"ID-Robust Coverage: {rcov:.3f}")
    plt.figure(figsize=(10, 10))
    plt.subplot(3, 2, 1)
    plt.title(f"stderr: mean={np.mean(res[:, 0]):.3f}\n%>1.0={np.mean(res[:, 0] > 1.0):.3f}")
    plt.hist(res[:, 0])
    plt.subplot(3, 2, 2)
    crit = np.round(scipy.stats.foldnorm(c=10).ppf(.95), 2)
    plt.title(f"idstrength: mean={np.mean(res[:, 1] / res[:, 11 + 2*pz]):.3f}\n%Fail (<{crit})={np.mean(res[:, 1] / res[:, 11 + 2*pz] < crit):.3f}")
    plt.hist(res[:, 1] / res[:, 11 + 2*pz])
    plt.subplot(3, 2, 3)
    crit = np.round(scipy.stats.chi2(df=pz + 1).ppf(.95), 2)
    plt.title(f"primal_violation: mean={np.mean(res[:, 2]):.3f}\n%Fail (>{crit})={np.mean(res[:, 2] > crit):.3f}")
    plt.hist(res[:, 2])
    plt.subplot(3, 2, 4)
    crit = np.round(scipy.stats.chi2(df=px).ppf(.95), 2)
    plt.title(f"dual_violation: mean={np.mean(res[:, 3]):.3f}\n%Fail (>{crit})={np.mean(res[:, 3] > crit):.3f}")
    plt.hist(res[:, 3])
    plt.subplot(3, 2, 5)
    plt.title(f"weakiv_test: mean={np.mean(res[:, 7]):.3f}\n%Fail (<23)={np.mean(res[:, 7] < 23):.3f}")
    plt.hist(res[:, 7])
    plt.subplot(3, 2, 6)
    eig_stat = res[:, 8] / res[:, 11+2*pz+1]
    plt.title(f"rank_test: mean={np.mean(eig_stat):.3f}\n%Fail (<1)={np.mean(eig_stat < 1):.3f}")
    plt.hist(eig_stat)
    plt.tight_layout()
    plt.show()


```
a : strength of D -> M edge
b : strength of M -> Y edge
c : strength of D -> Y edge
d : strength of D -> Z edge
e : strength of M -> Z edge
f : strength of M -> X edge
g : strength of X -> Y edge
```

## No Failure Mode

In [None]:
np.random.seed(1)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M,
# X has direct relationship to M, X has direct relationship to Y (but doesn't matter w/ validity of result)
d, e, f, g = 1.0, 1.0, 1.0, 1.0
sm = 2.0
run_summary(n, pz, px, a, b, c, d, e, f, g, sm)

## Failure mode 1: Z is a bad proxy (no M->Z)

Z has a direct relationship to D but no direct relationship to M. So Z is not a good proxy treatment. However, X is a good proxy outcome and this can allow us to detect the failure mode. 

Here the fact that Z is correlated with D makes the dual assumption not be violated, and the primal should pass as X is a good proxy outcome. However, the solution basically leads to an orthogonal instrument V=D-gammaZ that perfectly predicts treatment D (since Z is only driven by D) and hence the identification strength assumption is violated. Thus we expect the id_stregth test should fail.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has no relationship to M,
# X has direct relationship to M, X has direct relationship to Y (but doesn't matter w/ validity of result)
d, e, f, g = 1, 0, 1, 0
sm = 2.
n = 10000
px, pz = 5, 5
run_summary(n, pz, px, a, b, c, d, e, f, g, sm=sm)

## Failure mode 2: Z is a bad proxy (no M->Z or D->Z)

Z has no direct relationship to D or M. So Z is not a good proxy treatment. However, X is a good proxy outcome and this can allow us to detect the failure mode. 

Unlike failure mode 1, Z is uncorrelated with D, which makes the dual assumption violated. Because the dual doesn't have a solution, the weakIV test is invalid, as the covariance calculation depends on the validity of the dual moment. We also expect the rank covariance test to fail as there is no relationship between Z and X. 

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M,
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 0.0, 1.0, 0.0
sm = 2.0
run_summary(n, pz, px, a, b, c, d, e, f, g, sm)

## Failure Mode 3: X is a bad proxy (no M->X or X->Y)

Z is a good proxy, but X is not and is unrelated to M. This makes the existence of a solution to the primal IV assumption violated. We also expect the rank covariance test to fail as there is no relationship between Z and X. 

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has direct relationship to M,
# X has no direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 1.0, 0.0, 0.0
sm = 2.0
run_summary(n, pz, px, a, b, c, d, e, f, g, sm)

## Failure mode 4: X and Z are completely unrelated


Z and X are unrelated to all other variables (both are bad proxies). We expect the covariance rank test to can catch this failure mode.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has no direct relationship to Z, Z has no direct relationship to M,
# X has no direct relationship to M, X has no direct relationship to Y (but doesn't affect results if it does; rank test will still catch the failure)
d, e, f, g = 0.0, 0.0, 0.0, 0.0
sm = 2.0
run_summary(n, pz, px, a, b, c, d, e, f, g, sm)

## Failure Mode 5: M ~= D.

X, Z are good proxies, but the mediator is super correlated with the treatment, which leads to lack of identification of the direct effect. The weakIV and idstrength tests will catch this case.

In [None]:
np.random.seed(123)
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M,
# X has direct relationship to M, X has direct relationship to Y
d, e, f, g = 1.0, 1.0, 1.0, 1.0
sm = 0.05
run_summary(n, pz, px, a, b, c, d, e, f, g, sm)