In [1]:
### Magic functinos
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
%matplotlib inline

In [2]:
### imports
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import mean_ci
from mliv.neuralnet import AGMMEarlyStop as AGMM
from mliv.neuralnet.moments import avg_small_diff
from sklearn.ensemble import RandomForestRegressor

## $\ell_2$-Regularized AGMM with Neural Net Test Function

We solve the problem:
\begin{equation}
\min_{\theta} \max_{w} \frac{1}{n} \sum_i (y_i - h_{\theta}(x_i)) f_w(z_i) - f_w(z_i)^2
\end{equation}
where $h_{\theta}$ and $f_w$ are two neural nets.

In [3]:
def exp(it, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn, special_test=True):
    np.random.seed(it)
    
    #####
    # Neural network parameters
    ####
    p = 0.1 # dropout prob of dropout layers throughout notebook
    n_hidden = 100 # width of hidden layers throughout notebook
    learner_lr = 1e-4
    adversary_lr = 1e-4
    learner_l2 = 1e-3
    adversary_l2 = 1e-3
    n_epochs = 200
    bs = 100
    burn_in = 100
    device = None
    
    ######
    # Train test split
    ######
    Z, T, Y, true_fn = get_data(n, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
    Z_train, Z_val, T_train, T_val, Y_train, Y_val = train_test_split(Z, T, Y, test_size=.5, shuffle=True)
    Z_train, T_train, Y_train = map(lambda x: torch.Tensor(x), (Z_train, T_train, Y_train))
    Z_val, T_val, Y_val = map(lambda x: torch.Tensor(x).to(device), (Z_val, T_val, Y_val))

    
    #####
    # Train "riesz" representer xi
    #####
    np.random.seed(12356)
    learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
                                 nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    reisz = AGMM(learner, adversary_fn).fit(Z_train, T_train, Y_train, Z_val, T_val, Y_val,
                                            learner_lr=learner_lr, adversary_lr=adversary_lr,
                                            learner_l2=learner_l2, adversary_l2=adversary_l2,
                                            n_epochs=n_epochs, bs=bs, logger=None,
                                            model_dir=f'riesz_model_{it}', device=device,
                                            riesz=True, moment_fn=moment_fn)
    
    ######
    # Train "riesz" representer q
    ######
    qfun = RandomForestRegressor(min_samples_leaf=20).fit(Z_train, reisz.predict(T_train))
    qfun_avg = RandomForestRegressor(min_samples_leaf=20).fit(Z_train,
                                                              reisz.predict(T_train, model='avg', burn_in=burn_in))
    
    ######
    # Train IV function h
    ######

    # Add "clever instrument" to instrument vector
    augZ_val = Z_val
    augZ_train = Z_train
    if special_test:
        qtrain = torch.tensor(qfun.predict(Z_train).reshape(-1, 1)).float()
        augZ_train = torch.cat([qtrain, Z_train], dim=1)
        qval = torch.tensor(qfun.predict(Z_val).reshape(-1, 1)).float()
        augZ_val = torch.cat([qval, Z_val], dim=1)

    adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(augZ_train.shape[1], n_hidden), nn.LeakyReLU(),
                                 nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    
    agmm = AGMM(learner, adversary_fn, special_test=special_test)
    agmm.fit(augZ_train, T_train, Y_train, augZ_val, T_val, Y_val,
             learner_lr=learner_lr, adversary_lr=adversary_lr,
             learner_l2=learner_l2, adversary_l2=adversary_l2,
             n_epochs=n_epochs, bs=bs, logger=None,
             model_dir=f'agmm_model_{it}', device=device)

    
    #####
    # Average moment calculation
    #####
    direct = moment_fn(T_val, agmm.predict, device='cpu').flatten()
    residual = (Y_val - agmm.predict(T_val)).detach().numpy().flatten()
    qvalues = qfun.predict(Z_val).flatten()
    pseudo = direct + qvalues * residual
    dr = mean_ci(pseudo)
    ipw = mean_ci(qvalues * Y_val.detach().numpy().flatten())
    reg = mean_ci(direct)
    
    xivalues = reisz.predict(T_val).flatten()
    coef = np.mean(qvalues * residual) / np.mean(qvalues * xivalues)
    pseudo_tmle = direct + coef * moment_fn(T_val, reisz.predict, device='cpu').flatten()
    pseudo_tmle += qvalues * (residual - coef * xivalues)
    tmle = mean_ci(pseudo_tmle)

    direct_avg = moment_fn(T_val,
                           lambda x: agmm.predict(x, model='avg', burn_in=burn_in), device='cpu').flatten()
    residual_avg = (Y_val - agmm.predict(T_val, model='avg', burn_in=burn_in)).detach().numpy().flatten()
    qvalues_avg = qfun_avg.predict(Z_val).flatten()
    pseudo_avg = direct_avg + qvalues_avg * residual_avg
    dr_avg = mean_ci(pseudo_avg)
    ipw_avg = mean_ci(qvalues_avg * Y_val.detach().numpy().flatten())
    reg_avg = mean_ci(direct_avg)

    xivalues_avg = reisz.predict(T_val, model='avg', burn_in=burn_in).flatten()
    coef_avg = np.mean(qvalues_avg * residual_avg) / np.mean(qvalues_avg * xivalues_avg)
    pseudo_tmle_avg = (direct_avg 
                       + coef_avg * moment_fn(T_val, lambda x: reisz.predict(x, model='avg', burn_in=burn_in),
                                              device='cpu').flatten())
    pseudo_tmle_avg += qvalues_avg * (residual_avg - coef_avg * xivalues_avg)
    tmle_avg = mean_ci(pseudo_tmle_avg)

    return dr, tmle, ipw, reg, dr_avg, tmle_avg, ipw_avg, reg_avg

In [4]:
n = 500
n_z = 1
n_t = 1
iv_strength = .7
fname = 'sigmoid'
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)

Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
true = np.mean(moment_fn(T, true_fn, device='cpu'))
print(f'True: {true:.4f}')

