In [7]:
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import mean_ci
from mliv.neuralnet import AGMMEarlyStop as AGMM
from mliv.neuralnet.moments import avg_small_diff
from sklearn.ensemble import RandomForestRegressor
import joblib
from joblib import Parallel, delayed
from mliv.cct.mc2 import MC2
from mliv.rkhs import ApproxRKHSIVCV
from sklearn.model_selection import cross_val_predict
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.kernel_approximation import Nystroem
from sklearn.pipeline import FeatureUnion, Pipeline
import scipy

In [8]:
# average finite difference moment
def moment_fn(x, test_fn):
    epsilon = 0.1
    t1 = np.hstack([x[:, [0]] + epsilon, x[:, 1:]])
    t0 = np.hstack([x[:, [0]] - epsilon, x[:, 1:]])
    return (test_fn(t1) - test_fn(t0)) / (2 * epsilon)

In [80]:
it = 0
n = 5000
mc2_gen = MC2(n, 100, None, dimension=10, corr=0.5)
npvec, *_ = mc2_gen.data(it)
Z, X, Y = npvec['instrument'], npvec['endogenous'], npvec['response']
n_z = Z.shape[1]
n_x = X.shape[1]

In [81]:
Z_train, Z_val, X_train, X_val, Y_train, Y_val = train_test_split(
        Z, X, Y, test_size=.5, shuffle=True)
    
# ztrans = Nystroem(n_components=100)
# xtrans = Nystroem(n_components=100)
ztrans = PolynomialFeatures(degree=2)
xtrans = PolynomialFeatures(degree=2)
# ztrans = FeatureUnion([('poly', PolynomialFeatures(degree=2)), ('nys', Nystroem(n_components=10))])
# xtrans = FeatureUnion([('poly', PolynomialFeatures(degree=2)), ('nys', Nystroem(n_components=10))])
ztrans = Pipeline([('trans', ztrans), ('scale', StandardScaler())])
xtrans = Pipeline([('trans', xtrans), ('scale', StandardScaler())])
Psi = ztrans.fit_transform(Z_train)
Phi = xtrans.fit_transform(X_train)
mPhi = moment_fn(X_train, xtrans.transform)

In [82]:
CovPsi = Psi.T @ Psi
CovPhiPsi = Phi.T @ Psi
Phival = xtrans.transform(X_val)
Psival = ztrans.transform(Z_val)
mPhival = moment_fn(X_val, xtrans.transform)
moment_val = np.mean(mPhival, axis=0)

best_violation = np.inf
for alpha in np.logspace(-6, 1, 5):
    regCov = scipy.linalg.pinv(CovPsi + alpha * n * np.eye(Psi.shape[1]))
    Sigma = CovPhiPsi @ regCov @ CovPsi @ regCov @  CovPhiPsi.T
    for beta in np.logspace(-6, 1, 5):
        xi = scipy.linalg.pinv(Sigma + beta * n * np.eye(Phi.shape[1])) @ np.sum(mPhi, axis=0)
        for gamma in np.logspace(-6, 1, 5):
            qparam = scipy.linalg.pinv(CovPsi + gamma * n * np.eye(Psi.shape[1])) @ CovPhiPsi.T @ xi

            representer_val = np.mean((Psival @ qparam).reshape(-1, 1) * Phival, axis=0)
            violation = np.linalg.norm(moment_val - representer_val, ord=2)
            if violation <= best_violation:
                best_alpha = alpha
                best_beta = beta
                best_gamma = gamma
                best_violation = violation

In [83]:
alpha = best_alpha
beta = best_beta
gamma = best_gamma
regCov = scipy.linalg.pinv(CovPsi + alpha * n * np.eye(Psi.shape[1]))
Sigma = CovPhiPsi @ regCov @ CovPsi @ regCov @  CovPhiPsi.T
xi = scipy.linalg.pinv(Sigma + beta * n * np.eye(Phi.shape[1])) @ np.sum(mPhi, axis=0)
qparam = scipy.linalg.pinv(CovPsi + gamma * n * np.eye(Psi.shape[1])) @ CovPhiPsi.T @ xi

In [84]:
best_alpha, best_beta, best_gamma, best_violation

(1e-06, 0.0031622776601683794, 0.0031622776601683794, 1.271478323909948)

In [85]:
agmm = ApproxRKHSIVCV(n_components=200)
agmm.fit(Z_train, X_train, Y_train)

