In [None]:
import sys
import keras
from econml.iv.nnet import DeepIV
sys.path.append('DeepGMM/')
from DeepGMM.methods.toy_model_selection_method import ToyModelSelectionMethod

sys.path.append("../")
from helpers.trainer import train_mse, train_HSIC_IV
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import NNHSICX
import pandas as pd
from helpers.utils import med_sigma, to_torch, gen_data, gen_radial_fn
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import functools

In [None]:
np.random.seed(1)

n_rep = 8
n = 1000
num_basis = 10
data_limits = (-7, 7)

config_hsic = {'batch_size': 256, 'lr': 1e-2,
               'max_epoch': 700, 'num_restart': 4}

config_mse = {'batch_size': 256, 'lr': 5e-2,
              'max_epoch': 300}

f = gen_radial_fn(num_basis=num_basis, data_limits=data_limits)

In [None]:
for instrument in ['Gaussian', 'Binary']:
    res_df = None
    ret_df_vis = None
    # get a fix x_vis
    iv_type = 'mix_{}'.format(instrument)
    _, _, _, X_vis = gen_data(f, n, iv_type, var_effect=True)

    alphas = np.linspace(0, 1, 5)

    for j in range(len(alphas)):
        alpha = alphas[j]


        def rep_function(i):
            X, Y, Z, _ = gen_data(f, n, iv_type, alpha=alpha)
            # dev set for DeepGMM
            X_dev, Y_dev, Z_dev, _ = gen_data(f, n, iv_type, alpha=alpha)
            # oracle set
            X_o, Y_o, Z_o, _ = gen_data(f, n, iv_type, alpha=alpha, oracle=True)
            X_test, _, _, _ = gen_data(f, X_vis.shape[0], iv_type, alpha=alpha)
            
            # Pure predictive
            mse_net = NNHSICX(input_dim=1,
                              lr=config_mse['lr'],
                              lmd=-99)

            mse_net = train_mse(mse_net, config_mse, X, Y, Z)
            y_hat_mse = mse_net(to_torch(X_test)).detach().numpy()

            oracle_net = train_mse(mse_net, config_mse, X_o, Y_o, Z_o)
            y_hat_oracle = oracle_net(to_torch(X_test)).detach().numpy()

            # HSIC IV
            s_z = med_sigma(Z)
            kernel_e = RBFKernel(sigma=1)

            if instrument == 'Binary':
                kernel_z = CategoryKernel()
            else:
                kernel_z = RBFKernel(sigma=s_z)

            # non regularized HSIC IV
            hsic_net = NNHSICX(input_dim=1,
                               lr=config_hsic['lr'],
                               kernel_e=kernel_e,
                               kernel_z=kernel_z,
                               lmd=0)

            hsic_net.load_state_dict(mse_net)
            hsic_net = train_HSIC_IV(hsic_net, config_hsic, X, Y, Z, verbose=True)

            intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
            y_hat_hsic = intercept_adjust + hsic_net(to_torch(X_test))
            y_hat_hsic = y_hat_hsic.detach().numpy().copy()

            # regularized HSIC IV
            hsic_net = NNHSICX(input_dim=1,
                               lr=config_hsic['lr'],
                               kernel_e=kernel_e,
                               kernel_z=kernel_z,
                               lmd=5e-5)

            hsic_net.load_state_dict(mse_net)
            hsic_net = train_HSIC_IV(hsic_net, config_hsic, X, Y, Z, verbose=True)

            intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
            y_hat_hsic_pen = intercept_adjust + hsic_net(to_torch(X_test))
            y_hat_hsic_pen = y_hat_hsic_pen.detach().numpy().copy()

            # prepare data for DeepGMM
            dat = [X, Z, Y, X_dev, Z_dev, Y_dev]
            # to torch
            for k in range(len(dat)):
                dat[k] = to_torch(dat[k]).double()

            deepGMM = ToyModelSelectionMethod()
            deepGMM.fit(to_torch(X).double(), to_torch(Z).double(), to_torch(Y).double(), 
                        to_torch(X_dev).double(), to_torch(Z_dev).double(), to_torch(Y_dev).double(), 
                        g_dev=to_torch(f(X_dev)).double(), verbose=True)
            y_hat_deepGMM = deepGMM.predict(to_torch(X_test).double()).flatten().detach().numpy()

            # DeepIV
            treatment_model = keras.Sequential([keras.layers.Dense(64, activation='sigmoid', input_shape=(2,)),
                                                keras.layers.Dropout(0.17)])
            response_model = keras.Sequential([keras.layers.Dense(64, activation='sigmoid', input_shape=(2,)),
                                               keras.layers.Dropout(0.17),
                                               keras.layers.Dense(1)])
            est = DeepIV(n_components=10,
                         m=lambda z, x: treatment_model(keras.layers.concatenate([z, x])),
                         h=lambda t, x: response_model(keras.layers.concatenate([t, x])),
                         n_samples=10)
            
            context = np.zeros((X.shape[0], 1))
            context_test = np.zeros((X_test.shape[0], 1))

            est.fit(Y=Y, T=X, X=context, Z=Z)
            y_hat_DeepIV = est.predict(T=X_test, X=context_test)

            inner_df = pd.DataFrame()

            inner_df['f_x'] = f(X_test)
            inner_df['Pred'] = y_hat_mse
            inner_df['HSIC-IV'] = y_hat_hsic
            inner_df['HSIC-X-pen'] = y_hat_hsic_pen
            inner_df['DeepGMM'] = y_hat_deepGMM
            inner_df['DeepIV'] = y_hat_DeepIV
            inner_df['Oracle'] = y_hat_oracle
            inner_df['alpha'] = alpha
            inner_df['run_id'] = i

            return inner_df


        ret_df = Parallel(n_jobs=4)(
            delayed(rep_function)(i=i) for i in range(n_rep))

        ret_df = functools.reduce(lambda df1, df2: df1.append(df2, ignore_index=True), ret_df)

        if res_df is None:
            res_df = ret_df
        else:
            res_df = res_df.append(ret_df, ignore_index=True)

    melt_res_df = res_df.melt(id_vars=['f_x', 'alpha', 'run_id'], var_name='Method',
                              value_name='y_pred')
    melt_res_df['MISE'] = (melt_res_df['f_x'] - melt_res_df['y_pred']) ** 2
    final_df = melt_res_df.groupby(['Method', 'alpha', 'run_id'])['MISE'].mean().reset_index()
    final_df['alpha'] = np.round(final_df.alpha, 2)
    final_df.to_csv("../results/compare_df_NN_ins_{}.csv".format(instrument),
                    index=False)

