## Setup Only for Colab

In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/hidden_mediators

In [None]:
%ls

In [None]:
from IPython.display import clear_output

In [None]:
import time
!pip install -r requirements.txt
time.sleep(2)
clear_output()

In [None]:
import time
# replace `develop` with `install` if you wont make library code changes
!python setup.py develop
time.sleep(2)
clear_output()
# Restart the session after running this

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks

# Main Logic

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from proximalde.crossfit import fit_predict

# Running a Single Experiment

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

In [None]:
n = 50000
pw = 1
pz, px = 2, 2

In [None]:
# np.random.seed(5)
W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

## for no controls un-comment this
# _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

## for multi-dimensional mediator uncomment this
# pm = 5
# full_rank = False
# while not full_rank:
#     E = np.random.normal(0, 2, (pm, pz))
#     F = np.random.normal(0, 2, (pm, px))
#     if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
#         full_rank = True
# W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)

In [None]:
# It's advisable to standardize W, Z, X (in particular the non-binary ones)
# and center the binary ones
W = StandardScaler().fit_transform(W)
X = StandardScaler().fit_transform(X)
Z = StandardScaler().fit_transform(Z)

### Using the ProximalDE Estimator Class

In [None]:
est = ProximalDE(semi=True, cv=3, random_state=4)
## or we can use default xgboost models, or interchange linear and xgboost for regression or classiifcation
# est = ProximalDE(model_regression='xgb', model_classification='xgb', semi=True, cv=3, random_state=4)
# est = ProximalDE(model_regression='linear', model_classification='xgb', semi=True, cv=3, random_state=4)
est.fit(W, D, Z, X, Y)

In [None]:
est.summary(decimals=5)

In [None]:
# tests can also be accessed individually
display(est.weakiv_test(alpha=0.05))
display(est.idstrength_violation_test(alpha=0.05))
display(est.primal_violation_test(alpha=0.05))
display(est.dual_violation_test(alpha=0.05))

#### Covariance Rank Diagnostic for Covariance of Proxies

In [None]:
svalues, svalues_crit = est.covariance_rank_test(calculate_critical=True)

In [None]:
plt.title(f"Number of singular values above threshold: {np.sum(svalues >= svalues_crit)}. "
          f"Threshold={svalues_crit:.3f}. Top singular value={svalues[0]:.3f}")
plt.scatter(np.arange(len(svalues)), svalues)
plt.axhline(svalues_crit)
plt.show()

#### Confidence Intervals and Robust Confidence Intervals

In [None]:
est.conf_int(alpha=.05) # 95% confidence interval

In [None]:
# 95% confidence interval, robust to weak identification
est.robust_conf_int(alpha=0.05, lb=.1, ub=1.0, ngrid=1000)

#### Unusual Data Diagnostics

In [None]:
diag = est.run_diagnostics()

In [None]:
inds = est.influential_set(alpha=0.05)
len(inds)  # size of influential set that can flip the result

In [None]:
from sklearn.base import clone
# let's re-train a clone of the estimator on all the data
# except the influential set
est2 = clone(est)
est2.fit(np.delete(W, inds, axis=0), np.delete(D, inds, axis=0),
         np.delete(Z, inds, axis=0), np.delete(X, inds, axis=0),
         np.delete(Y, inds, axis=0))
est2.summary(alpha=0.05)

In [None]:
diag.cookd_plot()
plt.show()

In [None]:
diag.l2influence_plot()
plt.show()

In [None]:
diag.influence_plot(influence_measure='cook', npoints=10)
plt.show()

In [None]:
diag.influence_plot(influence_measure='l2influence', npoints=10)
plt.show()

### Subsample-Based Inference

