In [1]:
### Magic functinos
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
%matplotlib inline

In [2]:
### imports
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.utilities import mean_ci
from mliv.neuralnet import AGMMEarlyStop as AGMM
from mliv.neuralnet.moments import avg_small_diff
from sklearn.ensemble import RandomForestRegressor

## $\ell_2$-Regularized AGMM with Neural Net Test Function

We solve the problem:
\begin{equation}
\min_{\theta} \max_{w} \frac{1}{n} \sum_i (y_i - h_{\theta}(x_i)) f_w(z_i) - f_w(z_i)^2
\end{equation}
where $h_{\theta}$ and $f_w$ are two neural nets.

In [64]:
def exp(it, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn, special_test=True, lambda_l2_h=0):
    np.random.seed(it)
    
    #####
    # Neural network parameters
    ####
    p = 0.1 # dropout prob of dropout layers throughout notebook
    n_hidden = 100 # width of hidden layers throughout notebook
    learner_lr = 1e-4
    adversary_lr = 1e-4
    learner_l2 = 1e-3
    adversary_l2 = 1e-3
    n_epochs = 200
    bs = 100
    burn_in = 100
    device = None
    
    ######
    # Train test split
    ######
    Z, T, Y, true_fn = get_data(n, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
    Z_train, Z_val, T_train, T_val, Y_train, Y_val = train_test_split(Z, T, Y, test_size=.5, shuffle=True)
    Z_train, T_train, Y_train = map(lambda x: torch.Tensor(x), (Z_train, T_train, Y_train))
    Z_val, T_val, Y_val = map(lambda x: torch.Tensor(x).to(device), (Z_val, T_val, Y_val))

    
    #####
    # Train "riesz" representer xi
    #####
    np.random.seed(12356)
    learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
                                 nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    reisz = AGMM(learner, adversary_fn).fit(Z_train, T_train, Y_train, Z_val, T_val, Y_val,
                                            learner_lr=learner_lr, adversary_lr=adversary_lr,
                                            learner_l2=learner_l2, adversary_l2=adversary_l2,
                                            n_epochs=n_epochs, bs=bs, logger=None,
                                            model_dir=f'riesz_model_{it}', device=device,
                                            riesz=True, moment_fn=moment_fn, verbose=1)
    
    ######
    # Train "riesz" representer q
    ######
    qfun = RandomForestRegressor(min_samples_leaf=20).fit(Z_train, reisz.predict(T_train).ravel())
    qfun_avg = RandomForestRegressor(min_samples_leaf=20).fit(Z_train,
                                                              reisz.predict(T_train,
                                                                            model='avg', burn_in=burn_in).ravel())
    
    ######
    # Train IV function h
    ######

    # Add "clever instrument" to instrument vector
    augZ_val = Z_val
    augZ_train = Z_train
    if special_test:
        qtrain = torch.tensor(qfun.predict(Z_train).reshape(-1, 1)).float()
        augZ_train = torch.cat([qtrain, Z_train], dim=1)
        qval = torch.tensor(qfun.predict(Z_val).reshape(-1, 1)).float()
        augZ_val = torch.cat([qval, Z_val], dim=1)

    adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(augZ_train.shape[1], n_hidden), nn.LeakyReLU(),
                                 nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, 1))
    
    agmm = AGMM(learner, adversary_fn, special_test=special_test)
    agmm.fit(augZ_train, T_train, Y_train, augZ_val, T_val, Y_val,
             learner_lr=learner_lr, adversary_lr=adversary_lr,
             learner_l2=learner_l2, adversary_l2=adversary_l2,
             learner_tikhonov=lambda_l2_h,
             n_epochs=n_epochs, bs=bs, logger=None,
             model_dir=f'agmm_model_{it}', device=device, verbose=1)

    
    #####
    # Average moment calculation
    #####
    direct = moment_fn(T_val, agmm.predict, device='cpu').flatten()
    residual = (Y_val - agmm.predict(T_val)).detach().numpy().flatten()
    qvalues = qfun.predict(Z_val).flatten()
    pseudo = direct + qvalues * residual
    dr = mean_ci(pseudo)
    ipw = mean_ci(qvalues * Y_val.detach().numpy().flatten())
    reg = mean_ci(direct)
    
    xivalues = reisz.predict(T_val).flatten()
    coef = np.mean(qvalues * residual) / np.mean(qvalues * xivalues)
    pseudo_tmle = direct + coef * moment_fn(T_val, reisz.predict, device='cpu').flatten()
    pseudo_tmle += qvalues * (residual - coef * xivalues)
    tmle = mean_ci(pseudo_tmle)
    
    # Based on "average neural net" models as opposed to early stopped models
    direct_avg = moment_fn(T_val,
                           lambda x: agmm.predict(x, model='avg', burn_in=burn_in), device='cpu').flatten()
    residual_avg = (Y_val - agmm.predict(T_val, model='avg', burn_in=burn_in)).detach().numpy().flatten()
    qvalues_avg = qfun_avg.predict(Z_val).flatten()
    pseudo_avg = direct_avg + qvalues_avg * residual_avg
    dr_avg = mean_ci(pseudo_avg)
    ipw_avg = mean_ci(qvalues_avg * Y_val.detach().numpy().flatten())
    reg_avg = mean_ci(direct_avg)

    xivalues_avg = reisz.predict(T_val, model='avg', burn_in=burn_in).flatten()
    coef_avg = np.mean(qvalues_avg * residual_avg) / np.mean(qvalues_avg * xivalues_avg)
    pseudo_tmle_avg = (direct_avg 
                       + coef_avg * moment_fn(T_val, lambda x: reisz.predict(x, model='avg', burn_in=burn_in),
                                              device='cpu').flatten())
    pseudo_tmle_avg += qvalues_avg * (residual_avg - coef_avg * xivalues_avg)
    tmle_avg = mean_ci(pseudo_tmle_avg)

    return dr, tmle, ipw, reg, dr_avg, tmle_avg, ipw_avg, reg_avg