In [None]:
all_df = pd.DataFrame()
for instrument in ['Binary', 'Gaussian']:
# for instrument in ['Gaussian']:
    final_df = pd.read_csv("../results/compare_df_NN_ins_{}.csv".format(instrument))
    final_df = final_df.replace({"HSIC-AR": "HSIC-X-pen", 
                                 "HSIC-IV": "HSIC-X",
                                 "Pred":"OLS",
                                 "D-GMM": "DeepGMM"})
    final_df[r'$Z$'] = instrument
    all_df = all_df.append(final_df, ignore_index=True)
    
sns.set(font_scale=1.8, style='white', palette=sns.set_palette("tab10"))

palette = np.array(sns.color_palette("tab10"))
palette[2:] = palette[1:-1]
palette[1] = palette[6]

g = sns.catplot(data=all_df, kind="point", log=True,
                x='alpha', y='MISE', hue='Method',
                hue_order=['DeepGMM', 'DeepIV', 'OLS', 'Oracle', 'HSIC-X', 'HSIC-X-pen'],
                markers=["o", "+", "x", "d", "s", "v"], linestyles=[':', '-', '--', '-.', ':', '-'],
                palette=palette,
                capsize=.07, aspect=1.2, height=4, ci=95,
                col=r'$Z$', sharey=False)

g._legend.remove()

plt.legend(loc='center right', bbox_to_anchor=(1.57, 0.5), 
  ncol=1, fancybox=True, shadow=True, prop={'size': 15.5})

for ax in g.axes.flat:
    ax.set_yscale('log')

g.set_xlabels(r'$\alpha$')
g.set_ylabels('MSE')

plt.savefig('../results/compare_NN_update.pdf',
            bbox_inches="tight")
plt.close()