## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from proximalde.crossfit import fit_predict

# Running a Single Experiment

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

In [None]:
n = 100000
pw = 100
pz, px = 2, 2

In [None]:
np.random.seed(2)
W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

## for no controls un-comment this
# _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

## for multi-dimensional mediator uncomment this
# pm = 5
# full_rank = False
# while not full_rank:
#     E = np.random.normal(0, 2, (pm, pz))
#     F = np.random.normal(0, 2, (pm, px))
#     if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
#         full_rank = True
# W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)

### Using the ProximalDE Estimator Class

In [None]:
est = ProximalDE(cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)

In [None]:
est.summary(decimals=5)

In [None]:
# tests can also be accessed individually
display(est.weakiv_test(alpha=0.05))
display(est.idstrength_violation_test(alpha=0.05))
display(est.primal_violation_test(alpha=0.05))
display(est.dual_violation_test(alpha=0.05))

#### Covariance Rank Diagnostic for Covariance of Proxies

In [None]:
svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)

In [None]:
plt.title(f"Number of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
          f"Threshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
plt.scatter(np.arange(len(svalues)), svalues)
plt.axhline(svalues_crit)
plt.show()

#### Confidence Intervals and Robust Confidence Intervals

In [None]:
est.conf_int(alpha=.05) # 95% confidence interval

In [None]:
# 95% confidence interval, robust to weak identification
est.robust_conf_int(alpha=0.05, lb=.1, ub=1.0, ngrid=1000)

#### Unusual Data Diagnostics

In [None]:
diag = est.run_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)
len(inds)  # size of influential set that can flip the result

In [None]:
from sklearn.base import clone
# let's re-train a clone of the estimator on all the data
# except the influential set
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))
est2.summary(alpha=0.05)

In [None]:
diag.cookd_plot()
plt.show()

In [None]:
diag.l2influence_plot()
plt.show()

In [None]:
diag.influence_plot(influence_measure='cook', npoints=10)
plt.show()

In [None]:
diag.influence_plot(influence_measure='l2influence', npoints=10)
plt.show()

### Subsample-Based Inference

