In [1]:
### imports
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import mean_ci
from mliv.neuralnet import AGMMEarlyStop as AGMM
from mliv.neuralnet.moments import avg_small_diff
from sklearn.ensemble import RandomForestRegressor
import joblib
import pandas as pd
from collections import OrderedDict


def plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true):
    plt.title(f'fname={fname}, n={n}, strength={iv_strength}, true={true:.3f}\n'
              f'dr: Cov={np.mean((dr[:, 1] <= true) & (true <= dr[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((dr[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((dr[:, 0]-true)):.3f}\n'
              f'tmle: Cov={np.mean((tmle[:, 1] <= true) & (true <= tmle[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((tmle[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((tmle[:, 0]-true)):.3f}\n'
              f'ipw: Cov={np.mean((ipw[:, 1] <= true) & (true <= ipw[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((ipw[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((ipw[:, 0]-true)):.3f}\n'
              f'direct: Cov={np.mean((direct[:, 1] <= true) & (true <= direct[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((direct[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((direct[:, 0]-true)):.3f}\n')
    plt.hist(dr[:, 0], label='dr')
    plt.hist(tmle[:, 0], label='tmle', alpha=.4)
    plt.hist(ipw[:, 0], label='ipw', alpha=.4)
    plt.hist(direct[:, 0], label='direct', alpha=.4)
    plt.legend()

n_z = 1
n_t = 1
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)

In [None]:
for clever in [False, True]:
    resd = {}
    for n_t in [1]:
        resd[n_t] = {}
        for fname in ['abs', '2dpoly', 'sigmoid', 'sin']:
            resd[n_t][fname] = OrderedDict()
            for n in [500, 1000, 2000]:
                lambda_l2_h = .1/n**(.9)
                print(lambda_l2_h)
                nkey = f'$n={n}$'
                resd[n_t][fname][nkey] = {}
                for iv_strength in [0.2, 0.5]:
                    true, results = joblib.load(f'res_fn_{fname}_n_{n}_n_t_{n_t}_stregth_{iv_strength}_eps_{0.1}_clever_{clever}_l2h_{lambda_l2_h:.4f}.jbl')
                    ivkey = f'$\rho={iv_strength}$'
                    resd[n_t][fname][nkey][ivkey] = {}
                    for it, method in enumerate(['dr', 'tmle', 'ipw', 'direct']):
                        data = np.array([r[it] for r in results])
                        if method in ['dr', 'tmle']:
                            cov = f'{100*np.mean((data[:, 1] <= true) & (true <= data[:, 2])):.0f}'
                        else:
                            cov = 'NA'
                        resd[n_t][fname][nkey][ivkey][method] = {
                                        'cov': cov,
                                        'rmse': f'{np.sqrt(np.mean((data[:, 0]-true)**2)):.3f}',
                                        'bias': f'{np.abs(np.mean((data[:, 0]-true))):.3f}',
                                        'std': f'{np.std(data[:, 0]):.3f}'}
                    resd[n_t][fname][nkey][ivkey] = pd.concat({f'${true:.2f}$': pd.DataFrame(resd[n_t][fname][nkey][ivkey])})
                resd[n_t][fname][nkey] = pd.concat(resd[n_t][fname][nkey], sort=False)
            resd[n_t][fname] = pd.concat(resd[n_t][fname], sort=False)
        resd[n_t] = pd.concat(resd[n_t], sort=False)
    display(pd.concat(resd).unstack(level=5))
    print(pd.concat(resd).unstack(level=5).to_latex(bold_rows=True, multirow=True,
                                                    multicolumn=True, escape=False,
                                                    column_format='lll||lll|lll|lll|lll|',
                                                    multicolumn_format='c|'))