In [68]:
n = 1000
n_z = 1
n_t = 1
iv_strength = .5
fname = 'sigmoid'
dgp_num = 5
epsilon = .1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)
lambda_l2_h = .1/n**(.9)
clever = True

Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
true = np.mean(moment_fn(T, true_fn, device='cpu'))
print(f'True: {true:.4f}')

True: 0.1887


In [69]:
exp(1, n, n_z, n_t, iv_strength, fname, dgp_num, moment_fn, clever, lambda_l2_h)

f(z_dev) collection prepared.
Epoch #0
tensor(0.7971)
Current moment approx: 0.10657162219285965
Epoch #1
tensor(0.7971)
Current moment approx: 0.10264182090759277
Epoch #2
tensor(0.7971)
Current moment approx: 0.10054069012403488
Epoch #3
tensor(0.7971)
Current moment approx: 0.09702461957931519
Epoch #4
tensor(0.7971)
Current moment approx: 0.09602110832929611
Epoch #5
tensor(0.7971)
Current moment approx: 0.0926128625869751
Epoch #6
tensor(0.7971)
Current moment approx: 0.09030835330486298
Epoch #7
tensor(0.7971)
Current moment approx: 0.08588869869709015
Epoch #8
tensor(0.7971)
Current moment approx: 0.08188210427761078
Epoch #9
tensor(0.7971)
Current moment approx: 0.08070887625217438
Epoch #10
tensor(0.7971)
Current moment approx: 0.07715708762407303
Epoch #11
tensor(0.7971)
Current moment approx: 0.07351870834827423
Epoch #12
tensor(0.7971)
Current moment approx: 0.07105858623981476
Epoch #13
tensor(0.7971)
Current moment approx: 0.06848948448896408
Epoch #14
tensor(0.7971)
Curr