<mliv.rkhs.rkhsiv.ApproxRKHSIVCV at 0x17271cca960>

In [86]:
direct = moment_fn(X_val, agmm.predict).flatten()
residual = (Y_val - agmm.predict(X_val)).flatten()
qvalues = Psival @ qparam
pseudo = direct + qvalues * residual

reg = mean_ci(direct)
dr = mean_ci(pseudo)
ipw = mean_ci(qvalues * Y_val.flatten())
reg, ipw, dr

((1.0399470362147252, 0.918951053053702, 1.1609430193757484),
 (1.4583955377281368, 0.6870931252779763, 2.229697950178297),
 (1.0471596379471688, 0.7875847212933229, 1.3067345546010147))

In [87]:
xivalues = xtrans.transform(X_val) @ xi
coef = np.mean(qvalues * residual) / np.mean(qvalues * xivalues)
pseudo_tmle = direct + coef * (mPhival @ xi)
pseudo_tmle += qvalues * (residual - coef * xivalues)
tmle = mean_ci(pseudo_tmle)
tmle

(1.0489932555527517, 0.7893646343444162, 1.3086218767610873)

In [96]:
def exp(it, n, dim, corr):
    mc2_gen = MC2(n, 100, None, dimension=dim, corr=corr)
    npvec, *_ = mc2_gen.data(it)
    Z, X, Y = npvec['instrument'], npvec['endogenous'], npvec['response']
    n_z = Z.shape[1]
    n_x = X.shape[1]
    
    Z_train, Z_val, X_train, X_val, Y_train, Y_val = train_test_split(
        Z, X, Y, test_size=.5, shuffle=True)
    
    ztrans = PolynomialFeatures(degree=2)
    xtrans = PolynomialFeatures(degree=2) 
#     ztrans = Nystroem(n_components=100)
#     xtrans = Nystroem(n_components=100)
#     ztrans = FeatureUnion([('poly', PolynomialFeatures(degree=2)), ('nys', Nystroem(n_components=5))])
#     xtrans = FeatureUnion([('poly', PolynomialFeatures(degree=2)), ('nys', Nystroem(n_components=5))])
    ztrans = Pipeline([('trans', ztrans), ('scale', StandardScaler())])
    xtrans = Pipeline([('trans', xtrans), ('scale', StandardScaler())])

    Psi = ztrans.fit_transform(Z_train)
    Phi = xtrans.fit_transform(X_train)
    mPhi = moment_fn(X_train, xtrans.transform)
    
#     Sigma = Phi.T @ Psi @ scipy.linalg.pinv(Psi.T @ Psi) @  Psi.T @ Phi
#     xi = scipy.linalg.pinv(Sigma) @ np.sum(mPhi, axis=0)
#     qparam = scipy.linalg.pinv(Psi.T @ Psi) @ Psi.T @ Phi @ xi

    CovPsi = Psi.T @ Psi
    CovPhiPsi = Phi.T @ Psi
    Phival = xtrans.transform(X_val)
    Psival = ztrans.transform(Z_val)
    mPhival = moment_fn(X_val, xtrans.transform)
    moment_val = np.mean(mPhival, axis=0)

    best_violation = np.inf
    for alpha in np.logspace(-6, 1, 5):
        regCov = scipy.linalg.inv(CovPsi + alpha * n * np.eye(Psi.shape[1]))
        Sigma = CovPhiPsi @ regCov @ CovPsi @ regCov @  CovPhiPsi.T
        for beta in np.logspace(-6, 1, 5):
            xi = scipy.linalg.inv(Sigma + beta * n * np.eye(Phi.shape[1])) @ np.sum(mPhi, axis=0)
            for gamma in np.logspace(-6, 1, 5):
                qparam = scipy.linalg.inv(CovPsi + gamma * n * np.eye(Psi.shape[1])) @ CovPhiPsi.T @ xi
                
                # calculating the violation in the riesz representation property for each feature
                #  E[m(W; phi)] = E[q(Z) * phi(X)]
                # for every feature phi.
                representer_val = np.mean((Psival @ qparam).reshape(-1, 1) * Phival, axis=0)
                violation = np.linalg.norm(moment_val - representer_val, ord=2)
                if violation <= best_violation:
                    best_alpha = alpha
                    best_beta = beta
                    best_gamma = gamma
                    best_violation = violation

    alpha = best_alpha
    beta = best_beta
    gamma = best_gamma
    regCov = scipy.linalg.inv(CovPsi + alpha * n * np.eye(Psi.shape[1]))
    Sigma = CovPhiPsi @ regCov @ CovPsi @ regCov @  CovPhiPsi.T
    xi = scipy.linalg.inv(Sigma + beta * n * np.eye(Phi.shape[1])) @ np.sum(mPhi, axis=0)
    qparam = scipy.linalg.inv(CovPsi + gamma * n * np.eye(Psi.shape[1])) @ CovPhiPsi.T @ xi

    agmm = ApproxRKHSIVCV(n_components=200)
    agmm.fit(Z_train, X_train, Y_train)
    
    direct = moment_fn(X_val, agmm.predict).flatten()
    residual = (Y_val - agmm.predict(X_val)).flatten()
    qvalues = Psival @ qparam
    pseudo = direct + qvalues * residual

    reg = mean_ci(direct)
    dr = mean_ci(pseudo)
    ipw = mean_ci(qvalues * Y_val.flatten())
    
    xivalues = Phival @ xi
    coef = np.mean(qvalues * residual) / np.mean(qvalues * xivalues)
    pseudo_tmle = direct + coef * (mPhival @ xi)
    pseudo_tmle += qvalues * (residual - coef * xivalues)
    tmle = mean_ci(pseudo_tmle)

    return dr, tmle, ipw, reg

