In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# IHDP: RieszNet

## Library Imports

In [None]:
from pathlib import Path
import os
import glob
from joblib import dump, load
import pandas as pd
import scipy
import scipy.stats
import scipy.special
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from utils.riesznet import RieszNet
from utils.moments import ate_moment_fn
from utils.ihdp_data import *

## Moment Definition

In [None]:
moment_fn = ate_moment_fn

## MAE Experiment

In [None]:
data_base_dir = "./data/IHDP/sim_data"
simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

### Estimator Settings

In [None]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg = 1.0
riesz_weight = 0.1

bs = 64
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200
        self.monomials = list(_combinations(n_t, degree, interaction_only))
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))


    def forward(self, x):
        poly = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials], dim=1)
        feats = self.common(x)
        riesz = self.riesz_nn(feats) + self.riesz_poly(poly)
        reg = self.reg_nn0(feats) * (1 - x[:, [0]]) + self.reg_nn1(feats) * x[:, [0]] + self.reg_poly(poly)
        return torch.cat([reg, riesz], dim=1)

In [None]:
nsims = 1000
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)

res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'MAE': np.mean(np.abs(point - truth)),
                        'std. err.': np.std(np.abs(point - truth)) / np.sqrt(nsims),
                        }
    print("{} : MAE = {:.3f} +/- {:.3f}".format(method, res_dict[method]['MAE'], res_dict[method]['std. err.']))

In [None]:
path = './results/IHDP/RieszNet/MAE'

if not os.path.exists(path):
    os.makedirs(path)
            
dump(res_dict, path + '/IHDP_MAE_NN.joblib')

### Table

In [None]:
path = './results/IHDP/RieszNet/MAE'

if not os.path.exists(path):
    os.makedirs(path)
    
methods_str = ["DR", "Direct", "IPS"] 

with open(path + '/IHDP_MAE_NN.tex', "w") as f:
    f.write("\\begin{tabular}{lc} \n" +
            "\\toprule \n" +
            "& MAE $\\pm$ std. err. \\\\ \n" +
            "\\midrule \n" +
            "\\multicolumn{2}{l}{\\textbf{Auto-DML:}} \\\\ \n")
    
    for i, method in enumerate(methods):
        f.write(" & ".join([methods_str[i], "{:.3f} $\\pm$ {:.3f}".format(res_dict[method]['MAE'], 
                                                                          res_dict[method]['std. err.'])]) + " \\\\ \n")

    f.write("\\multicolumn{2}{l}{\\textbf{Benchmark:}} \\\\"
            + "\n Dragonnet & 0.146 & 0.010 \\\\ \n \\bottomrule \n \\end{tabular}")

## Coverage Experiment

In [None]:
data_base_dir = "./data/IHDP/sim_data_redraw_T"
simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

In [None]:
def rmse_fn(y_pred, y_true):
    return np.sqrt(np.mean((y_pred - y_true)**2))

nsims = 100
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)
                        
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

In [None]:
path = './results/IHDP/RieszNet/coverage'

if not os.path.exists(path):
    os.makedirs(path)
    
dump(res_dict, path + '/IHDP_coverage_NN.joblib')

### Histogram

In [None]:
path = './results/IHDP/RieszNet/coverage'

if not os.path.exists(path):
    os.makedirs(path)
    
method_strs = ["{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}".format(method, d['bias'], d['rmse'], d['cov'])
               for method, d in res_dict.items()]
plt.title("\n".join(method_strs))
for method, d in res_dict.items():
    plt.hist(np.array(d['point']), alpha=.5, label=method)
plt.axvline(x = np.mean(truth), label='true', color='red')
plt.legend()
plt.savefig(path + '/IHDP_coverage_NN.pdf', bbox_inches='tight')
plt.show()