In [None]:
inf = est.bootstrap_inference(stage=3, n_subsamples=1000, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.axvline(inf.point, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=2, n_subsamples=100, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.axvline(inf.point, color='r')
plt.show()

In [None]:
inf = est.bootstrap_inference(stage=1, n_subsamples=10, fraction=0.5, replace=False, verbose=3, random_state=123)
inf.summary()

In [None]:
plt.hist(inf.point_dist)
plt.vlines([inf.point], 0, 300, color='r')
plt.show()

In [None]:
inf.summary(pivot=True)

# Quality of Procedure and Diagnostics Across Many Experiments

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from sklearn.preprocessing import StandardScaler
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_no_controls_discrete_m
from proximalde.proximal import ProximalDE

In [None]:
def exp_res(it, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm, *,
            dual_type='Z', ivreg_type='adv', n_splits=3, semi=True,
            n_jobs=-1, verbose=0):
    np.random.seed(it)
    if pm > 1:
        full_rank = False
        while not full_rank:
            E = np.random.normal(0, 2, (pm, pz))
            F = np.random.normal(0, 2, (pm, px))
            if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
                full_rank = True
        W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)
        if pw == 0:
            W = None
    elif pw > 0:
        # M is unobserved so we omit it from the return variables
        W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
    else:
        _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
        W = None

    W = StandardScaler().fit_transform(W)
    X = StandardScaler().fit_transform(X)
    Z = StandardScaler().fit_transform(Z)
    est = ProximalDE(cv=n_splits, semi=semi, binary_D=True,
                     dual_type=dual_type, ivreg_type=ivreg_type,
                     n_jobs=n_jobs, random_state=it, verbose=verbose)
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

```
a : strength of D -> M edge
b : strength of M -> Y edge
c : strength of D -> Y edge
d : strength of D -> Z edge
e : strength of M -> Z edge
f : strength of M -> X edge
g : strength of X -> Y edge
```

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .5  # this can be zero; does not hurt
e = .5  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .5  # this can be zero; does not hurt
sm = 2.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.
n = 10000
pw = 1
pm = 1
pz, px = 2, 2

results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, n, pw, pm, pz, px, a, b, c, d, e, f, g, sm,
                                                          dual_type='Z', ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

#### Summarize

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")

In [None]:
import scipy.stats
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.title(f"{np.mean(dval > scipy.stats.chi2(df=px).ppf(.95))} vs 0.05, {np.mean(dval, axis=0):.3f} vs {px}, {np.var(dval, axis=0):.3f} vs {2*px}")
plt.hist(dval)
plt.axvline(scipy.stats.chi2(df=px).ppf(.95), color='r')
plt.subplot(1, 2, 2)
plt.title(f"{np.mean(pval > scipy.stats.chi2(df=pz + 1).ppf(.95))} vs 0.05, "
          f"{np.mean(pval, axis=0):.3f} vs {pz + 1}, {np.var(pval, axis=0):.3f} vs {2*(pz + 1)}")
plt.hist(pval)
plt.axvline(scipy.stats.chi2(df=pz + 1).ppf(.95), color='r')
plt.show()

In [None]:
from statsmodels.graphics.gofplots import qqplot
import scipy.stats
plt.figure(figsize=(15, 5))
ax = plt.subplot(1, 2, 1)
qqplot(np.array(dval), dist=scipy.stats.chi2(df=px), line='45', ax=ax)
ax = plt.subplot(1, 2, 2)
qqplot(np.array(pval), dist=scipy.stats.chi2(df=pz+1), line='45', ax=ax)
plt.show()

In [None]:
import scipy.stats
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.title(f"{np.mean(idstr / idstr_crit > 1)} vs. 0.05, {np.mean(idstr / idstr_crit, axis=0):.3f}")
plt.hist(idstr)
plt.axvline(np.mean(idstr_crit), color='r')
plt.subplot(1, 2, 2)
plt.title(f"{np.mean(wiv_stat):.3f}, {np.mean(wiv_stat / wiv_crit, axis=0):.3f}")
plt.hist(wiv_stat)
plt.axvline(np.mean(wiv_crit), color='r')
plt.show()

In [None]:
plt.hist(points_base, label='Distribution of Estimates: debiased')
plt.hist(points_alt, label='Distribution of Estimates: original', alpha=.3)
plt.vlines([c], 0, plt.ylim()[1], color='red', label='truth')
plt.legend()
plt.show()

# Using Custom ML Models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.gen_data import gen_data_complex, gen_data_no_controls, gen_data_no_controls_discrete_m
from proximalde.proximal import proximal_direct_effect, ProximalDE, residualizeW
from sklearn.linear_model import LinearRegression
from proximalde.crossfit import fit_predict

In [None]:
from xgboost import XGBRegressor, XGBClassifier
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.model_selection import train_test_split

class XGBRegressorWrapper(BaseEstimator, RegressorMixin):

    def __init__(self, *, max_depth=3, early_stopping_rounds=50, learning_rate=.1):
        self.max_depth = max_depth
        self.early_stopping_rounds = early_stopping_rounds
        self.learning_rate = learning_rate

    def fit(self, X, y):
        Xtrain, Xval, ytrain, yval = train_test_split(X, y, test_size=.2)
        self.model_ = XGBRegressor(max_depth=self.max_depth,
                                   early_stopping_rounds=self.early_stopping_rounds,
                                   learning_rate=self.learning_rate, random_state=123)
        self.model_.fit(Xtrain, ytrain, eval_set=[(Xval, yval)], verbose=False)
        return self

    def predict(self, X):
        return self.model_.predict(X)


class XGBClassifierWrapper(BaseEstimator, ClassifierMixin):

    def __init__(self, *, max_depth=3, early_stopping_rounds=50, learning_rate=.1):
        self.max_depth = max_depth
        self.early_stopping_rounds = early_stopping_rounds
        self.learning_rate = learning_rate

    def fit(self, X, y):
        Xtrain, Xval, ytrain, yval = train_test_split(X, y, test_size=.2)
        self.model_ = XGBClassifier(max_depth=self.max_depth,
                                   early_stopping_rounds=self.early_stopping_rounds,
                                   learning_rate=self.learning_rate, eval_metric='logloss', random_state=123)
        self.model_.fit(Xtrain, ytrain, eval_set=[(Xval, yval)], verbose=False)
        self.classes_ = self.model_.classes_
        return self

    def predict(self, X):
        return self.model_.predict(X)

    def predict_proba(self, X):
        return self.model_.predict_proba(X)

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt
n = 50000
pw = 10
pz, px = 2, 2

In [None]:
np.random.seed(2)
W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

## for no controls un-comment this
# _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

## for multi-dimensional mediator uncomment this
# pm = 5
# full_rank = False
# while not full_rank:
#     E = np.random.normal(0, 2, (pm, pz))
#     F = np.random.normal(0, 2, (pm, px))
#     if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
#         full_rank = True
# W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)

In [None]:
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
est = ProximalDE(model_regression=XGBRegressorWrapper(), model_classification=XGBClassifierWrapper(),
                 cv=3, semi=False, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)

In [None]:
est.summary()

## And with HyperParam Tuning and Semi-Crossfitting

In [None]:
from sklearn.model_selection import GridSearchCV

regression = GridSearchCV(XGBRegressorWrapper(), {'learning_rate': [.01, .1, 1]}, scoring='neg_root_mean_squared_error')
classification = GridSearchCV(XGBClassifierWrapper(), {'learning_rate': [.01, .1, 1]}, scoring='neg_log_loss')

In [None]:
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
est = ProximalDE(model_regression=regression, model_classification=classification,
                 cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)

In [None]:
est.summary()

## And even gcv among many types of models

In [None]:
from proximalde.utilities import GridSearchCVList
from sklearn.linear_model import Lasso, LogisticRegression

regression = GridSearchCVList([XGBRegressorWrapper(), Lasso()],
                              [{'learning_rate': [.01, .1, 1]},
                               {'alpha': np.logspace(-4, 2, 20)}],
                              scoring='neg_root_mean_squared_error')
classification = GridSearchCVList([XGBClassifierWrapper(),
                                   LogisticRegression(penalty='l1', solver='liblinear',
                                                      tol=1e-6, intercept_scaling=100)],
                                  [{'learning_rate': [.01, .1, 1]},
                                   {'C': np.logspace(-4, 4, 10)}],
                                  scoring='neg_log_loss')

In [None]:
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
est = ProximalDE(model_regression=regression, model_classification=classification,
                 cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)

In [None]:
est.summary()

# Mediations that Trigger Violations of Both Assumptions

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE
from proximalde.gen_data import gen_data_with_mediator_violations

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 2, 2

In [None]:
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=[0], invalidXinds=[1])
W = None