Current moment approx: -0.35222646594047546
Epoch #121
tensor(0.7971)
Current moment approx: -0.35334181785583496
Epoch #122
tensor(0.7971)
Current moment approx: -0.3551754951477051
Epoch #123
tensor(0.7971)
Current moment approx: -0.35659170150756836
Epoch #124
tensor(0.7971)
Current moment approx: -0.35869932174682617
Epoch #125
tensor(0.7971)
Current moment approx: -0.3610484004020691
Epoch #126
tensor(0.7971)
Current moment approx: -0.36307764053344727
Epoch #127
tensor(0.7971)
Current moment approx: -0.3650643527507782
Epoch #128
tensor(0.7971)
Current moment approx: -0.3668758273124695
Epoch #129
tensor(0.7971)
Current moment approx: -0.3682031035423279
Epoch #130
tensor(0.7971)
Current moment approx: -0.36986440420150757
Epoch #131
tensor(0.7971)
Current moment approx: -0.37120169401168823
Epoch #132
tensor(0.7971)
Current moment approx: -0.3730553090572357
Epoch #133
tensor(0.7971)
Current moment approx: -0.375137597322464
Epoch #134
tensor(0.7971)
Current moment approx: -0.37

tensor(0.7971)
Current moment approx: -0.4565139412879944
Epoch #242
tensor(0.7971)
Current moment approx: -0.45683184266090393
Epoch #243
tensor(0.7971)
Current moment approx: -0.4568014144897461
Epoch #244
tensor(0.7971)
Current moment approx: -0.45688551664352417
Epoch #245
tensor(0.7971)
Current moment approx: -0.45698535442352295
Epoch #246
tensor(0.7971)
Current moment approx: -0.45696231722831726
Epoch #247
tensor(0.7971)
Current moment approx: -0.4570048749446869
Epoch #248
tensor(0.7971)
Current moment approx: -0.4570567011833191
Epoch #249
tensor(0.7971)
Current moment approx: -0.45736566185951233
Epoch #250
tensor(0.7971)
Current moment approx: -0.4576217830181122
Epoch #251
tensor(0.7971)
Current moment approx: -0.4575407803058624
Epoch #252
tensor(0.7971)
Current moment approx: -0.45773211121559143
Epoch #253
tensor(0.7971)
Current moment approx: -0.45786425471305847
Epoch #254
tensor(0.7971)
Current moment approx: -0.4579789638519287
Epoch #255
tensor(0.7971)
Current mome

Current moment approx: -0.4465891122817993
Epoch #360
tensor(0.7971)
Current moment approx: -0.44705426692962646
Epoch #361
tensor(0.7971)
Current moment approx: -0.4464206099510193
Epoch #362
tensor(0.7971)
Current moment approx: -0.4464947581291199
Epoch #363
tensor(0.7971)
Current moment approx: -0.4462706446647644
Epoch #364
tensor(0.7971)
Current moment approx: -0.44575875997543335
Epoch #365
tensor(0.7971)
Current moment approx: -0.4454459547996521
Epoch #366
tensor(0.7971)
Current moment approx: -0.4454331398010254
Epoch #367
tensor(0.7971)
Current moment approx: -0.44479578733444214
Epoch #368
tensor(0.7971)
Current moment approx: -0.44463878870010376
Epoch #369
tensor(0.7971)
Current moment approx: -0.44415831565856934
Epoch #370
tensor(0.7971)
Current moment approx: -0.4438353180885315
Epoch #371
tensor(0.7971)
Current moment approx: -0.4437740445137024
Epoch #372
tensor(0.7971)
Current moment approx: -0.4434632658958435
Epoch #373
tensor(0.7971)
Current moment approx: -0.442

