In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

n = 50000
pw = 1
pm = 10
pz, px = 20, 15
n_splits = 3
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has no relationship to M,
# X has direct relationship to M, X has no direct relationship to Y
d, e, f, g = 0.0, 1.0, 1.0, 0.0
sm = 2.0

In [None]:
full_rank = False
while not full_rank:
    E = np.random.normal(0, 2, (pm, pz))
    F = np.random.normal(0, 2, (pm, px))
    if (np.linalg.matrix_rank(E, tol=0.5) == pm) and (np.linalg.matrix_rank(F, tol=0.5) == pm):
        full_rank = True

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.base import clone
from proximalde.ivreg import AdvIV
import scipy.linalg
from proximalde.gen_data import gen_data_no_controls_discrete_m

def exp(it):
    np.random.seed(it)
    
    W, D, _, Z, X, Y = gen_data_no_controls_discrete_m(n, pw, pz, px, a, b, c, d, E, F, g, sm=sm, pm=pm)
    D = D.reshape(-1, 1)
    D = D - D.mean(axis=0)
    X = X - X.mean(axis=0)
    Z = Z - Z.mean(axis=0)

    random_state = it
    nobs = X.shape[0]
    train, test = train_test_split(np.arange(nobs), test_size=.3, shuffle=True, random_state=random_state)
    ntest = len(test)
    ntrain = len(train)
    ivreg_train = AdvIV(alphas=[1.0 * nobs**(0.3)], cv=5, random_state=random_state).fit(X[train], Z[train], D[train])
    coef = ivreg_train.coef_
    # coef, _, _ = adv_iv(X[train], Z[train], D[train], n**(0.3))
    # Estimate of projection matrix E[XZ] E[ZX]^+
    # using a regularized SVD decomposition
    U, S, _ = scipy.linalg.svd((X[train].T @ Z[train]) / ntrain, full_matrices=False)
    P = U @ np.diag(S / (S + 1 / ntrain**(0.2))) @ U.T
    Dbar = D - Z @ coef.reshape(-1, 1)
    dual_phi = X * Dbar
    dual_phi[train] = dual_phi[train] @ P.T
    dual_moments = np.mean(dual_phi[test], axis=0)
    dual_phi[test] = dual_phi[test] - dual_moments.reshape(1, -1)
    dual_cov = (dual_phi[test].T @ dual_phi[test]) / ntest**2
    dual_cov += (dual_phi[train].T @ dual_phi[train]) / ntrain**2
    dual_violation_stat = dual_moments.T @ scipy.linalg.pinvh(dual_cov) @ dual_moments
    return dual_violation_stat

In [None]:
exp(0)

In [None]:
from joblib import Parallel, delayed

res = Parallel(n_jobs=-1)(delayed(exp)(it) for it in range(1000))

In [None]:
import matplotlib.pyplot as plt

plt.title(np.mean(res > scipy.stats.chi2(df=px).ppf(.95)))
plt.hist(res)
plt.axvline(scipy.stats.chi2(df=px).ppf(.95))
plt.show()