In [None]:
est = ProximalDE(cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
est.summary()

## Explanation

For the primal moment to hold, we essentially need that:
\begin{equation}
\text{Cov}(Y, DZ) \in \text{column-span}(\text{Cov}(DZ, DX))
\end{equation}
where $DZ, DX$ are the concatenation of $D$ with $Z$ and $X$ correspondingly. Roughly this should be satisfied if:
\begin{equation}
\text{Cov}(Y, Z) \in \text{column-span}(\text{Cov}(Z, X))
\end{equation}


For the dual moment to hold, we essentially need that:
\begin{equation}
\text{Cov}(D, X) \in \text{column-span}(\text{Cov}(X, Z)) = \text{row-span}(\text{Cov}(Z, X))
\end{equation}

Let's verify that this is indeed not the case

In [None]:
Z = Z - np.mean(Z, axis=0)
X = X - np.mean(X, axis=0)
D = D - np.mean(D, axis=0)
Y = Y - np.mean(Y, axis=0)
D = D.reshape(-1, 1)
Y = Y.reshape(-1, 1)

Let's calculate the three relevant covariances:

In [None]:
CovZX = Z.T @ X / n
CovXD = X.T @ D / n
CovYZ = Z.T @ Y / n

Let's investigate the condition for the existence of the dual, so we need $Cov(X,D)$ to be in the row span of $Cov(Z,X)$, equivalently, column span of $Cov(X, Z)$. We perform a singular value decomposition and take only the significant non-zero eigenvalues. This can be done by using the critical value computed by the "covariance_rank_test"

In [None]:
_, Scrit = est.covariance_rank_test(calculate_critical=True)
Scrit

In [None]:
U, S, Vh = np.linalg.svd(CovZX, full_matrices=False)

In [None]:
# row span of CovZX, with stat-sig non-zero singular values, vs CovXD
print("Basis of row span of CovZX:\n", Vh[:, S > Scrit], "\n",
      "Vector CovXD:\n", CovXD)

We see that $CovXD\approx (0.12, 0.12)$, while the row span of CovZX is the subspace spanned by approximately the single vector $(-1, 0)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.

---



Let's examine the primal existence:

In [None]:
# column span of CovZX, with stat-sig non-zero singular values, vs CovXD
print("Basis of column span of CovZX:\n", U[:, S > Scrit], "\n",
      "Vector CovYZ:\n", CovYZ)

We see that $CovYZ\approx (7.5, 6.5)$, while the column span of CovZX is the subspace spanned by the single vector $(-.7, -.7)$, i.e. multiples of this single vector. So obviously, the first vector is not in that subspace.



## Fixing the Violation by Removing Z's and X's

In the above example it was $Z[0]$ that had a violation with $Y$ and $X[1:]$ that had a violation with $D$. So potentially if we remove $Z[0]$ and if we remove $X[1:]$ we would get an unbiased estimate.

Of course, this would be ok only if $D$ does not have a direct effect on the Z's we removed (in this case $Z[0]$), as otherwise, by removing $Z[0]$, there is another mediation path through $Z[0]$ that we are not controlling for and the effect we are estimating is also the effect mediated through the path $D->Z[0]->Y$. 

So even though removing $Z[0]$ and $X[1:]$ will always lead to the violations not being flagged, the estimate will be the correct estimate only when $d=0$, i.e. the direct effect from $D->Z[0]$ is $0$.

Similarly, removing $X[1:]$ is ok, only if the direct effect of these X's to Y is zero. Otherwise, by removing these X's we are removing the mediation paths $D->X->Y$ and not controlling for them.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE
from proximalde.gen_data import gen_data_with_mediator_violations, gen_data_no_controls

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .0  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 50, 40
invalidZ = [0, 4, 5]
invalidX = [0, 6, 8]
validZ = np.setdiff1d(np.arange(pz), invalidZ)
validX = np.setdiff1d(np.arange(px), invalidX)

In [None]:
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g,
                                                     invalidZinds=invalidZ, invalidXinds=invalidX)