In [None]:
for clever in [False]:
    resd = {}
    for n_t in [1]:
        resd[n_t] = {}
        for fname in ['2dpoly']:
            resd[n_t][fname] = OrderedDict()
            for n in [2000, 20000]:
                lambda_l2_h = .1/n**(.9)
                nkey = f'$n={n}$'
                resd[n_t][fname][nkey] = {}
                for iv_strength in [0.05, 0.1]:
                    true, results = joblib.load(f'res_fn_{fname}_n_{n}_n_t_{n_t}_stregth_{iv_strength}_eps_{0.1}_clever_{clever}_l2h_{lambda_l2_h:.4f}.jbl')
                    ivkey = f'$\rho={iv_strength}$'
                    resd[n_t][fname][nkey][ivkey] = {}
                    for it, method in enumerate(['dr', 'tmle', 'ipw', 'direct']):
                        data = np.array([r[it] for r in results])
                        if method in ['dr', 'tmle']:
                            cov = f'{100*np.mean((data[:, 1] <= true) & (true <= data[:, 2])):.0f}'
                        else:
                            cov = 'NA'
                        resd[n_t][fname][nkey][ivkey][method] = {
                                        'cov': cov,
                                        'rmse': f'{np.sqrt(np.mean((data[:, 0]-true)**2)):.3f}',
                                        'bias': f'{np.abs(np.mean((data[:, 0]-true))):.3f}',
                                        'std': f'{np.std(data[:, 0]):.3f}'}
                    resd[n_t][fname][nkey][ivkey] = pd.concat({f'${true:.2f}$': pd.DataFrame(resd[n_t][fname][nkey][ivkey])})
                resd[n_t][fname][nkey] = pd.concat(resd[n_t][fname][nkey], sort=False)
            resd[n_t][fname] = pd.concat(resd[n_t][fname], sort=False)
        resd[n_t] = pd.concat(resd[n_t], sort=False)
    display(pd.concat(resd).unstack(level=5))
    print(pd.concat(resd).unstack(level=5).to_latex(bold_rows=True, multirow=True,
                                                    multicolumn=True, escape=False,
                                                    column_format='lll||lll|lll|lll|lll|',
                                                    multicolumn_format='c|'))

In [None]:
fname = 'cct'
for clever in [False]:
    resd = {}
    for n_t in [0, 5, 10]:
        resd[n_t] = {}
        for n in [1000, 5000]:
            lambda_l2_h = .1/n**(.9)
            nkey = f'$n={n}$'
            resd[n_t][nkey] = {}
            for iv_strength in [0.0, 0.5]:
                if n_t == 0 and iv_strength == 0.5:
                    continue

                true, results = joblib.load(f'res_fn_{fname}_n_{n}_n_t_{n_t}_stregth_{iv_strength}_eps_{0.1}_clever_{clever}_l2h_{lambda_l2_h:.4f}.jbl')
                ivkey = f'$\\rho={iv_strength}$'
                resd[n_t][nkey][ivkey] = {}
                for it, method in enumerate(['dr', 'tmle', 'direct']):
                    data = np.array([r[it] for r in results])
                    confidence = .95
                    se = (data[:, 2] - data[:, 0]) / scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                    confidence = .99
                    data[:, 1] = data[:, 0] - se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                    data[:, 2] = data[:, 0] + se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
                    if method in ['dr', 'tmle']:
                        cov = f'{100*np.mean((data[:, 1] <= true) & (true <= data[:, 2])):.0f}'
                    else:
                        cov = 'NA'
                    resd[n_t][nkey][ivkey][method] = {
                                    'cov': cov,
                                    'rmse': f'{np.sqrt(np.mean((data[:, 0]-true)**2)):.3f}',
                                    'bias': f'{np.abs(np.mean((data[:, 0]-true))):.3f}',
                                    'std': f'{np.std(data[:, 0]):.3f}'}
                resd[n_t][nkey][ivkey] = pd.DataFrame(resd[n_t][nkey][ivkey])
            resd[n_t][nkey] = pd.concat(resd[n_t][nkey], sort=False)
        resd[n_t] = pd.concat(resd[n_t], sort=False)
    display(pd.concat(resd).unstack(level=3))
    print(pd.concat(resd).unstack(level=3).to_latex(bold_rows=True, multirow=True,
                                                    multicolumn=True, escape=False,
                                                    column_format='lll||llll|llll|llll|',
                                                    multicolumn_format='c|'))