In [None]:
import sys
import functools
import torch
from joblib import Parallel, delayed
from torch.utils.data import TensorDataset
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append("../")
from helpers.trainer import train_HSIC_IV
from models.hsicx import LinearHSICX
from models.kernel import RBFKernel, CategoryKernel
from models.baselines import PredPolyRidge, Poly2SLS
from helpers.utils import gen_data_multi, to_torch, med_sigma

In [None]:
np.random.seed(1)

n_rep = 4
n = 1000
alpha = 1.0
config = {'batch_size': 256,
          'lr': 1e-2,
          'max_epoch': 600,
          'num_restart': 4}

np.random.seed(0)
rand_state = np.random.RandomState(0)
x_dims = [3, 5, 10]
z_dims = {3: range(1, 6), 5: range(1, 6), 10: range(6, 11)}
weights = {x_dim: np.random.RandomState(1).normal(0, 2, size=(x_dim, 1)) for x_dim in x_dims}

In [None]:
for instrument in ['Binary', 'Gaussian']:
    iv_type = 'mix_{}'.format(instrument)
    for x_dim in x_dims:
        res_df = None
        w = weights[x_dim]
        f = lambda x: (x @ w).flatten()
        z_dims_ = z_dims[x_dim]
        for z_dim in list(z_dims_):

            def rep_function(i):
                X, Y, Z = gen_data_multi(f, n, x_dim, z_dim, iv_type, alpha=alpha, oracle=False)
                X_o, Y_o, _ = gen_data_multi(f, n, x_dim, z_dim, iv_type, alpha=alpha, oracle=True)
                X_test, _, _ = gen_data_multi(f, int(10e4), x_dim, z_dim, iv_type, alpha=alpha, oracle=False)

                trainloader = torch.utils.data.DataLoader(TensorDataset(to_torch(X), to_torch(Y), to_torch(Z)),
                                                          batch_size=config['batch_size'],
                                                          shuffle=True, num_workers=0)

                # get y_hat for MSE loss
                mse_reg = PredPolyRidge(degree=1, bias=False)
                oracle_reg = PredPolyRidge(degree=1, bias=False)
                mse_reg.fit(X, Y)
                y_hat_mse = mse_reg.predict(X_test)
                mse_coef = mse_reg.reg.coef_
                oracle_reg.fit(X_o, Y_o)
                y_hat_oracle = oracle_reg.predict(X_test)

                s_z = med_sigma(Z)

                kernel_e = RBFKernel(sigma=1)
                if instrument == 'Binary':
                    kernel_z = CategoryKernel(one_hot=False)
                else:
                    kernel_z = RBFKernel(sigma=s_z)

                hsic_net = LinearHSICX(input_dim=x_dim,
                                       lr=config['lr'],
                                       lmd=0.0,
                                       kernel_e=kernel_e,
                                       kernel_z=kernel_z,
                                       bias=False)

                hsic_net.load_state_dict(mse_coef)
                hsic_net = train_HSIC_IV(hsic_net, config, X, Y, Z)

                intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
                y_hat_hsic = intercept_adjust + hsic_net(to_torch(X_test))
                y_hat_hsic = y_hat_hsic.detach().numpy()

                hsic_net = LinearHSICX(input_dim=x_dim,
                                       lr=config['lr'],
                                       lmd=0.0,
                                       kernel_e=kernel_e,
                                       kernel_z=kernel_z,
                                       bias=False)

                hsic_net.load_state_dict(w.flatten())
                hsic_net = train_HSIC_IV(hsic_net, config, X, Y, Z)

                intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
                y_hat_hsic_oracle = intercept_adjust + hsic_net(to_torch(X_test))
                y_hat_hsic_oracle = y_hat_hsic_oracle.detach().numpy()

                # 2SLS
                poly2SLS = Poly2SLS(degree=1, bias=False)
                poly2SLS.fit(X, Y, Z)

                inner_df = pd.DataFrame()

                inner_df['f_x'] = f(X_test)
                inner_df['Pred'] = y_hat_mse
                inner_df['HSIC-IV'] = y_hat_hsic
                inner_df['HSIC-Oracle'] = y_hat_hsic_oracle
                inner_df['2SLS'] = poly2SLS.predict(X_test)
                inner_df['Oracle'] = y_hat_oracle
                inner_df['alpha'] = alpha
                inner_df['z_dim'] = z_dim
                inner_df['run_id'] = i

                return inner_df


            ret_df = Parallel(n_jobs=4)(
                delayed(rep_function)(i=i) for i in range(n_rep))

            ret_df = functools.reduce(lambda df1, df2: df1.append(df2, ignore_index=True), ret_df)
            if res_df is None:
                res_df = ret_df
            else:
                res_df = res_df.append(ret_df, ignore_index=True)

        melt_res_df = res_df.melt(id_vars=['f_x', 'z_dim', 'alpha', 'run_id'], var_name='Method', value_name='y_pred')
        melt_res_df['MISE'] = (melt_res_df['f_x'] - melt_res_df['y_pred']) ** 2
        final_df = melt_res_df.groupby(['Method', 'z_dim', 'alpha', 'run_id'])['MISE'].mean().reset_index()
        final_df['alpha'] = np.round(final_df.alpha, 2)
        final_df.to_csv("../results/compare_df_multidim_ins_{}_xdim_{}.csv".format(instrument, x_dim),
                        index=False)

In [None]:
all_df = pd.DataFrame()
for instrument in ['Binary', 'Gaussian']:
    for x_dim in [3, 5, 10]:
        final_df = pd.read_csv("../results/compare_df_multidim_ins_{}_xdim_{}.csv".format(instrument, x_dim))
        final_df = final_df.replace({"Pred": "OLS",
                                     "HSIC-IV": "HSIC-X"})
        final_df = final_df.query("Method != 'HSIC-Oracle'")
        final_df[r'$d_X$'] = x_dim

        all_df = all_df.append(final_df, ignore_index=True)

sns.set(font_scale=1.8, style='white', palette=sns.set_palette("tab10"))

g = sns.catplot(data=all_df, kind="point", log=True,
                x='z_dim', y='MISE', hue='Method', alpha=.5,
                hue_order=['2SLS', 'OLS', 'Oracle', 'HSIC-X'],
                markers=["o", "x", "d", "s"], linestyles=[':', '-', '--', '-.'],
                capsize=.07, aspect=1.2, height=4, ci=95,
                col=r'$d_X$', sharey=False)
g._legend.remove()

plt.legend(loc='upper center', bbox_to_anchor=(-.2, 1.45),
           ncol=4, fancybox=True, shadow=True, prop={'size': 18})

g.set_xlabels(r'$d_Z$')
g.set_ylabels('MSE')
for ax in g.axes.flat:
    ax.set_yscale('log')
plt.savefig(
    '../results/compare_multidim.pdf',
    bbox_inches="tight")
plt.close()