W = None

In [None]:
est = ProximalDE(cv=3, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
display(est.summary().tables[0], est.summary().tables[2])

In [None]:
est = ProximalDE(cv=3, random_state=3, verbose=3)
est.fit(W, D, Z[:, validZ], X[:, validX], Y)
display(est.summary().tables[0], est.summary().tables[2])

Let's see if we can identify this subset using a data-driven approach:

In [None]:
from proximalde.utilities import covariance, svd_critical_value
from proximalde.proximal import residualizeW

# Order the Z's in decreasing order of orthogonality
Dres, Zres, Xres, Yres, *_ = residualizeW(W, D, Z, X, Y)

In [None]:
covXD = covariance(Xres, Dres)
covZY = covariance(Zres, Yres)
covXZ = covariance(Xres, Zres)

# replacing covariance with low rank component, cleaning up the noisy eigenvalues
U, S, Vh = scipy.linalg.svd(covXZ, full_matrices=False)
Scrit = svd_critical_value(Xres, Zres)
covXZ = U[:, S > Scrit] @ np.diag(S[S > Scrit]) @ Vh[S > Scrit, :]

# cleaning up cov(X,D) and cov(Z,Y) to zero-out the statistical zeros
stderr_covXD = np.sqrt(np.var((Xres - Xres.mean(axis=0)) * (Dres - Dres.mean(axis=0)), axis=0) / Xres.shape[0])
covXD[np.abs(covXD).flatten() < 1.96 * stderr_covXD] = 0
stderr_covZY = np.sqrt(np.var((Zres - Zres.mean(axis=0)) * (Yres - Yres.mean(axis=0)), axis=0) / Zres.shape[0])
covZY[np.abs(covZY).flatten() < 1.96 * stderr_covZY] = 0

In [None]:
covDZY = np.zeros((1 + Z.shape[1], 1))
covDZY[1:, :] = covZY
covDZY[0, :] = covariance(Dres, Yres).flatten()

In [None]:
DZ = np.hstack([Dres, Zres])
covDDZ = covariance(D, DZ)

covDXDZ = np.zeros((1 + Xres.shape[1], 1 + Zres.shape[1]))
covDXDZ[1:, 1:] = covXZ
covDXDZ[0, :] = covDDZ.flatten()
covDXDZ[1:, 0] = covXD.flatten()

In [None]:
def violation(remnantX, remnantZ):
    covXZ_tmp = covXZ[remnantX, :][:, remnantZ]
    covZX_tmp = covXZ_tmp.T
    covXD_tmp = covXD[remnantX]
    covZY_tmp = covZY[remnantZ]
    dual_violation = np.linalg.norm(covXD_tmp - covXZ_tmp @ scipy.linalg.pinv(covXZ_tmp) @ covXD_tmp, ord=np.inf)
    primal_violation = np.linalg.norm(covZY_tmp - covZX_tmp @ scipy.linalg.pinv(covZX_tmp) @ covZY_tmp, ord=np.inf)

    ## more accurate primal violation
    # covDXDZ_tmp = covDXDZ[[0] + [i + 1 for i in remnantX], :][:, [0] + [i + 1 for i in remnantZ]]
    # covDZDX_tmp = covDXDZ_tmp.T
    # covDZY_tmp = covDZY[[0] + [i + 1 for i in remnantZ]]
    # primal_violation = np.linalg.norm(covDZY_tmp - covDZDX_tmp @ scipy.linalg.pinv(covDZDX_tmp) @ covDZY_tmp, ord=np.inf)

    return dual_violation, primal_violation

In [None]:
violation(list(validX), list(validZ))

In [None]:
dv_bench, pv_bench = violation(np.arange(Xres.shape[1]), np.arange(Zres.shape[1]))
dv_bench, pv_bench

In [None]:
from joblib import Parallel, delayed

def xset_trial(it, remnantZ, verbose):
    ''' We try to add elements to the X's in random order, while maintaining that the dual
    violation is not violated. Here we use all the Z's, since the dual violation can only
    improve if we add more Z's.
    '''
    np.random.seed(it)
    unusedX = np.arange(Xres.shape[1])
    remnantX = []
    while len(unusedX) > 0:
        next = np.random.choice(len(unusedX), size=1)[0]
        dv, pv = violation(remnantX + [unusedX[next]], remnantZ)
        if dv < 0.1 * dv_bench:
            remnantX += [unusedX[next]]
        unusedX = np.delete(unusedX, next)

    if remnantX:
        remnantX = np.sort(remnantX)
        if verbose:
            print(remnantX, violation(remnantX, remnantZ))
    
        ohe = np.zeros(Xres.shape[1]).astype(int)
        ohe[remnantX] = 1
        return ohe
    else:
        return None

def zset_trial(it, remnantX, verbose):
    ''' Given a candidate X set, we try to add elements to the Z's in random order,
    while maintaining that the primal violation does not occur.
    '''
    np.random.seed(it)
    
    unusedZ = np.arange(Zres.shape[1])
    remnantZ = []
    while len(unusedZ) > 0:
        next = np.random.choice(len(unusedZ), size=1)[0]
        dv, pv = violation(remnantX, remnantZ + [unusedZ[next]])
        if pv < 0.1 * pv_bench:
            remnantZ += [unusedZ[next]]
        unusedZ = np.delete(unusedZ, next)

    if remnantZ:
        remnantZ = np.sort(remnantZ)
    
        dv, pv = violation(remnantX, remnantZ)
        if verbose:
            print(remnantX, remnantZ, dv, pv)
    
        ohe = np.zeros(Xres.shape[1] + Zres.shape[1]).astype(int)
        ohe[remnantX] = 1
        ohe[Xres.shape[1] + remnantZ] = 1
        return ohe
    else:
        return None

def find_candidate_sets(ntrials, verbose=0, n_jobs=-1):
    unique_Zsets = np.array([np.ones(Zres.shape[1])]).astype(int)

    for _ in range(2):
        # we generate a set of candidate of maximal X sets such that the dual violation does not
        # occur, when we use all the Z's. Note that more Z's can only help the dual.
        candidateX = []
        for remnantZ in unique_Zsets:
            remnantZ = np.argwhere(remnantZ).flatten()
            candidateX += Parallel(n_jobs=n_jobs, verbose=3)(delayed(xset_trial)(it, remnantZ, verbose)
                                                             for it in range(ntrials))
        candidateX = [c for c in candidateX if c is not None]

        if not candidateX:
            return []

        candidateX = np.array(candidateX).astype(int)
        # we clean up to keep only the unique solutions
        unique_Xsets = np.unique(candidateX, axis=0)
    
        candidateXZ = []
        # for each unique candidate solution of X's
        for remnantX in unique_Xsets:
            remnantX = np.argwhere(remnantX).flatten()
            # we try to construct maximal sets of Z's, such that the primal violation
            # does not occur. Note that more X's can only help the primal, which is why
            # we tried to build maximal X's in the first place.
            candidateXZ += Parallel(n_jobs=n_jobs, verbose=3)(delayed(zset_trial)(it, remnantX, verbose)
                                                              for it in range(ntrials))
        candidateXZ = [c for c in candidateXZ if c is not None]

        if not candidateXZ:
            return []

        # this array now contains the one-hot-encodings of the Xset and the Zset (concatenated)
        candidateXZ = np.array(candidateXZ).astype(int)
        # we clean up to keep only unique Zset solutions
        unique_Zsets = np.unique(candidateXZ[:, Xres.shape[1]:], axis=0)

    # we clean up to keep only unique pairs of solutions
    unique_XZsets = np.unique(candidateXZ, axis=0)
    # we transform the one hot encodings back to member sets
    final_candidates = []
    for unique_XZ in unique_XZsets:
        Xset = np.argwhere(unique_XZ[:Xres.shape[1]]).flatten()
        Zset = np.argwhere(unique_XZ[Xres.shape[1]:]).flatten()
        dv, pv = violation(Xset, Zset)
        if verbose:
            print(Xset, Zset, dv, pv)
        if pv < 0.1 * pv_bench and dv < 0.1 * dv_bench:
            final_candidates += [(Xset, Zset)]
    return final_candidates

In [None]:
candidates = find_candidate_sets(20)

In [None]:
for Xset, Zset in candidates:
    print("Xset =", Xset)
    print("Deleted Xs =", np.setdiff1d(np.arange(Xres.shape[1]), Xset))
    print("Zset =", Zset)
    print("Deleted Zs =", np.setdiff1d(np.arange(Zres.shape[1]), Zset))
    est = ProximalDE(random_state=3)
    est.fit(None, Dres, Zres[:, Zset], Xres[:, Xset], Yres)
    display(est.summary().tables[0], est.summary().tables[2])

# Semi-Synthetic Data Generation



In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE, residualizeW
from proximalde.gen_data import gen_data_with_mediator_violations, gen_data_no_controls_discrete_m, gen_data_no_controls, gen_data_complex

Suppose we are given some real-world dataset $(W, D, Z, X, Y)$

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .6  # this can be zero; does not hurt
e = .7  # if the product of e*f is small, then we have a weak instrument
f = .5  # if the product of e*f is small, then we have a weak instrument
g = .9  # this can be zero; does not hurt

n = 100000
pw = 100
pz, px = 10, 5

In [None]:
np.random.seed(124)
W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g)
W = None