tensor(0.7971)
Current moment approx: -0.42504262924194336
Epoch #481
tensor(0.7971)
Current moment approx: -0.42449748516082764
Epoch #482
tensor(0.7971)
Current moment approx: -0.4250885248184204
Epoch #483
tensor(0.7971)
Current moment approx: -0.4247046113014221
Epoch #484
tensor(0.7971)
Current moment approx: -0.42473411560058594
Epoch #485
tensor(0.7971)
Current moment approx: -0.42448389530181885
Epoch #486
tensor(0.7971)
Current moment approx: -0.4247627258300781
Epoch #487
tensor(0.7971)
Current moment approx: -0.4250526428222656
Epoch #488
tensor(0.7971)
Current moment approx: -0.424757182598114
Epoch #489
tensor(0.7971)
Current moment approx: -0.4244210124015808
Epoch #490
tensor(0.7971)
Current moment approx: -0.4236651062965393
Epoch #491
tensor(0.7971)
Current moment approx: -0.42375606298446655
Epoch #492
tensor(0.7971)
Current moment approx: -0.4235643148422241
Epoch #493
tensor(0.7971)
Current moment approx: -0.42372894287109375
Epoch #494
tensor(0.7971)
Current moment

tensor(0.7971)
Current moment approx: -0.41735827922821045
Epoch #599
tensor(0.7971)
Current moment approx: -0.41749048233032227
Epoch #600
tensor(0.7971)
Current moment approx: -0.41788560152053833
Epoch #601
tensor(0.7971)
Current moment approx: -0.41780292987823486
Epoch #602
tensor(0.7971)
Current moment approx: -0.41817760467529297
Epoch #603
tensor(0.7971)
Current moment approx: -0.4176345467567444
Epoch #604
tensor(0.7971)
Current moment approx: -0.41795098781585693
Epoch #605
tensor(0.7971)
Current moment approx: -0.4177771806716919
Epoch #606
tensor(0.7971)
Current moment approx: -0.41816186904907227
Epoch #607
tensor(0.7971)
Current moment approx: -0.4179610013961792
Epoch #608
tensor(0.7971)
Current moment approx: -0.41761112213134766
Epoch #609
tensor(0.7971)
Current moment approx: -0.4183046221733093
Epoch #610
tensor(0.7971)
Current moment approx: -0.4181746244430542
Epoch #611
tensor(0.7971)
Current moment approx: -0.4179545044898987
Epoch #612
tensor(0.7971)
Current mom

tensor(0.7971)
Current moment approx: -0.4108625650405884
Epoch #720
tensor(0.7971)
Current moment approx: -0.4116275906562805
Epoch #721
tensor(0.7971)
Current moment approx: -0.41144126653671265
Epoch #722
tensor(0.7971)
Current moment approx: -0.4123237729072571
Epoch #723
tensor(0.7971)
Current moment approx: -0.41308045387268066
Epoch #724
tensor(0.7971)
Current moment approx: -0.41313308477401733
Epoch #725
tensor(0.7971)
Current moment approx: -0.41324687004089355
Epoch #726
tensor(0.7971)
Current moment approx: -0.4133789539337158
Epoch #727
tensor(0.7971)
Current moment approx: -0.41385769844055176
Epoch #728
tensor(0.7971)
Current moment approx: -0.4135264754295349
Epoch #729
tensor(0.7971)
Current moment approx: -0.4135323762893677
Epoch #730
tensor(0.7971)
Current moment approx: -0.4129561185836792
Epoch #731
tensor(0.7971)
Current moment approx: -0.41284775733947754
Epoch #732
tensor(0.7971)
Current moment approx: -0.4134480357170105
Epoch #733
tensor(0.7971)
Current momen