In [None]:
inf = est.bootstrap_inference(stage=3, n_subsamples=1000, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.axvline(inf.point, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=2, n_subsamples=100, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.axvline(inf.point, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=1, n_subsamples=10, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.vlines([inf.point], 0, 300, color='r')
plt.show()

In [None]:
inf.summary(pivot=True)

# Quality of Procedure and Diagnostics Across Many Experiments

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_no_controls_discrete_m
from proximalde.proximal import ProximalDE

In [None]:
def exp_res(it, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm, *,
            dual_type='Z', ivreg_type='adv', n_splits=3, semi=True,
            multitask=False, n_jobs=-1, verbose=0):
    np.random.seed(it)
    if pm > 1:
        full_rank = False
        while not full_rank:
            E = np.random.normal(0, 2, (pm, pz))
            F = np.random.normal(0, 2, (pm, px))
            if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
                full_rank = True
        W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)
        if pw == 0:
            W = None
    elif pw > 0:
        # M is unobserved so we omit it from the return variables
        W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
    else:
        _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
        W = None

    est = ProximalDE(cv=n_splits, semi=semi,
                    #  alpha_multipliers=np.array([1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1]),
                     dual_type=dual_type, ivreg_type=ivreg_type,
                     multitask=multitask, n_jobs=n_jobs,
                     random_state=it, verbose=verbose)
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

```
a : strength of D -> M edge
b : strength of M -> Y edge
c : strength of D -> Y edge
d : strength of D -> Z edge
e : strength of M -> Z edge
f : strength of M -> X edge
g : strength of X -> Y edge
```

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .5  # this can be zero; does not hurt
e = .5  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .5  # this can be zero; does not hurt
sm = 2.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.
n = 50000
pw = 0
pm = 1
pz, px = 20, 10

results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm,
                                                          dual_type='Z', ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

#### Summarize

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")

In [None]:
import scipy.stats
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.title(f"{np.mean(dval > scipy.stats.chi2(df=px).ppf(.95))} vs 0.05, {np.mean(dval, axis=0):.3f} vs {px}, {np.var(dval, axis=0):.3f} vs {2*px}")
plt.hist(dval)
plt.axvline(scipy.stats.chi2(df=px).ppf(.95), color='r')
plt.subplot(1, 2, 2)
plt.title(f"{np.mean(pval > scipy.stats.chi2(df=pz + 1).ppf(.95))} vs 0.05, "
          f"{np.mean(pval, axis=0):.3f} vs {pz + 1}, {np.var(pval, axis=0):.3f} vs {2*(pz + 1)}")
plt.hist(pval)
plt.axvline(scipy.stats.chi2(df=pz + 1).ppf(.95), color='r')
plt.show()

In [None]:
from statsmodels.graphics.gofplots import qqplot
import scipy.stats
plt.figure(figsize=(15, 5))
ax = plt.subplot(1, 2, 1)
qqplot(np.array(dval), dist=scipy.stats.chi2(df=px), line='45', ax=ax)
ax = plt.subplot(1, 2, 2)
qqplot(np.array(pval), dist=scipy.stats.chi2(df=pz+1), line='45', ax=ax)
plt.show()

In [None]:
import scipy.stats
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.title(f"{np.mean(idstr / idstr_crit > 1)} vs. 0.05, {np.mean(idstr / idstr_crit, axis=0):.3f}")
plt.hist(idstr)
plt.axvline(np.mean(idstr_crit), color='r')
plt.subplot(1, 2, 2)
plt.title(f"{np.mean(wiv_stat):.3f}, {np.mean(wiv_stat / wiv_crit, axis=0):.3f}")
plt.hist(wiv_stat)
plt.axvline(np.mean(wiv_crit), color='r')
plt.show()

In [None]:
plt.hist(points_base, label='Distribution of Estimates: debiased')
plt.hist(points_alt, label='Distribution of Estimates: original', alpha=.3)
plt.vlines([c], 0, plt.ylim()[1], color='red', label='truth')
plt.legend()
plt.show()

# Mediations that Trigger Violations of Both Assumptions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE

In [None]:
def gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g, *, sm=2, sz=1, sx=1, sy=1):
    ''' Controls are generated but are irrelevant to the rest
    of the data. We now also have mediation paths:
        D -> Mp -> X
        Z -> Mpp -> Y
    Such paths violate the assumptions required for the method to work. The
    mediator Mp can trigger a violation of the dual test, and the mediator Mpp
    can trigger a violation of the primal test.

    n: number of samples
    pw: dimension of controls
    pz: dimension of treatment proxies ("instruments")
    px: dimension of outcome proxies ("treatments")
    a : strength of D -> M edge
    b : strength of M -> Y edge
    c : strength of D -> Y edge
    d : strength of D -> Z edge
    e : strength of M -> Z edge
    f : strength of M -> X edge
    g : strength of X -> Y edge
    '''
    W = np.random.normal(0, 1, size=(n, pw))
    D = np.random.binomial(1, .5 * np.ones(n,))
    M = a * D + sm * np.random.normal(0, 1, (n,))
    Mp = a * D + sm * np.random.normal(0, 1, (n,))

    Z = np.zeros((n, pz))
    Z = (e * M + d * D).reshape(-1, 1) + sz * np.random.normal(0, 1, (n, pz))

    X = np.zeros((n, px))
    X[:, 0] = f * M + sx * np.random.normal(0, 1, (n))
    X[:, 1:] = f * Mp.reshape(-1, 1)

    Mpp = Z[:, 0] + sm * np.random.normal(0, 1, (n,))
    Y = b * M + b * Mpp + c * D + g * X[:, 0] + sy * np.random.normal(0, 1, n)
    return W, D, M, Z, X, Y

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 2, 2

In [None]:
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g)
W = None