# W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

## for no controls un-comment this
# _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

## for multi-dimensional mediator uncomment this
# pm = 5
# full_rank = False
# while not full_rank:
#     E = np.random.normal(0, 2, (pm, pz))
#     F = np.random.normal(0, 2, (pm, px))
#     if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
#         full_rank = True
# W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)
# W = None

In [None]:
# as we said this dgp violates the assumptions so the test will fail
est = ProximalDE(cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(W, D, Z, X, Y)
est.summary()

## Semi-Synthetic Generation Process
We will create a semi-synthetic DGP as follows. We first find the top component of the covariance of $(Z - E[Z|W], X - E[X|W])$, by running a singular value decomposition.

We can think of the statistically non-zero singular values of this covariance matrix as the latent factor model. If the svd decomposition of $Cov(Z, X)$ is $G \cdot S \cdot F'$, then such an SVD can be generated under the following latent factor model:
\begin{align}
Z =& G M + \epsilon_Z
X =& F M + \epsilon_X
\end{align}
where $\epsilon_Z$ and $\epsilon_X$ are independent and $G$ and $F$ are the
eigenvectors found by the SVD. Note that under this structural model:
\begin{align}
Cov(Z, X) = G E[MM'] F'
\end{align}
Hence, if $E[MM'] = diagonal(s_1, ..., s_K)$, then the covariance of the Z,X generated by the above structural model is the same as the covariance we calculated. Thus we can generate $Z, X$ that match this covariance, by first generating a mediator $M$ based on a normal r.v. with covariance $diagonal(s_1, ..., s_K)$ and then generate $Z$ and $X$ based on the structural equations above. For this we also need the distribution of $\epsilon_Z$ and $\epsilon_X$. As a proxy we can use the marginal distribution of $Z$ and $X$ from the data (i.e. marginalizing the empirical distribution). We can also use the marginal distribution of the projected $Z$ and $X$ after we project out the non-orthogonal components to the eigenvectors $G$ and $F$ correspondingly. The latter has the guarantee that it only contains the epsilon parts, but does not preserve the variance of the original variables. So we aire for the first approach.

Moreover, we generate the mediator as follows. We learn a propensity model $E[D|W]$ and we generate a treatment $D$ by sampling from the propensity. Then we  set the value of the mediator to be:
\begin{equation}
\tilde{M} = a \cdot D + \epsilon_M
\end{equation}
so that the mediator is affected by the treatment, where $\epsilon_M$ is a multi-variate gaussian with diagonals as described above. 

Then we also impute outcomes:
\begin{equation}
\tilde{Y} = f_Y(\tilde{M}, D, \tilde{X}, \epsilon_Y)
\end{equation}
The reason why we want to resample D is to break any violating mediation paths $D->Mp->X$ that might exist in the original data, which would create a failure in this new dataset, if we didn't resample the treatment. We could try not resampling the treatment. If then we get a failure of the dual violation, then this hints at an auxiliary violating mediation path $D->Mp->X$.

Now for every sample $(W, D, Z, X, Y)$ in the original dataset, we now have a sample $(W, \tilde{D}, \tilde{Z}, \tilde{X}, \tilde{Y})$, where $W$ is real, $\tilde{D}$ is sampled from the estimated propensity, given $W$, and $\tilde{Z}, \tilde{X}$ are slight modifications of the real $X,Z$ along only a particular direction and $\tilde{Y}$ is fully synthetic.


For simplicity, we will first choose linear structural function $f_Y$:
\begin{align}
f_Y(M, D, X, \epsilon_Y) =& b M + c D + g X[:, 0] + \sigma_Y F_n(Y)
\end{align}
where F_n(Y) is the empirical distribution of $Y$ in the orignal data.



## Packaging the Semi-Synthetic Generation

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats
from joblib import Parallel, delayed
from proximalde.proximal import ProximalDE, residualizeW, svd_critical_value
from proximalde.utilities import covariance
from proximalde.gen_data import gen_data_with_mediator_violations, gen_data_no_controls_discrete_m, gen_data_no_controls, gen_data_complex

In [None]:
a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
d = .0  # this can be zero; does not hurt
e = 1.0  # if the product of e*f is small, then we have a weak instrument
f = 1.0  # if the product of e*f is small, then we have a weak instrument
g = .0  # this can be zero; does not hurt

n = 100000
pw = 10
pz, px = 4, 4

In [None]:
np.random.seed(124)
# W, D, _, Z, X, Y = gen_data_with_mediator_violations(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

# W, D, _, Z, X, Y = gen_data_complex(n, pw, pz, px, a, b, c, d, e, f, g)

## for no controls un-comment this
# _, D, _, Z, X, Y = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g)
# W = None

## for multi-dimensional mediator uncomment this
pm = 2
full_rank = False
while not full_rank:
    E = np.random.normal(0, 2, (pm, pz))
    F = np.random.normal(0, 2, (pm, px))
    if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
        full_rank = True
W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, e*E, f*F, g, pm=pm)
W = None

In [None]:
from proximalde.gen_data import SemiSyntheticGenerator

a = 1.0  # a*b is the indirect effect through mediator
b = 1.0
c = .5  # this is the direct effect we want to estimate
g = .0  # this can be zero; does not hurt
sm = 2.0  # strength of mediator noise; needs to be non-zero for identifiability; only used when pm=1.
nsamples = 100000

generator = SemiSyntheticGenerator(split=True)
generator.fit(W, D, Z, X, Y)

In [None]:
Wtilde, Dtilde, _, Ztilde, Xtilde, Ytilde = generator.sample(nsamples, a, b, c, g, replace=True)

In [None]:
covariance(Z, X)

In [None]:
covariance(Z, Z), covariance(X, X)

In [None]:
covariance(Ztilde, Xtilde)

In [None]:
covariance(Ztilde, Ztilde), covariance(Xtilde, Xtilde)

In [None]:
plt.hist(Ztilde[:, 0], label='sampled')
plt.hist(Z[:, 0], label='true')
plt.show()

In [None]:
import statsmodels.api as stm
def exp_res(it, generator, n, a, b, c, g, *, sy=1.0, n_jobs=-1, verbose=0):
    np.random.seed(it)

    # M is unobserved so we omit it from the return variables
    W, D, M, Z, X, Y = generator.sample(n, a, b, c, g, sy=sy, replace=True)

    res = stm.OLS(Y, np.hstack([D.reshape(-1, 1), M, X, np.ones((D.shape[0], 1))])).fit(cov_type='HC1')
    return res.params[0], np.sqrt(res.cov_params()[0, 0])

In [None]:
exp_res(5, generator, nsamples, a, b, c, g, n_jobs=1)

In [None]:
results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, generator, nsamples,
                                                          a, b, c, g, n_jobs=1)
                                          for i in range(100))

