In [1]:
### Magic functinos
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
%matplotlib inline

In [2]:
### imports
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import mean_ci
from mliv.neuralnet import AGMMEarlyStop as AGMM
from mliv.neuralnet.moments import avg_small_diff
from sklearn.ensemble import RandomForestRegressor

## $\ell_2$-Regularized AGMM with Neural Net Test Function

We solve the problem:
\begin{equation}
\min_{\theta} \max_{w} \frac{1}{n} \sum_i (y_i - h_{\theta}(x_i)) f_w(z_i) - f_w(z_i)^2
\end{equation}
where $h_{\theta}$ and $f_w$ are two neural nets.

In [3]:
def exp(it, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn):
    np.random.seed(it)
    
    #####
    # Neural network parameters
    ####
    p = 0.1 # dropout prob of dropout layers throughout notebook
    n_hidden = 100 # width of hidden layers throughout notebook

    learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                            #nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, 1))

    # For any method that uses an unstructured adversary test function f(z)
    adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
                                 #nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ReLU(),
                                 nn.Dropout(p=p), nn.Linear(n_hidden, 1))

    learner_lr = 1e-4
    adversary_lr = 1e-4
    learner_l2 = 1e-3
    adversary_l2 = 1e-3
    n_epochs = 200
    bs = 100
    burn_in = 100
    device = None
    
    ######
    # Train test split
    ######
    Z, T, Y, true_fn = get_data(n, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
    Z_train, Z_val, T_train, T_val, Y_train, Y_val = train_test_split(Z, T, Y, test_size=.5, shuffle=True)
    Z_train, T_train, Y_train = map(lambda x: torch.Tensor(x), (Z_train, T_train, Y_train))
    Z_val, T_val, Y_val = map(lambda x: torch.Tensor(x).to(device), (Z_val, T_val, Y_val))

    
    #####
    # Train IV function and "riesz" representer xi
    #####
    np.random.seed(12356)
    agmm = AGMM(learner, adversary_fn).fit(Z_train, T_train, Y_train, Z_val, T_val, Y_val,
                                           learner_lr=learner_lr, adversary_lr=adversary_lr,
                                           learner_l2=learner_l2, adversary_l2=adversary_l2,
                                           n_epochs=n_epochs, bs=bs, logger=None,
                                           model_dir=f'agmm_model_{it}', device=device)

    reisz = AGMM(learner, adversary_fn).fit(Z_train, T_train, Y_train, Z_val, T_val, Y_val,
                                            learner_lr=learner_lr, adversary_lr=adversary_lr,
                                            learner_l2=learner_l2, adversary_l2=adversary_l2,
                                            n_epochs=n_epochs, bs=bs, logger=None,
                                            model_dir=f'riesz_model_{it}', device=device,
                                            riesz=True, moment_fn=moment_fn)

    qfun = RandomForestRegressor(min_samples_leaf=20).fit(Z_train, reisz.predict(T_train))
    qfun_avg = RandomForestRegressor(min_samples_leaf=20).fit(Z_train,
                                                              reisz.predict(T_train, model='avg', burn_in=burn_in))
    
    #####
    # Average moment calculation
    #####
    direct = moment_fn(T_val, agmm.predict, device='cpu').flatten()
    residual = (Y_val - agmm.predict(T_val)).detach().numpy().flatten()
    qvalues = qfun.predict(Z_val).flatten()
    pseudo = direct + qvalues * residual
    dr = mean_ci(pseudo)
    ipw = mean_ci(qvalues * Y_val.detach().numpy().flatten())
    reg = mean_ci(direct)
    
    xivalues = reisz.predict(T_val).flatten()
    coef = np.mean(qvalues * residual) / np.mean(qvalues * xivalues)
    pseudo_tmle = direct + coef * moment_fn(T_val, reisz.predict, device='cpu').flatten()
    pseudo_tmle += qvalues * (residual - coef * xivalues)
    tmle = mean_ci(pseudo_tmle)
    
    direct_avg = moment_fn(T_val,
                           lambda x: agmm.predict(x, model='avg', burn_in=burn_in), device='cpu').flatten()
    residual_avg = (Y_val - agmm.predict(T_val, model='avg', burn_in=burn_in)).detach().numpy().flatten()
    qvalues_avg = qfun_avg.predict(Z_val).flatten()
    pseudo_avg = direct_avg + qvalues_avg * residual_avg
    dr_avg = mean_ci(pseudo_avg)
    ipw_avg = mean_ci(qvalues_avg * Y_val.detach().numpy().flatten())
    reg_avg = mean_ci(direct_avg)

    xivalues_avg = reisz.predict(T_val, model='avg', burn_in=burn_in).flatten()
    coef_avg = np.mean(qvalues_avg * residual_avg) / np.mean(qvalues_avg * xivalues_avg)
    pseudo_tmle_avg = (direct_avg 
                       + coef_avg * moment_fn(T_val, lambda x: reisz.predict(x, model='avg', burn_in=burn_in),
                                              device='cpu').flatten())
    pseudo_tmle_avg += qvalues_avg * (residual_avg - coef_avg * xivalues_avg)
    tmle_avg = mean_ci(pseudo_tmle_avg)

    return dr, tmle, ipw, reg, dr_avg, tmle_avg, ipw_avg, reg_avg

In [4]:
n = 2000
n_z = 1
n_t = 1
iv_strength = .7
fname = 'sigmoid'
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)

Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
true = np.mean(moment_fn(T, true_fn, device='cpu'))
print(f'True: {true:.4f}')