True: 0.1902


In [5]:
exp(1, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn)

f(z_dev) collection prepared.
Epoch #0
Current moment approx: 0.1438421607017517
Epoch #1
Current moment approx: 0.14461368322372437
Epoch #2
Current moment approx: 0.1448717564344406
Epoch #3
Current moment approx: 0.1455904096364975
Epoch #4
Current moment approx: 0.14687763154506683
Epoch #5
Current moment approx: 0.14842307567596436
Epoch #6
Current moment approx: 0.14982415735721588
Epoch #7
Current moment approx: 0.15057368576526642
Epoch #8
Current moment approx: 0.1521255522966385
Epoch #9
Current moment approx: 0.1535838097333908
Epoch #10
Current moment approx: 0.15497159957885742
Epoch #11
Current moment approx: 0.1560703068971634
Epoch #12
Current moment approx: 0.15757490694522858
Epoch #13
Current moment approx: 0.15945811569690704
Epoch #14
Current moment approx: 0.15965843200683594
Epoch #15
Current moment approx: 0.16091185808181763
Epoch #16
Current moment approx: 0.16251015663146973
Epoch #17
Current moment approx: 0.16270850598812103
Epoch #18
Current moment approx:

Current moment approx: -0.05695759505033493
Epoch #158
Current moment approx: -0.060375526547431946
Epoch #159
Current moment approx: -0.06289416551589966
Epoch #160
Current moment approx: -0.06509587168693542
Epoch #161
Current moment approx: -0.0674339011311531
Epoch #162
Current moment approx: -0.07101815938949585
Epoch #163
Current moment approx: -0.07390706241130829
Epoch #164
Current moment approx: -0.07624417543411255
Epoch #165
Current moment approx: -0.07972867786884308
Epoch #166
Current moment approx: -0.08249110728502274
Epoch #167
Current moment approx: -0.08419572561979294
Epoch #168
Current moment approx: -0.08531949669122696
Epoch #169
Current moment approx: -0.08689413964748383
Epoch #170
Current moment approx: -0.08893013000488281
Epoch #171
Current moment approx: -0.09061021357774734
Epoch #172
Current moment approx: -0.09182599186897278
Epoch #173
Current moment approx: -0.09394507855176926
Epoch #174
Current moment approx: -0.09437201917171478
Epoch #175
Current mo

Current moment approx: 0.009829754941165447
Epoch #110
Current moment approx: 0.009417101740837097
Epoch #111
Current moment approx: 0.008682372979819775
Epoch #112
Current moment approx: 0.009371362626552582
Epoch #113
Current moment approx: 0.008839530870318413
Epoch #114
Current moment approx: 0.008483144454658031
Epoch #115
Current moment approx: 0.008206977508962154
Epoch #116
Current moment approx: 0.008183788508176804
Epoch #117
Current moment approx: 0.009763234294950962
Epoch #118
Current moment approx: 0.013240438885986805
Epoch #119
Current moment approx: 0.014833049848675728
Epoch #120
Current moment approx: 0.015984900295734406
Epoch #121
Current moment approx: 0.015742294490337372
Epoch #122
Current moment approx: 0.01475923415273428
Epoch #123
Current moment approx: 0.013199122622609138
Epoch #124
Current moment approx: 0.011979503557085991
Epoch #125
Current moment approx: 0.010807814076542854
Epoch #126
Current moment approx: 0.011755027808248997
Epoch #127
Current mom