In [None]:
points, stderrs = map(np.array, zip(*results))

print("Estimation Quality")
coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
rmse = np.sqrt(np.mean((points - c)**2))
bias = np.abs(np.mean(points) - c)
std = np.std(points)
mean_stderr = np.mean(stderrs)
mean_length = np.mean(2 * 1.96 * stderrs)
median_length = np.median(2 * 1.96 * stderrs)
print(f"Coverage: {coverage:.3f}")
print(f"RMSE: {rmse:.3f}")
print(f"Bias: {bias:.3f}")
print(f"Std: {std:.3f}")
print(f"Mean CI length: {mean_length:.3f}")
print(f"Median CI length: {mean_length:.3f}")
print(f"Mean Estimated Stderr: {mean_stderr:.3f}")

In [None]:
Wtilde, Dtilde, _, Ztilde, Xtilde, Ytilde = generator.sample(nsamples, a, b, c, g, replace=True)

# we find that the dual violation still exists, causing a slight bias (the true
# value we should recover is c)
est = ProximalDE(cv=3, semi=True, n_jobs=-1, random_state=3, verbose=3)
est.fit(Wtilde, Dtilde, Ztilde, Xtilde, Ytilde)
est.summary()

In [None]:
def exp_res(it, generator, n, a, b, c, g, *, sy=1.0,
            dual_type='Z', ivreg_type='adv', n_splits=3, semi=True,
            n_jobs=-1, verbose=0):
    np.random.seed(it)

    # M is unobserved so we omit it from the return variables
    W, D, _, Z, X, Y = generator.sample(n, a, b, c, g, sy=sy, replace=True)

    est = ProximalDE(cv=n_splits, semi=semi,
                     dual_type=dual_type, ivreg_type=ivreg_type,
                     n_jobs=n_jobs, random_state=it, verbose=verbose)
    est.fit(W, D, Z, X, Y)
    weakiv_stat, _, _, weakiv_crit = est.weakiv_test(alpha=0.05)
    idstr, _, _, idstr_crit = est.idstrength_violation_test(alpha=0.05)
    pval, _, _, pval_crit = est.primal_violation_test(alpha=0.05)
    dval, _, _, dval_crit = est.dual_violation_test(alpha=0.05)
    lb, ub = est.robust_conf_int(lb=-2, ub=2)
    return est.point_, est.stderr_, est.r2D_, est.r2Z_, est.r2X_, est.r2Y_, \
        idstr, idstr_crit, est.point_pre_, est.stderr_pre_, \
        pval, pval_crit, dval, dval_crit, weakiv_stat, weakiv_crit, \
        lb, ub