True: 0.1904


In [None]:
exp(1, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn)

In [None]:
import joblib
from joblib import Parallel, delayed

n_z = 1
n_t = 1
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)

for fname in ['abs', '2dpoly', 'sigmoid', 'sin', '3dpoly', 'abspos']:
    for n in [1000, 2000]:
        for iv_strength in [.5, .7, .9]:
            Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
            true = np.mean(moment_fn(T, true_fn, device='cpu'))
            print(f'True: {true:.4f}')
            results = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it, n, n_z, n_t, iv_strength,
                                                                  fname, dgp_num, moment_fn)
                                                     for it in range(100))
            joblib.dump((true, results), f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')

In [None]:
def plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true):
    plt.title(f'fname={fname}, n={n}, strength={iv_strength}, true={true:.3f}\n'
              f'dr: Cov={np.mean((dr[:, 1] <= true) & (true <= dr[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((dr[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((dr[:, 0]-true)):.3f}\n'
              f'tmle: Cov={np.mean((tmle[:, 1] <= true) & (true <= tmle[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((tmle[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((tmle[:, 0]-true)):.3f}\n'
              f'ipw: Cov={np.mean((ipw[:, 1] <= true) & (true <= ipw[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((ipw[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((ipw[:, 0]-true)):.3f}\n'
              f'direct: Cov={np.mean((direct[:, 1] <= true) & (true <= direct[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((direct[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((direct[:, 0]-true)):.3f}\n')
    plt.hist(dr[:, 0], label='dr')
    plt.hist(tmle[:, 0], label='tmle', alpha=.4)
    plt.hist(ipw[:, 0], label='ipw', alpha=.4)
    plt.hist(direct[:, 0], label='direct', alpha=.4)
    plt.legend()

### Results from early stopping models

In [None]:
plt.figure(figsize=(20, 20))
it = 1
for fname in ['abs', '2dpoly', 'sigmoid', 'sin', '3dpoly', 'abspos']:
    for n in [1000, 2000]:
        for iv_strength in [.5, .7, .9]:
            plt.subplot(3, 3, it)
            true, results = joblib.load(f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')
            dr = np.array([r[0] for r in results])
            tmle = np.array([r[1] for r in results])
            ipw = np.array([r[2] for r in results])
            direct = np.array([r[3] for r in results])
            plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true)
            it += 1
    plt.tight_layout()
    plt.show()

### Results from avg models

In [None]:
plt.figure(figsize=(20, 20))
it = 1
for fname in ['abs', '2dpoly', 'sigmoid', 'sin', '3dpoly', 'abspos']:
    for n in [1000, 2000]:
        for iv_strength in [.5, .7, .9]:
            plt.subplot(3, 3, it)
            true, results = joblib.load(f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')
            dr = np.array([r[4] for r in results])
            tmle = np.array([r[5] for r in results])
            ipw = np.array([r[6] for r in results])
            direct = np.array([r[7] for r in results])
            plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true)
            it += 1

    plt.tight_layout()
    plt.show()