In [None]:
est = ProximalDE(cv=3, semi=True,
                 multitask=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
est.summary()

## Explanation

For the primal moment to hold, we essentially need that:
\begin{equation}
\text{Cov}(Y, DZ) \in \text{column-span}(\text{Cov}(DZ, DX))
\end{equation}
where $DZ, DX$ are the concatenation of $D$ with $Z$ and $X$ correspondingly. Roughly this should be satisfied if:
\begin{equation}
\text{Cov}(Y, Z) \in \text{column-span}(\text{Cov}(Z, X))
\end{equation}


For the dual moment to hold, we essentially need that:
\begin{equation}
\text{Cov}(D, X) \in \text{column-span}(\text{Cov}(X, Z)) = \text{row-span}(\text{Cov}(Z, X))
\end{equation}

Let's verify that this is indeed not the case

In [None]:
Z = Z - np.mean(Z, axis=0)
X = X - np.mean(X, axis=0)
D = D - np.mean(D, axis=0)
Y = Y - np.mean(Y, axis=0)
D = D.reshape(-1, 1)
Y = Y.reshape(-1, 1)

Let's calculate the three relevant covariances:

In [None]:
CovZX = Z.T @ X / n
CovXD = X.T @ D / n
CovYZ = Z.T @ Y / n

Let's investigate the condition for the existence of the dual, so we need $Cov(X,D)$ to be in the row span of $Cov(Z,X)$, equivalently, column span of $Cov(X, Z)$. We perform a singular value decomposition and take only the significant non-zero eigenvalues. This can be done by using the critical value computed by the "covariance_rank_test"

In [None]:
_, Scrit = est.covariance_rank_test(calculate_critical=True)
Scrit

In [None]:
U, S, Vh = np.linalg.svd(CovZX, full_matrices=False)

In [None]:
# row span of CovZX, with stat-sig non-zero singular values, vs CovXD
print("Basis of row span of CovZX:\n", Vh[:, S > Scrit], "\n",
      "Vector CovXD:\n", CovXD)

We see that $CovXD\approx (0.12, 0.12)$, while the row span of CovZX is the subspace spanned by approximately the single vector $(-1, 0)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.

---



Let's examine the primal existence:

In [None]:
# column span of CovZX, with stat-sig non-zero singular values, vs CovXD
print("Basis of column span of CovZX:\n", U[:, S > Scrit], "\n",
      "Vector CovYZ:\n", CovYZ)

We see that $CovYZ\approx (8, 7)$, while the column span of CovZX is the subspace spanned by the single vector $(-.7, -.7)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.



# Semi-Synthetic Data Generation (in-progress)



In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE, residualizeW

Suppose we are given some real-world dataset $(W, D, Z, X, Y)$

In [None]:
def gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g, *, sm=2, sz=1, sx=1, sy=1):
    ''' Controls are generated but are irrelevant to the rest
    of the data. We now also have mediation paths:
        D -> Mp -> X
        Z -> Mpp -> Y
    Such paths violate the assumptions required for the method to work. The
    mediator Mp can trigger a violation of the dual test, and the mediator Mpp
    can trigger a violation of the primal test.

    n: number of samples
    pw: dimension of controls
    pz: dimension of treatment proxies ("instruments")
    px: dimension of outcome proxies ("treatments")
    a : strength of D -> M edge
    b : strength of M -> Y edge
    c : strength of D -> Y edge
    d : strength of D -> Z edge
    e : strength of M -> Z edge
    f : strength of M -> X edge
    g : strength of X -> Y edge
    '''
    W = np.random.normal(0, 1, size=(n, pw))
    D = np.random.binomial(1, .5 * np.ones(n,))
    M = a * D + sm * np.random.normal(0, 1, (n,))
    Mp = a * D + sm * np.random.normal(0, 1, (n,))

    Z = np.zeros((n, pz))
    Z = (e * M + d * D).reshape(-1, 1) + sz * np.random.normal(0, 1, (n, pz))

    X = np.zeros((n, px))
    X[:, 0] = f * M + sx * np.random.normal(0, 1, (n))
    X[:, 1:] = f * Mp.reshape(-1, 1)

    Mpp = Z[:, 0] + sm * np.random.normal(0, 1, (n,))
    Y = b * M + b * Mpp + c * D + g * X[:, 0] + sy * np.random.normal(0, 1, n)
    return W, D, M, Z, X, Y

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 5, 3

