## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_synthetic_data import gen_data
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from proximalde.crossfit import fit_predict
from proximalde.proxy_selection_alg import *
from proximalde.gen_synthetic_data import gen_data_no_controls_mediator_violations

# Simulating Mediations that Trigger Violations of Both Assumptions
We want to explore what causes the primal and dual to both be violated using synthetic data.

In [None]:
np.random.seed(2)
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 2, 2

In [None]:
# Can choose what indices of X and Z creating violating mediator paths 
W, D, _, Z, X, Y = gen_data_no_controls_mediator_violations(n, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=[0], invalidXinds=[1])

In [None]:
# Note: both the primal and the dual fail
est = ProximalDE(cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
est.summary()

## Explanation: What could cause these violations? 

Assume we are dropping the tildes (ie. $\tilde{V} \rightarrow V$, etc.).


For the __dual moment__ to hold, we essentially need that:
\begin{equation}
\text{Cov}(D, X) \in \text{column-span}(\text{Cov}(X, Z)) = \text{row-span}(\text{Cov}(Z, X))
\end{equation}

For the __primal moment__ to hold, we essentially need that:
\begin{equation}
\text{Cov}(Y, DZ) \in \text{column-span}(\text{Cov}(DZ, DX))
\end{equation}
where $DZ, DX$ are the concatenation of $D$ with $Z$ and $X$ correspondingly. Roughly this should be satisfied if:
\begin{equation}
\text{Cov}(Y, Z) \in \text{column-span}(\text{Cov}(Z, X))
\end{equation}


Using the above data, where both the primal and the dual fail, *let's verify that this is indeed not the case.*

### 1. Calculate the three relevant covariances:

In [None]:
Z = Z - np.mean(Z, axis=0)
X = X - np.mean(X, axis=0)
D = D - np.mean(D, axis=0)
Y = Y - np.mean(Y, axis=0)
D = D.reshape(-1, 1)
Y = Y.reshape(-1, 1)

CovZX = Z.T @ X / n
CovXD = X.T @ D / n
CovYZ = Z.T @ Y / n

### 2. Let's first investigate the condition for the existence of the dual
For the dual to hold, we need $\text{Cov}(X,D)$ to be in the row span of $\text{Cov}(Z,X)$, or equivalently, the column span of $\text{Cov}(X, Z)$. 

To find the basis (column span) of $\text{Cov}(X,Z)$, we perform a singular value decomposition and take only the significant non-zero eigenvalues. This can be done by using the critical value computed by the `covariance_rank_test`. 

In [None]:
# Calculates the critical value for Cov(Z,X)
_, Scrit = est.covariance_rank_test(calculate_critical=True)
# SVD
U, S, Vh = np.linalg.svd(CovZX, full_matrices=False)
Scrit

In [None]:
# row span of CovZX = column span of CovXZ, with stat-sig non-zero singular values, vs CovXD
print("Basis of row span of CovZX:\n", Vh[:, S > Scrit], "\n",
      "Vector CovXD:\n", CovXD)

We see that $\text{Cov}XD\approx (0.13, 0.15)$, while the row span of $\text{Cov}ZX$ is the subspace spanned by approximately the single vector $(-0.7, 0.7)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.

### 3. Let's next examine the primal existence
Similarly, for the primal to hold, we (approximately) need $\text{Cov}(Z, Y)$ to be in the column span of $\text{Cov}(Z,X)$.

In [None]:
# column span of CovZX, with stat-sig non-zero singular values, vs CovXD
print("Basis of column span of CovZX:\n", U[:, S > Scrit], "\n",
      "Vector CovYZ:\n", CovYZ)

We see that $\text{Cov}YZ \approx (7.5, 6.5)$, while the column span of $\text{Cov}ZX$ is the subspace spanned by the single vector $(-0.7, -0.7)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.



# Fixing these violations by applying our proposed proxy (=X,Z) selection algorithm
In the above example it was feature $Z[:,0]$ that had a violating path with $Y$ and feature $X[:, 1]$ that had a violating path with $D$. *If we remove feature $Z[:,0]$ from $Z$ and $X[:, 1]$ from $X$, would this solve these violations and produce an unbiased estimate?*

Potentially. We need to consider two cases: 
1. Say $D$ has a direct effect on the $Z$'s we removed (in this case $Z[:,0]$). Then, by removing $Z[:,0]$ we are not controlling for the mediation path $D->Z[:,0]->Y$ and thus the implicit bias effect we are estimating does not control for this path, i.e., erronously includes the violating mediation path $D->Z[:,0]->Y$.

2. Similarly, removing $X[1:]$ is ok, only if the direct effect of these $X$'s to $Y$ is zero. Otherwise, by removing these X's we are removing the mediation paths $D->X[:,1]->Y$ and not controlling for them.

So even though removing the violating proxy features (in this case, $Z[:, 0]$ and $X[:, 1]$) will always lead to the primal and dual violations not being flagged, the implicit bias effect estimate will be the correct estimate only when $d=0$ and , i.e. the direct effect from $D->Z[0]$ is $0$.

### Proxy selection algorithm 
We propose a proxy selection algoritm that selects non-violating $X$ and $Z$ proxies in a data-driven manner. The class can be found in `proximalde/proxy_selection_alg.py`.


In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .0  # this can be zero; does not hurt

n = 100000
pw = 0 # doesn't matter, we don't use W here 
pz, px = 50, 40
invalidZ = [0, 4, 5]
invalidX = [0, 6, 8]
validZ = np.setdiff1d(np.arange(pz), invalidZ)
validX = np.setdiff1d(np.arange(px), invalidX)

In [None]:
W, D, _, Z, X, Y = gen_data_no_controls_mediator_violations(n, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=invalidZ, invalidXinds=invalidX)
np.random.seed(0)
Dres, Zres, Xres, Yres, *_ = residualizeW(W, D, Z, X, Y) # no controls, so just zero-means the data

In [None]:
# Confirm that this data, with its violating paths, fails the 
# primal and dual violation test if we naively use all the proxy variables available to us 
est = ProximalDE(cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(None, Dres, Zres, Xres, Yres)
est.summary()

In [None]:
# Initiate the proxy selection algorithm
    # Potential candidate pairs are proposed after searching for sets using violation 
    # estimation, where we can specify how tight we want our violation estimates to be 
    # to baseline using est_thresh.
    # violation_type determines if we want to confirm candidate pairs at the end of a proposing round
    # using the estimate or full violation computation 
prm = ProxySelection(Xres,Zres,Dres,Yres,primal_type='full', violation_type='full',est_thresh=.05)

In [None]:
# Baseline primal and dual violation estimate if we use all indices 
prm.violation_est(np.arange(Xres.shape[1]), np.arange(Zres.shape[1]))[:2]

In [None]:
# Baseline primal and dual violation (not estimate) if we use all indices 
prm.violation_full(np.arange(Xres.shape[1]), np.arange(Zres.shape[1]))[:2]

In [None]:
# determines, for each Zset, how many admissible Xset subsets to look for (or vice versa)
# in the paper, we refer to ntrials as K
ntrials = 100 
candidates = prm.find_candidate_sets(ntrials,niters=2)

In [None]:
# Compute the (unbiased) point estimate and test scores, which should pass,
# for all proposed proxy candidate pairs 
for (Xset, Zset) in candidates: 
    point,test=prm.violation_full(Xset,Zset,return_dual_primal_only=False)
    display(point, test)
    rmXset = np.setdiff1d(np.arange(Xres.shape[1]), Xset)
    print("Kept Xs = ", Xset)
    print("Deleted Xs =", rmXset)

    rmZset = np.setdiff1d(np.arange(Zres.shape[1]), Zset)
    print("Kept Zs = ", Zset)
    print("Deleted Zs =", rmZset)