In [None]:
results = Parallel(n_jobs=-1, verbose=3)(delayed(exp_res)(i, generator, nsamples,
                                                          a, b, c, g,
                                                          dual_type='Z', ivreg_type='adv',
                                                          n_splits=3, semi=True, n_jobs=1)
                                          for i in range(100))

In [None]:
points_base, stderrs_base, rmseD, rmseZ, rmseX, rmseY, \
    idstr, idstr_crit, points_alt, stderrs_alt, \
    pval, pval_crit, dval, dval_crit, wiv_stat, wiv_crit, \
    rlb, rub = map(np.array, zip(*results))

points_base = np.array(points_base)
stderrs_base = np.array(stderrs_base)
points_alt = np.array(points_alt)
stderrs_alt = np.array(stderrs_alt)

print("Estimation Quality")
for name, points, stderrs in [('Debiased', points_base, stderrs_base), ('Regularized', points_alt, stderrs_alt)]:
    print(f"\n{name} Estimate")
    coverage = np.mean((points + 1.96 * stderrs >= c) & (points - 1.96 * stderrs <= c))
    rmse = np.sqrt(np.mean((points - c)**2))
    bias = np.abs(np.mean(points) - c)
    std = np.std(points)
    mean_stderr = np.mean(stderrs)
    mean_length = np.mean(2 * 1.96 * stderrs)
    median_length = np.median(2 * 1.96 * stderrs)
    print(f"Coverage: {coverage:.3f}")
    print(f"RMSE: {rmse:.3f}")
    print(f"Bias: {bias:.3f}")
    print(f"Std: {std:.3f}")
    print(f"Mean CI length: {mean_length:.3f}")
    print(f"Median CI length: {mean_length:.3f}")
    print(f"Mean Estimated Stderr: {mean_stderr:.3f}")
    print(f"Nuisance R^2 (D, Z, X, Y): {np.mean(rmseD):.3f}, {np.mean(rmseZ):.3f}, {np.mean(rmseX):.3f}, {np.mean(rmseY):.3f}")

print("\nRobust ConfInt Coverage")
rcoverage = np.mean((rub >= c) & (rlb <= c))
print(f"Robust Coverage: {rcoverage:.3f}")

print("\nViolations")
for name, stat, crit in [('Id-Strenth', idstr, idstr_crit), ('WeakIV F-test', wiv_stat, wiv_crit)]:
    violation = np.mean(stat <= crit)
    print(f"% Violations of {name}: {violation:.3f}")
for name, stat, crit in [('Primal Existence', pval, pval_crit), ('Dual Existence', dval, dval_crit)]:
    violation = np.mean(stat >= crit)
    print(f"% Violations of {name}: {violation:.3f}")