Current moment approx: -0.4122048020362854
Epoch #841
tensor(0.7971)
Current moment approx: -0.4125029444694519
Epoch #842
tensor(0.7971)
Current moment approx: -0.4126133322715759
Epoch #843
tensor(0.7971)
Current moment approx: -0.41216468811035156
Epoch #844
tensor(0.7971)
Current moment approx: -0.4125332236289978
Epoch #845
tensor(0.7971)
Current moment approx: -0.4126805067062378
Epoch #846
tensor(0.7971)
Current moment approx: -0.4127100110054016
Epoch #847
tensor(0.7971)
Current moment approx: -0.41247570514678955
Epoch #848
tensor(0.7971)
Current moment approx: -0.4124182462692261
Epoch #849
tensor(0.7971)
Current moment approx: -0.4132443070411682
Epoch #850
tensor(0.7971)
Current moment approx: -0.4131144881248474
Epoch #851
tensor(0.7971)
Current moment approx: -0.4138113260269165
Epoch #852
tensor(0.7971)
Current moment approx: -0.4134240746498108
Epoch #853
tensor(0.7971)
Current moment approx: -0.4133536219596863
Epoch #854
tensor(0.7971)
Current moment approx: -0.413197

tensor(0.7971)
Current moment approx: -0.4115619659423828
Epoch #962
tensor(0.7971)
Current moment approx: -0.4119988679885864
Epoch #963
tensor(0.7971)
Current moment approx: -0.4126906394958496
Epoch #964
tensor(0.7971)
Current moment approx: -0.4125102162361145
Epoch #965
tensor(0.7971)
Current moment approx: -0.41235101222991943
Epoch #966
tensor(0.7971)
Current moment approx: -0.4132794141769409
Epoch #967
tensor(0.7971)
Current moment approx: -0.41367781162261963
Epoch #968
tensor(0.7971)
Current moment approx: -0.41282105445861816
Epoch #969
tensor(0.7971)
Current moment approx: -0.41258901357650757
Epoch #970
tensor(0.7971)
Current moment approx: -0.4129922389984131
Epoch #971
tensor(0.7971)
Current moment approx: -0.41230589151382446
Epoch #972
tensor(0.7971)
Current moment approx: -0.4122350215911865
Epoch #973
tensor(0.7971)
Current moment approx: -0.41312116384506226
Epoch #974
tensor(0.7971)
Current moment approx: -0.41231095790863037
Epoch #975
tensor(0.7971)
Current mome

Current moment approx: 0.0283229798078537
Epoch #107
Current moment approx: 0.026581624522805214
Epoch #108
Current moment approx: 0.02573983743786812
Epoch #109
Current moment approx: 0.02387470379471779
Epoch #110
Current moment approx: 0.022758007049560547
Epoch #111
Current moment approx: 0.02169010601937771
Epoch #112
Current moment approx: 0.019960777834057808
Epoch #113
Current moment approx: 0.019064808264374733
Epoch #114
Current moment approx: 0.018171323463320732
Epoch #115
Current moment approx: 0.019366402179002762
Epoch #116
Current moment approx: 0.01882508024573326
Epoch #117
Current moment approx: 0.01875138096511364
Epoch #118
Current moment approx: 0.020262103527784348
Epoch #119
Current moment approx: 0.023184239864349365
Epoch #120
Current moment approx: 0.02506018988788128
Epoch #121
Current moment approx: 0.026559052988886833
Epoch #122
Current moment approx: 0.026197116822004318
Epoch #123
Current moment approx: 0.027889804914593697
Epoch #124
Current moment app

Current moment approx: 0.017838958650827408
Epoch #260
Current moment approx: 0.0183984637260437
Epoch #261
Current moment approx: 0.018264541402459145
Epoch #262
Current moment approx: 0.016885902732610703
Epoch #263
Current moment approx: 0.01900097168982029
Epoch #264
Current moment approx: 0.018221471458673477
Epoch #265
Current moment approx: 0.02009608782827854
Epoch #266
Current moment approx: 0.020188823342323303
Epoch #267
Current moment approx: 0.018320446833968163
Epoch #268
Current moment approx: 0.017359308898448944
Epoch #269
Current moment approx: 0.01649375446140766
Epoch #270
Current moment approx: 0.01743997633457184
Epoch #271
Current moment approx: 0.01550986710935831
Epoch #272
Current moment approx: 0.014372591860592365
Epoch #273
Current moment approx: 0.016425959765911102
Epoch #274
Current moment approx: 0.016426879912614822
Epoch #275
Current moment approx: 0.019274551421403885
Epoch #276
Current moment approx: 0.02090573124587536
Epoch #277
Current moment app