In [None]:
np.random.seed(124)
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g)
W = None

In [None]:
# as we said this dgp violates the assumptions so the test will fail
est = ProximalDE(cv=3, semi=True,
                #  alpha_multipliers=np.array([1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1]),
                 multitask=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
est.summary()

## Semi-Synthetic Generation Process
We will create a semi-synthetic DGP as follows. We first find the top component of the covariance of $(Z - E[Z|W], X - E[X|W])$, by running a singular value decomposition.

We then have a direction $u_z$ and a direction $v_x$ that is associated with this principal component. We will assume that the mediator basically alters $Z$ and $X$ along this top component.

So if the value of the mediator is $\tilde{M}$, then we set:
\begin{equation}
\tilde{Z} = Z + \tilde{M} \cdot u_z\\
\tilde{X} = X + \tilde{M} \cdot v_x
\end{equation}
for each sample in the original dataset. Now we have samples of $X, Z$ whose principal component is moderated by $M$.

Moreover, we generate the mediator as follows. We learn a propensity model $E[D|W]$ and we generate a treatment $D$ by sampling from the propensity. Then we  set the value of the mediator to be:
\begin{equation}
\tilde{M} = f_M(D, \epsilon_M)
\end{equation}
so that the mediator is affected by the treatment. Then we also impute outcomes:
\begin{equation}
\tilde{Y} = f_Y(\tilde{M}, D, \tilde{X}, \epsilon_Y)
\end{equation}
The reason why we want to resample D is to break any violating mediation paths $D->Mp->X$ that might exist in the original data, which would create a failure in this new dataset, if we didn't resample the treatment. We could try not resampling the treatment. If then we get a failure of the dual violation, then this hints at an auxiliary violating mediation path $D->Mp->X$.

Now for every sample $(W, D, Z, X, Y)$ in the original dataset, we now have a sample $(W, \tilde{D}, \tilde{Z}, \tilde{X}, \tilde{Y})$, where $W$ is real, $\tilde{D}$ is sampled from the estimated propensity, given $W$, and $\tilde{Z}, \tilde{X}$ are slight modifications of the real $X,Z$ along only a particular direction and $\tilde{Y}$ is fully synthetic.


For simplicity, we will first choose linear structural functions $f_M$ and $f_Y$:
\begin{align}
f_M(D, \epsilon_M) =& a \cdot D + N(0, \sigma_m^2)\\
f_Y(M, D, X, \epsilon_Y) =& b M + c D + g X[:, 0] + N(0, \sigma_y^2)
\end{align}



In [None]:
sm = 2.0
sy = 1.0

In [None]:
Dres, Zres, Xres, Yres, *_ = residualizeW(Wtrain, Dtrain, Ztrain, Xtrain, Ytrain, semi=True)

In [None]:
# The original covariance is U @ diag(S) @ Vh.T
U, S, Vh = np.linalg.svd(Zres.T @ Xres / n, full_matrices=False)
# so the columnns of U are the column eigenvectors and the columns of Vh are the
# row eigenvectors. Hence, uz = U[:, 0] and vx = Vh[:, 0]
U, S, Vh

In [None]:
np.random.seed(123)
nsamples = len(test) // 2
inds = np.random.choice(len(test), size=nsamples, replace=True)
Wtilde = Wtest[inds] if Wtest is not None else None
Dtilde = Dtest[inds]
# first let's not resample the treatment
Mtilde = a * Dtilde.flatten() + sm * np.random.normal(0, 1, (nsamples,))
Ztilde = Ztest[inds] + 2 * S[0] * Mtilde.reshape(-1, 1) * U[:, 0].reshape(1, -1)
Xtilde = Xtest[inds] + 2 * S[0] * Mtilde.reshape(-1, 1) * Vh[0, :].reshape(1, -1)
Ytilde = b * Mtilde + c * Dtilde + g * Xtilde[:, 0] + sy * np.random.normal(0, 1, (nsamples,))

In [None]:
# we find that the dual violation still exists, causing a slight bias (the true
# value we should recover is c)
est = ProximalDE(cv=3, semi=True,
                 alpha_multipliers=np.array([1e-3, 1e-2, 1e-1, 1]),
                 multitask=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(Wtilde, Dtilde, Ztilde, Xtilde, Ytilde)
est.summary()

In [None]:
propensity = LogisticRegressionCV().fit(Wtrain, Dtrain).predict(Wtest) if W is not None\
    else np.mean(D) * np.ones(len(test),)

In [None]:
from sklearn.linear_model import LogisticRegressionCV
# np.random.seed(189)
# now let's resample the treatment

nsamples = len(test) # // 2
inds = np.random.choice(len(test), size=nsamples, replace=True)
inds = np.arange(nsamples)
Wtilde = Wtest[inds] if Wtest is not None else None
Dtilde = np.random.binomial(1, propensity[inds])
Mtilde = a * Dtilde.flatten() + sm * np.random.normal(0, 1, (nsamples,))
Ztilde = Ztest[inds] + 10 * Mtilde.reshape(-1, 1) * U[:, 0].reshape(1, -1)
Xtilde = Xtest[inds] +  10 * Mtilde.reshape(-1, 1) * Vh[0, :].reshape(1, -1)
Ytilde = b * Mtilde + c * Dtilde + g * Xtilde[:, 0] + sy * np.random.normal(0, 1, (nsamples,))

# we find that the dual violation still exists, causing a slight bias (the true
# value we should recover is c)
est = ProximalDE(cv=3, semi=True,
                #  alpha_multipliers=np.array([1e-3, 1e-2, 1e-1, 1]),
                 multitask=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, Dtilde, Ztilde, Xtilde, Ytilde)
est.summary()

In [None]:
def gen_semisynth_data(nsamples, a, b, c, e, f, g, *, sm=2.0, sy=1.0):
    assert nsamples < len(test), "nsamples must be smaller than n"

    inds = np.random.choice(len(test), size=nsamples, replace=True)
    Wtilde = Wtest[inds] if Wtest is not None else None
    Dtilde = np.random.binomial(1, propensity[inds])
    Mtilde = a * Dtilde.flatten() + sm * np.random.normal(0, 1, (nsamples,))
    Ztilde = Ztest[inds] + 2 * S[0] * Mtilde.reshape(-1, 1) * U[:, 0].reshape(1, -1)
    Xtilde = Xtest[inds] + 2 * S[0] * Mtilde.reshape(-1, 1) * Vh[0, :].reshape(1, -1)
    Ytilde = b * Mtilde + c * Dtilde + g * Xtilde[:, 0] + sy * np.random.normal(0, 1, (nsamples,))
    return W, Dtilde, Mtilde, Ztilde, Xtilde, Ytilde


def exp_res(it, n, a, b, c, e, f, g, *, sm=2.0, sy=1.0,
            dual_type='Z', ivreg_type='adv', n_splits=3, semi=True,
            multitask=False, n_jobs=-1, verbose=0):
    np.random.seed(it)

    # M is unobserved so we omit it from the return variables
    W, D, _, Z, X, Y = gen_semisynth_data(n, a, b, c, e, f, g, sm=sm, sy=sy)

    est = ProximalDE(cv=n_splits, semi=semi,
                     alpha_multipliers=np.array([1e-3, 1e-2, 1e-1, 1]),
                     dual_type=dual_type, ivreg_type=ivreg_type,
                     multitask=multitask, n_jobs=n_jobs,
                     random_state=it, verbose=verbose)
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
e = 1.0  # if the product of e*f is small, then we have a weak instrument
f = 1.0  # if the product of e*f is small, then we have a weak instrument
g = .5  # this can be zero; does not hurt
sm = 2.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.
nsamples = 40000

results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, nsamples, a, b, c, e, f, g, sm=sm,
                                                          dual_type='Z', ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")