In [97]:
import pandas as pd 

n = 5000
true = 1.0

for n in [5000]:
    for n_x in [5, 10]:
        for corr in [0.0, 0.5]:
            print(n, n_x, corr)
            results = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it, n, n_x, corr)
                                                            for it in range(100))
            df = {}
            for it, method in enumerate(['dr', 'tmle', 'ipw', 'direct']):
                data = np.array([r[it] for r in results])
                confidence = .95
                se = (data[:, 2] - data[:, 0]) / scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                confidence = .99
                data[:, 1] = data[:, 0] - se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                data[:, 2] = data[:, 0] + se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                if method in ['dr', 'tmle']:
                    cov = f'{100*np.mean((data[:, 1] <= true) & (true <= data[:, 2])):.0f}'
                else:
                    cov = 'NA'
                df[method] = {'cov': cov,
                              'rmse': f'{np.sqrt(np.mean((data[:, 0] - true)**2)):.3f}',
                              'bias': f'{np.abs(np.mean((data[:, 0] - true))):.3f}',
                              'std': f'{np.std(data[:, 0]):.3f}'}
            
            display(pd.DataFrame(df))

5000 5 0.0


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 20 concurrent workers.
[Parallel(n_jobs=-1)]: Done  95 out of 100 | elapsed:  1.1min remaining:    3.4s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  1.1min finished


Unnamed: 0,dr,tmle,ipw,direct
cov,100.0,96.0,,
rmse,0.099,0.128,0.342,0.101
bias,0.019,0.034,0.114,0.03
std,0.097,0.124,0.322,0.096


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 20 concurrent workers.


5000 5 0.5


[Parallel(n_jobs=-1)]: Done  95 out of 100 | elapsed:   54.5s remaining:    2.8s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:   54.7s finished


Unnamed: 0,dr,tmle,ipw,direct
cov,95.0,91.0,,
rmse,0.126,0.144,0.487,0.09
bias,0.052,0.065,0.039,0.027
std,0.114,0.129,0.486,0.085


5000 10 0.0


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 20 concurrent workers.
[Parallel(n_jobs=-1)]: Done  95 out of 100 | elapsed:  1.0min remaining:    3.1s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  1.0min finished


Unnamed: 0,dr,tmle,ipw,direct
cov,94.0,86.0,,
rmse,0.161,0.225,0.374,0.176
bias,0.078,0.133,0.104,0.131
std,0.14,0.182,0.359,0.118


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 20 concurrent workers.


5000 10 0.5


[Parallel(n_jobs=-1)]: Done  95 out of 100 | elapsed:  1.3min remaining:    4.1s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  1.3min finished


Unnamed: 0,dr,tmle,ipw,direct
cov,65.0,51.0,,
rmse,0.314,0.405,0.501,0.157
bias,0.257,0.326,0.021,0.078
std,0.182,0.24,0.501,0.137