Current moment approx: 0.019195053726434708
Epoch #415
Current moment approx: 0.020597122609615326
Epoch #416
Current moment approx: 0.020698703825473785
Epoch #417
Current moment approx: 0.022290820255875587
Epoch #418
Current moment approx: 0.02245410531759262
Epoch #419
Current moment approx: 0.019514093175530434
Epoch #420
Current moment approx: 0.018693700432777405
Epoch #421
Current moment approx: 0.017712842673063278
Epoch #422
Current moment approx: 0.01607716642320156
Epoch #423
Current moment approx: 0.01604137383401394
Epoch #424
Current moment approx: 0.014700223691761494
Epoch #425
Current moment approx: 0.015362056903541088
Epoch #426
Current moment approx: 0.017009926959872246
Epoch #427
Current moment approx: 0.017274223268032074
Epoch #428
Current moment approx: 0.016769012436270714
Epoch #429
Current moment approx: 0.015728946775197983
Epoch #430
Current moment approx: 0.01716725528240204
Epoch #431
Current moment approx: 0.01640060544013977
Epoch #432
Current moment 

Current moment approx: 0.022565655410289764
Epoch #569
Current moment approx: 0.020664233714342117
Epoch #570
Current moment approx: 0.020360320806503296
Epoch #571
Current moment approx: 0.0206313319504261
Epoch #572
Current moment approx: 0.02105695568025112
Epoch #573
Current moment approx: 0.01858864165842533
Epoch #574
Current moment approx: 0.01897842437028885
Epoch #575
Current moment approx: 0.017371265217661858
Epoch #576
Current moment approx: 0.015159807167947292
Epoch #577
Current moment approx: 0.014840287156403065
Epoch #578
Current moment approx: 0.014282623305916786
Epoch #579
Current moment approx: 0.011785612441599369
Epoch #580
Current moment approx: 0.011560489423573017
Epoch #581
Current moment approx: 0.009921858087182045
Epoch #582
Current moment approx: 0.011076144874095917
Epoch #583
Current moment approx: 0.009199552237987518
Epoch #584
Current moment approx: 0.01044461503624916
Epoch #585
Current moment approx: 0.008402297273278236
Epoch #586
Current moment a

Current moment approx: 0.017448414117097855
Epoch #722
Current moment approx: 0.016795501112937927
Epoch #723
Current moment approx: 0.016744229942560196
Epoch #724
Current moment approx: 0.01954840123653412
Epoch #725
Current moment approx: 0.01784495636820793
Epoch #726
Current moment approx: 0.02000553160905838
Epoch #727
Current moment approx: 0.01901681162416935
Epoch #728
Current moment approx: 0.020503101870417595
Epoch #729
Current moment approx: 0.01884288527071476
Epoch #730
Current moment approx: 0.01856655813753605
Epoch #731
Current moment approx: 0.01941017620265484
Epoch #732
Current moment approx: 0.019860316067934036
Epoch #733
Current moment approx: 0.019457925111055374
Epoch #734
Current moment approx: 0.01730385608971119
Epoch #735
Current moment approx: 0.01557551883161068
Epoch #736
Current moment approx: 0.015873203054070473
Epoch #737
Current moment approx: 0.01490798033773899
Epoch #738
Current moment approx: 0.014193467795848846
Epoch #739
Current moment appro

