In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import functools
import pandas as pd
import sys

sys.path.append("../")
from helpers.hsic import *
from joblib import Parallel, delayed

from helpers.trainer import train_HSIC_IV
from models.baselines import Poly2SLS, Radial2SLSRidge, PredPolyRidge, PredRadialRidge
from models.kernel import CategoryKernel, RBFKernel
from models.hsicx import LinearHSICX, RadialHSICX
from helpers.utils import med_sigma, to_torch, gen_data, gen_radial_fn

np.random.seed(1)
n_rep = 10
n = 1000
num_basis = 10
data_limits = (-7, 7)

vis_mode = False
compare_mode = True

config_radial = {'batch_size': 256, 'lr': 1e-2, 'num_restart': 4, 'max_epoch': 600}
config_linear = {'batch_size': 256, 'lr': 1e-2, 'num_restart': 4, 'max_epoch': 200}

config_all = {'radial': config_radial, 'linear': config_linear}

In [None]:
for fn in ['radial', 'linear']:
    config = config_all[fn]

    if fn == 'linear':
        f = lambda x: -2 * x
        w = np.array([-2])
    else:
        f, w = gen_radial_fn(num_basis=num_basis, data_limits=data_limits, ret_w=True)
    for instrument in ['Binary', 'Gaussian']:
        res_df = None
        ret_df_vis = None
        iv_type = 'mix_{}'.format(instrument)
        
        # get a fix x_vis
        _, _, _, X_vis = gen_data(f, n, iv_type)

        if fn == 'linear':
            alphas = [0, 0.1, 0.2, 0.3, 0.4]
        else:
            alphas = np.linspace(0, 1, 5)

        for j in range(len(alphas)):
            alpha = alphas[j]


            def rep_function(i):
                X, Y, Z, _ = gen_data(f, n, iv_type, alpha=alpha)
                X_o, Y_o, _, _ = gen_data(f, n, iv_type, alpha=alpha, oracle=True)
                X_test, _, _, _ = gen_data(f, X_vis.shape[0], iv_type, alpha=alpha)

                # get y_hat for MSE loss
                if fn == 'linear':
                    mse_reg = PredPolyRidge(degree=1, bias=False)
                    oracle_reg = PredPolyRidge(degree=1, bias=False)
                else:
                    mse_reg = PredRadialRidge(num_basis=num_basis, data_limits=data_limits, bias=False)
                    oracle_reg = PredRadialRidge(num_basis=num_basis, data_limits=data_limits, bias=False)

                mse_reg.fit(X, Y)
                y_hat_mse = mse_reg.predict(X_test)
                y_hat_mse_vis = mse_reg.predict(X_vis)
                mse_coef = mse_reg.reg.coef_
                oracle_reg.fit(X_o, Y_o)
                y_hat_oracle = oracle_reg.predict(X_test)

                s_z = med_sigma(Z)
                kernel_e = RBFKernel(sigma=1)

                if instrument == 'Binary':
                    kernel_z = CategoryKernel()
                else:
                    kernel_z = RBFKernel(sigma=s_z)

                if fn == 'linear':
                    hsic_net = LinearHSICX(input_dim=1,
                                           lr=config['lr'],
                                           lmd=0,
                                           kernel_e=kernel_e,
                                           kernel_z=kernel_z,
                                           bias=False)
                else:
                    hsic_net = RadialHSICX(input_dim=1,
                                           num_basis=num_basis,
                                           data_limits=data_limits,
                                           lr=config['lr'],
                                           lmd=0,
                                           kernel_e=kernel_e,
                                           kernel_z=kernel_z,
                                           bias=False)

                hsic_net.load_state_dict(mse_coef)
                hsic_net = train_HSIC_IV(hsic_net, config, X, Y, Z)

                intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
                y_hat_hsic = intercept_adjust + hsic_net(to_torch(X_test))
                y_hat_hsic_vis = intercept_adjust + hsic_net(to_torch(X_vis))

                # hsic_oracle
                hsic_net.load_state_dict(w.flatten())

                hsic_net = train_HSIC_IV(hsic_net, config, X, Y, Z)

                intercept_adjust = Y.mean() - hsic_net(to_torch(X)).mean()
                y_hat_hsic_oracle = intercept_adjust + hsic_net(to_torch(X_test))
                y_hat_hsic_oracle_vis = intercept_adjust + hsic_net(to_torch(X_vis))

                # 2SLS
                if fn == 'linear':
                    poly2SLS = Poly2SLS(degree=1, bias=False)
                else:
                    poly2SLS = Radial2SLSRidge(num_basis=num_basis, data_limits=data_limits, bias=False)

                poly2SLS.fit(X, Y, Z)
                y_hat_2sls = poly2SLS.predict(X_test)
                y_hat_2sls_vis = poly2SLS.predict(X_vis)

                inner_df = pd.DataFrame()
                inner_df_vis = pd.DataFrame()

                inner_df['f_x'] = f(X_test)
                inner_df['Pred'] = y_hat_mse
                inner_df['HSIC-IV'] = y_hat_hsic.detach().numpy()
                inner_df['HSIC-Oracle'] = y_hat_hsic_oracle.detach().numpy()
                inner_df['2SLS'] = y_hat_2sls
                inner_df['Oracle'] = y_hat_oracle
                inner_df['alpha'] = alpha
                inner_df['run_id'] = i

                inner_df_vis['x_vis'] = X_vis
                inner_df_vis['f_x'] = f(X_vis)
                inner_df_vis['Pred'] = y_hat_mse_vis
                inner_df_vis['HSIC-IV'] = y_hat_hsic_vis.detach().numpy()
                inner_df_vis['2SLS'] = y_hat_2sls_vis
                inner_df_vis['alpha'] = alpha
                inner_df_vis['run_id'] = i

                return inner_df, inner_df_vis


            ret_df, ret_df_vis = zip(*Parallel(n_jobs=-2)(
                delayed(rep_function)(i=i) for i in range(n_rep)))

            ret_df = functools.reduce(lambda df1, df2: df1.append(df2, ignore_index=True), ret_df)
            ret_df_vis = functools.reduce(lambda df1, df2: df1.append(df2, ignore_index=True), ret_df_vis)

            if res_df is None:
                res_df = ret_df
                res_df_vis = ret_df_vis
            else:
                res_df = res_df.append(ret_df, ignore_index=True)
                res_df_vis = res_df_vis.append(ret_df_vis, ignore_index=True)

        if compare_mode:
            melt_res_df = res_df.melt(id_vars=['f_x', 'alpha', 'run_id'], var_name='Method',
                                      value_name='y_pred')
            melt_res_df['MISE'] = (melt_res_df['f_x'] - melt_res_df['y_pred']) ** 2
            final_df = melt_res_df.groupby(['Method', 'alpha', 'run_id'])['MISE'].mean().reset_index()
            final_df['alpha'] = np.round(final_df.alpha, 2)
            final_df.to_csv("../results/compare_df_fn_{}_ins_{}.csv".format(fn, instrument),
                            index=False)

        if vis_mode:
            res_df_vis.to_csv("results/vis_df_fn_{}_ins_{}.csv".format(fn, instrument),
                              index=False)
            for alpha in alphas:
                res_alpha = res_df_vis[res_df_vis.alpha == alpha]
                melt_res_df = res_alpha.melt(id_vars=['x_vis', 'alpha', 'run_id'], var_name='Method',
                                             value_name='y')
                X, Y, _, _ = gen_data(f, n, iv_type, alpha=alpha)
                sns.scatterplot(X, Y, color='.5', linewidth=0, alpha=0.5)
                sns.lineplot(data=melt_res_df.query("Method!='f_x'"), x="x_vis", y="y", units='run_id',
                             hue='Method', estimator=None, lw=1, alpha=0.5,
                             palette=["blue", "red", "green"])
                df_0 = res_alpha.query("run_id == 0.0")
                plt.plot(df_0['x_vis'].values, df_0['f_x'].values, '--',
                         label='f_x', c='black', lw=1.8)
                plt.legend()
                plt.ylabel('Y')
                plt.xlabel('X')
                plt.ylim(Y.min() * 1.2, Y.max() * 1.2)
                xlimL, xlimR = np.quantile(X, 0.01), np.quantile(X, 0.99)
                plt.xlim(xlimL, xlimR)
                plt.savefig(
                    'results/vis_plot_radial_fn_{}_ins_{}_alpha_{}.pdf'.format(fn, instrument, str(alpha)),
                    bbox_inches="tight")
                plt.close()