((0.18245209018598257, 0.1548967579173154, 0.21000742245464976),
 (0.1809173, 0.15332752289422813, 0.20850706277242836),
 (0.0711083701589425, 0.035721441188335694, 0.1064952991295493),
 (0.18350998, 0.17597776996448028, 0.1910421813790561),
 (0.21581645753035558, 0.18564753138845902, 0.24598538367225214),
 (0.21606198, 0.18592929218852625, 0.2461946668759574),
 (0.035931604760757414, -0.00045093151787378716, 0.07231414103938862),
 (0.21448505, 0.20941350404723802, 0.21955659444824538))

In [None]:
import joblib
from joblib import Parallel, delayed

n_z = 1
n_t = 1
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)
clever = False

for fname in ['abs', '2dpoly', 'sigmoid', 'sin']:
    for n in [500]:
        for iv_strength in [.5, .7, .9]:
            Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
            true = np.mean(moment_fn(T, true_fn, device='cpu'))
            print(f'True: {true:.4f}')
            results = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it, n, n_z, n_t, iv_strength,
                                                                  fname, dgp_num, moment_fn, special_test=clever)
                                                     for it in range(100))
            joblib.dump((true, results),
                        f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')

True: -0.0000


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed:   50.8s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  4.8min finished


True: -0.0003


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed:   43.8s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  4.9min finished


True: -0.0001


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed:   50.1s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  5.2min finished


True: -0.4967


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed:   48.7s


In [None]:
def plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true):
    plt.title(f'fname={fname}, n={n}, strength={iv_strength}, true={true:.3f}\n'
              f'dr: Cov={np.mean((dr[:, 1] <= true) & (true <= dr[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((dr[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((dr[:, 0]-true)):.3f}\n'
              f'tmle: Cov={np.mean((tmle[:, 1] <= true) & (true <= tmle[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((tmle[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((tmle[:, 0]-true)):.3f}\n'
              f'ipw: Cov={np.mean((ipw[:, 1] <= true) & (true <= ipw[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((ipw[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((ipw[:, 0]-true)):.3f}\n'
              f'direct: Cov={np.mean((direct[:, 1] <= true) & (true <= direct[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((direct[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((direct[:, 0]-true)):.3f}\n')
    plt.hist(dr[:, 0], label='dr')
    plt.hist(tmle[:, 0], label='tmle', alpha=.4)
    plt.hist(ipw[:, 0], label='ipw', alpha=.4)
    plt.hist(direct[:, 0], label='direct', alpha=.4)
    plt.legend()

### Results from early stopping models

In [None]:
for fname in ['abs', '2dpoly', 'sigmoid', 'sin']:
    plt.figure(figsize=(20, 20))
    it = 1
#     for n in [1000, 2000]:
    for n in [500, 1000]:
        for iv_strength in [.5, .7, .9]:
            plt.subplot(3, 3, it)
            # true, results = joblib.load(f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')
            true, results = joblib.load(f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}_clever_{clever}.jbl')
            dr = np.array([r[0] for r in results])
            tmle = np.array([r[1] for r in results])
            ipw = np.array([r[2] for r in results])
            direct = np.array([r[3] for r in results])
            plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true)
            it += 1
    plt.tight_layout()
    plt.show()

### Results from avg models

In [None]:
for fname in ['abs', '2dpoly', 'sigmoid', 'sin', '3dpoly', 'abspos']:
    plt.figure(figsize=(20, 20))
    it = 1
    for n in [1000, 2000]:
        for iv_strength in [.5, .7, .9]:
            plt.subplot(3, 3, it)
            true, results = joblib.load(f'res_fn_{fname}_n_{n}_stregth_{iv_strength}_eps_{epsilon}.jbl')
            dr = np.array([r[4] for r in results])
            tmle = np.array([r[5] for r in results])
            ipw = np.array([r[6] for r in results])
            direct = np.array([r[7] for r in results])
            plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true)
            it += 1

    plt.tight_layout()
    plt.show()