Current moment approx: 0.016477065160870552
Epoch #873
Current moment approx: 0.015573427081108093
Epoch #874
Current moment approx: 0.0172015018761158
Epoch #875
Current moment approx: 0.018045151606202126
Epoch #876
Current moment approx: 0.016100440174341202
Epoch #877
Current moment approx: 0.016315922141075134
Epoch #878
Current moment approx: 0.016696903854608536
Epoch #879
Current moment approx: 0.01675577647984028
Epoch #880
Current moment approx: 0.0168906319886446
Epoch #881
Current moment approx: 0.017058715224266052
Epoch #882
Current moment approx: 0.017925379797816277
Epoch #883
Current moment approx: 0.017602665349841118
Epoch #884
Current moment approx: 0.017243312671780586
Epoch #885
Current moment approx: 0.01716512441635132
Epoch #886
Current moment approx: 0.016202114522457123
Epoch #887
Current moment approx: 0.01593197137117386
Epoch #888
Current moment approx: 0.01692267693579197
Epoch #889
Current moment approx: 0.01672334223985672
Epoch #890
Current moment appr

((0.2404179236695896, 0.18308379156423954, 0.2977520557749397),
 (0.24073517, 0.18385994005276787, 0.29761040639803776),
 (0.2276113889311236, 0.1492584141301336, 0.30596436373211355),
 (0.23492679, 0.2313089765836171, 0.23854460341439926),
 (0.24359147943555798, 0.1757247946533761, 0.31145816421773986),
 (0.24113794, 0.1752171909740889, 0.3070586824485338),
 (0.2688657233462798, 0.1765424903645723, 0.36118895632798725),
 (0.22023343, 0.2130868681125073, 0.22737998288351016))

In [None]:
def plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true):
    plt.title(f'fname={fname}, n={n}, strength={iv_strength}, true={true:.3f}\n'
              f'dr: Cov={np.mean((dr[:, 1] <= true) & (true <= dr[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((dr[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((dr[:, 0]-true)):.3f}\n'
              f'tmle: Cov={np.mean((tmle[:, 1] <= true) & (true <= tmle[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((tmle[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((tmle[:, 0]-true)):.3f}\n'
              f'ipw: Cov={np.mean((ipw[:, 1] <= true) & (true <= ipw[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((ipw[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((ipw[:, 0]-true)):.3f}\n'
              f'direct: Cov={np.mean((direct[:, 1] <= true) & (true <= direct[:, 2])):.3f}, '
              f'rmse={np.sqrt(np.mean((direct[:, 0]-true)**2)):.3f}, '
              f'bias={np.mean((direct[:, 0]-true)):.3f}\n')
    plt.hist(dr[:, 0], label='dr')
    plt.hist(tmle[:, 0], label='tmle', alpha=.4)
    plt.hist(ipw[:, 0], label='ipw', alpha=.4)
    plt.hist(direct[:, 0], label='direct', alpha=.4)
    plt.legend()

In [None]:
import joblib
from joblib import Parallel, delayed

n_z = 1
n_t = 1
dgp_num = 5
epsilon = 0.1 # average finite difference epsilon
moment_fn = lambda x, fn, device: avg_small_diff(x, fn, device, epsilon)
clever = True

for clever in [False, True]:
    for fname in ['2dpoly']:
        for n in [2000]:
            for iv_strength in [.5]:
                lambda_l2_h = .1/n**(.9)
                Z, T, Y, true_fn = get_data(1000000, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)
                true = np.mean(moment_fn(T, true_fn, device='cpu'))
                print(f'True: {true:.4f}')
                results = Parallel(n_jobs=-1, verbose=3)(delayed(exp)(it, n, n_z, n_t, iv_strength,
                                                                      fname, dgp_num, moment_fn,
                                                                      special_test=clever, lambda_l2_h=lambda_l2_h)
                                                         for it in range(8))
                dr = np.array([r[0] for r in results])
                tmle = np.array([r[1] for r in results])
                ipw = np.array([r[2] for r in results])
                direct = np.array([r[3] for r in results])
                plot_results(fname, n, iv_strength, dr, tmle, ipw, direct, true)
                plt.show()