## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from ..proximalde.gen_synthetic_data import gen_data
from ..proximalde.gen_synthetic_data import SemiSyntheticGenerator
from ..proximalde.utilities import covariance

# Semi-Synthetic Generation Process
We will create a semi-synthetic DGP as follows. 


We first find the top component of the covariance of $(Z - E[Z|W], X - E[X|W])$ (i.e. Cov$(Z,X)$) by running a singular value decomposition (SVD). We can think of the statistically non-zero singular values of this covariance matrix as the latent factor model. Let the SVD decomposition of Cov$(Z, X)$ = $G \cdot S \cdot F'$. Then we see that if we assume the data is generated using the following structural equations:
\begin{align}
Z =& G M + \epsilon_Z \\
X =& F M + \epsilon_X
\end{align}
where $\epsilon_Z$ and $\epsilon_X$ are independent, and $G$ and $F$ are the
eigenvectors found by the SVD, then equivalently
\begin{align}
\text{Cov}(Z, X) = G E[MM'] F' 
\end{align}
Thus, if $E[MM'] = S = diagonal(s_1, ..., s_K)$, then the covariance of the $Z,X$ generated by the above structural model is the same as the covariance we calculated from the SVD. 

Note, to generate $Z,X$ under the aforementeioned structural equations, we need the distribution of $\epsilon_Z$ and $\epsilon_X$. As a proxy we can use the marginal distribution of $Z$ and $X$ from the data (i.e. marginalizing the empirical distribution).

To generate $D$, such that it is only dependent on $W$, we learn a propensity model $E[D|W]$ to generate a treatment. (The reason why we want to resample D is to break any violating mediation paths $D->Mp->X$ that might exist in the original data, which would create a failure in this new dataset, if we didn't resample the treatment.)

From this, we can set the value of the mediator to be 
\begin{equation}
\hat{M} = a \cdot E[D|W] + \epsilon_M
\end{equation}
where $\epsilon_M$ is a multi-variate gaussian with covariance $E[MM'] = S = diagonal(s_1, ..., s_K)$, as per above.


Thus the DGP is as follows: 
1. We first learn a propensity model over the observed data to get a treatment $\hat{D} = E[D | W]$.
2. We then compute the SVD of the covariance matrix over the observed data to get Cov$(Z, X)$ = $G \cdot S \cdot F'$.
2. We then generate a mediator $M$ based on a normal random variable with covariance $S = diagonal(s_1, ..., s_K)$, and using the equation above. 
3. We generate $Z$ and $X$ based on the structural equations above using the known $G,F$ and computed $M$ and $\epsilon$'s. 
4. For the outcome, we choose a simple linear structural function linear structural function:
\begin{align}
f_Y(M, D, X, \epsilon_Y) =& b M + c D + g X[:, 0] + \sigma_Y F_n(Y)
\end{align}
where F_n(Y) is the empirical distribution of $Y$ in the orignal data.

Now for every sample $(W, D, Z, X, Y)$ in the original dataset, we now have a sample $(W, \hat{D}, \hat{Z}, \hat{X}, \hat{Y})$, where $W$ is real, $\hat{D}$ is sampled from the estimated propensity, given $W$, and $\hat{Z}, \hat{X}$ are slight modifications of the real $X,Z$ along only a particular direction and $\hat{Y}$ is fully synthetic.

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = 0  # this can be zero; does not hurt
e = .5  # if the product of e*f is small, then we ha|ve a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = 0  # this can be zero; does not hurt
n = 50000 # number of samples 
pw = 10 # dimension of controls / confounders
pz = 5 # dimension of Z
px = 5 # dimension of X
pm = 1 # dimension of the mediator M; should not be more than max(pz,px)
sm = 1.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.

### Single experiment

In [None]:
W, X, Z, D, Y = gen_data(a, b, c, d, e, f, g, pm, pz, px, pw, n, sm=sm, seed=42)

In [None]:
generator = SemiSyntheticGenerator(random_state=0,split=True)
# if you already computed ZXYres, i.e. if it is expensive, you can pass in as ZXYres = [Zres,Xres,Yres]
generator.fit(W, D, Z, X, Y, ZXYres=None) 

In [None]:
# Sample from DGP
nsamples = 10000
What, Dhat, _, Zhat, Xhat, Yhat = generator.sample(nsamples, a, b, c, g, replace=True)

In [None]:
# Basic semisynthetic data comparison to real 
covariance(Z,X), covariance(Zhat, Xhat)

In [None]:
# Basic semisynthetic data comparison to real 
for i in range(5):
    plt.hist(Zhat[:, i], label='sampled', bins=20, alpha=.2, density=True)
    plt.hist(Z[:, i], label='true', bins=20, alpha=.2, density=True)
    plt.legend()
    plt.show()

In [None]:
# Basic semisynthetic data comparison to real 
for i in range(5):
    plt.hist(Xhat[:, i], label='sampled', bins=20, alpha=.2, density=True)
    plt.hist(X[:, i], label='true', bins=20, alpha=.2, density=True)
    plt.legend()
    plt.show()

### Repetition over many experiments

In [None]:
def run_semi_experiments(it, generator, n, a, b, c, g, *, sy=1.0, n_splits=3, semi=True,
            n_jobs=-1, verbose=0):
    np.random.seed(it)

    # M is unobserved so we omit it from the return variables
    W, D, _, Z, X, Y = generator.sample(n, a, b, c, g, sy=sy, replace=True)

    est = ProximalDE(cv=n_splits, semi=semi,binary_D=True,
                     n_jobs=n_jobs, random_state=it, verbose=verbose)
    
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

In [None]:
results = Parallel(n_jobs=-1, verbose=3)(delayed(run_semi_experiments)(i, generator, n,
                                                          a, b, c, g,
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")