In [None]:
all_df = pd.DataFrame()
for fn in ['linear', 'radial']:
    for instrument in ['Gaussian', 'Binary']:
        final_df = pd.read_csv("../results/compare_df_fn_{}_ins_{}.csv".format(fn, instrument))
        final_df[r'$Z$'] = instrument
        final_df[r'$f$'] = fn

        all_df = all_df.append(final_df, ignore_index=True)

all_df = all_df.replace({"Pred": "OLS", "HSIC-IV": "HSIC-X",
                         "linear": "Linear", "radial": "Non-linear"})

sns.set(font_scale=1.7, style='white', palette=sns.set_palette("tab10"))

g = sns.catplot(data=all_df, kind="point", log=True,
                x='alpha', y='MISE', hue='Method', col=r'$Z$', row=r'$f$',
                hue_order=['2SLS', 'OLS', 'Oracle', 'HSIC-X', 'HSIC-Oracle'],
                markers=["o", "x", "d", "s", "v"], linestyles=[':', '-', '--', '-.', ':'],
                capsize=.07, aspect=1.5, height=3.2, ci=95,
                sharex=False, sharey=False)

g._legend.remove()
g.set_xlabels(r'$\alpha$')
g.set_ylabels('MSE')
for ax in g.axes.flat:
    ax.set_yscale('log')

plt.legend(loc='upper center', bbox_to_anchor=(-.3, 3.1),
           ncol=5, fancybox=True, shadow=True, prop={'size': 17})
plt.savefig('../results/compare_alpha.pdf'.format(fn, instrument),
            bbox_inches="tight")
plt.close()