In [None]:
import numpy as np
from proximalde.gen_data import gen_data_no_controls
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import scipy.linalg

In [None]:
def true_params(pw, pz, px, a, b, c, d, e, f, g, sm, sz=1.0, sd=0.5):
    if pz == 1 and px == 1:
        true_Zsq = (e * a + d)**2 * sd**2 + e**2 * sm**2 + sz**2
        true_Msq = sm**2 + a**2 * sd**2
        true_XZ = f * (e * true_Msq + d * a * sd**2)
        true_DZ = (a * e + d) * sd**2
        true_DX = a * f * sd**2
        # D*X / X*Z
        true_gamma = true_DX / true_XZ
        # D^2 - gamma D * Z
        true_strength = sd**2 - true_gamma * true_DZ
        return true_gamma, true_strength
    else:
        raise AttributeError("Not available")

In [None]:
def exp(it, n, pw, pz, px, a, b, c, d, e, f, g, sm):
    np.random.seed(it)
    _, Y, _, X, Z, _ = gen_data_no_controls(n, pw, pz, px, a, b, c, d, e, f, g, sm=sm)
    Y = Y.reshape(-1, 1)
    Y = Y - Y.mean(axis=0)
    X = X - X.mean(axis=0)
    Z = Z - Z.mean(axis=0)
    n = Z.shape[0]
    XZ = X.T @ Z / n
    XX = X.T @ X / n
    ZZinv = scipy.linalg.pinvh(Z.T @ Z / n)
    Q = Z @ ZZinv @ XZ.T
    
    # J = X.T @ Q / n
    # eigv, eigvec = scipy.linalg.eigh(J)
    # clipped_inv_eig = np.zeros(eigv.shape)
    # filter = eigv > 1 / np.sqrt(n)
    # clipped_inv_eig[filter] = 1 / eigv[filter]
    # Jinv = eigvec @ np.diag(clipped_inv_eig) @ eigvec

    Jinv = scipy.linalg.pinvh(X.T @ Q / n + XX / n**(0.4))
    gamma = Jinv @ (Q.T @ Y / n)
    
    inf = Q * Y - Q * (X @ gamma)
    inf = inf @ Jinv.T
    cov = (inf.T @ inf / n)
    stderr = np.sqrt(np.diag(cov) / n)

    return *gamma.flatten(), *stderr

In [None]:
np.random.seed(123)
n = 10000
pw = 1
pz, px = 5, 5
n_splits = 3
# Indirect effect is a*b, direct effect is c
a, b, c = 1.0, 1.0, .5
# D has direct relationship to Z, Z has direct relationship to M, 
# X has direct relationship to M, X has direct relationship to Y
d, e, f, g = 0.0, 1.0, 1.0, 1.0
sm = 2.0
if px == 1 and pz == 1:
    true_params(pw, pz, px, a, b, c, d, e, f, g, sm)

In [None]:
exp(0, n, pw, pz, px, a, b, c, d, e, f, g, sm)

In [None]:
res = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it, n, pw, pz, px, a, b, c, d, e, f, g, sm)
                                     for it in range(100))

In [None]:
res = np.array(res)

In [None]:
np.mean(res[:, :pz], axis=0), np.std(res[:, :pz], axis=0), np.percentile(res[:, :pz], 5, axis=0), np.percentile(res[:, :pz], 95, axis=0)

In [None]:
np.mean(res[:, pz:], axis=0), np.percentile(res[:, pz:], 5, axis=0), np.percentile(res[:, pz:], 95, axis=0)

In [None]:
plt.hist(res[:, :pz])
plt.show()