In [1]:
import numpy as np
import pickle
from sklearn.linear_model import ElasticNetCV, LogisticRegressionCV
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split, StratifiedKFold
from scipy.stats import norm
from data_generation import m_0, g_0, get_data
from dml_algorithm import mm_ate, dml_ate

In [2]:
rng = np.random.default_rng(seed=57)

In [3]:
N = 250
y_data, d_data, x_data = get_data(N, rng)
poly_features = PolynomialFeatures(degree=2, include_bias=False)
x_quad_data = poly_features.fit_transform(x_data)
y_train, y_test, d_train, d_test, x_train, x_test, x_quad_train, x_quad_test = train_test_split(y_data, d_data, x_data, x_quad_data, test_size=0.2, random_state=42)
x_quad_data.shape

(250, 65)

In [4]:
l1_ratio = [0.1, 0.5, 0.7, 0.9, 0.95, 0.99, 1]
#alphas = [0.01, 0.1, 1, 10]
model_g0, model_g1 = ElasticNetCV(l1_ratio=l1_ratio, n_alphas=50, max_iter=10000, n_jobs=-1), ElasticNetCV(l1_ratio=l1_ratio, n_alphas=50, max_iter=10000, n_jobs=-1)
model_m = LogisticRegressionCV(Cs=10, 
                               l1_ratios=[0, 0.1, 0.2, 0.4, 0.6, 0.8, 0.9, 1], 
                               penalty='elasticnet',
                               solver='saga',
                               max_iter=1000,
                               random_state=42,
                               scoring='neg_brier_score',
                               n_jobs=-1
                              )
model_m_quad = LogisticRegressionCV(Cs=10, 
                               l1_ratios=[1], 
                               penalty='elasticnet',
                               solver='saga',
                               max_iter=10000,
                               random_state=42,
                               scoring='neg_brier_score',
                               n_jobs=-1
                              )

In [5]:
%%time
#model_m.fit(x_train, d_train)
#print(model_m.C_, model_m.l1_ratio_)
#print(model_m.predict_proba(x_test)[:20,1])
model_m_quad.fit(x_quad_train, d_train)
print(model_m_quad.C_[0], model_m_quad.l1_ratio_[0])
print(model_m_quad.predict_proba(x_quad_test)[:20,1])
print(m_0(x_test[:20]))

0.3593813663804626 1
[0.19574359 0.54034075 0.33744524 0.47989154 0.70808248 0.19638012
 0.18013099 0.76661517 0.20473032 0.21258443 0.09609844 0.56465364
 0.3828776  0.35195612 0.31742039 0.40492419 0.30285038 0.62991276
 0.60463028 0.40884351]
[0.15132306 0.74901171 0.50626301 0.3209899  0.37019757 0.17441277
 0.11173217 0.55695414 0.30432288 0.25391687 0.10304543 0.61951328
 0.30753256 0.63686157 0.4655652  0.4713355  0.4998131  0.53410245
 0.67962409 0.68040758]
CPU times: total: 18 s
Wall time: 6.35 s


In [6]:
model_m_quad.intercept_

array([-0.04492515])

In [20]:
"""
%%time
model_g0.fit(x_train[d_train==0], y_train[d_train==0])
print(model_g0.alpha_, model_g0.l1_ratio_)
print(model_g0.predict(x_test[:20]))
print(g_0(0, x_test[:20]))
"""

0.002023169961802633 1.0
[ 9.11913773  1.86267306  3.38823435  3.92009905  3.71135644  8.01888979
  6.47477631  3.35071711 -2.48495728  3.26023665  4.01562794  2.40326794
  5.08584477  7.09953179  0.24808764 10.10317209  4.97710924  6.61279589
  4.84095057  8.4415757 ]
[ 9.67089239  2.52296456  2.95237173  3.3036634   3.46981129  7.25168064
  5.95556265  4.4787683  -2.17647036  2.9592268   3.37120465  2.49881907
  6.21870476  6.53252436 -0.14063823  9.19354424  4.95250245  6.55058744
  4.15508761  7.76348105]
CPU times: total: 391 ms
Wall time: 397 ms


In [7]:
%%time
model_g0.fit(x_quad_train[d_train==0], y_train[d_train==0])
print(model_g0.alpha_, model_g0.l1_ratio_)
print(model_g0.predict(x_quad_test[:20]))
print(g_0(0, x_test[:20]))

0.02187294180905264 1.0
[ 5.2366389   4.42103247  2.65501032  1.09158128  3.08132517  5.68571463
  1.8744845   5.70053816  0.78835418 -1.95519293  7.7522891   3.71266605
  1.46732323  3.85241095  5.88412716  2.17879501  6.44474519  6.54582153
  6.94813473  5.90336384]
[ 5.43539668  4.63960997  2.86123934  1.2903609   2.38694195  6.55957317
  1.5254314   5.47361822  0.53288856 -1.23166293  8.04359852  3.06181596
  0.63574394  2.86751766  6.28581004  2.12573501  6.87827026  6.8938472
  6.44457398  6.32056436]
CPU times: total: 1.89 s
Wall time: 775 ms


In [8]:
model_g0.coef_

array([ 0.44642041,  1.41112897,  1.88154013,  2.67229138,  0.78923947,
        0.09819284, -0.10155162, -0.        , -0.0356418 , -0.        ,
        0.2235204 ,  0.41034453,  0.05326671,  0.0147464 ,  0.13809821,
       -0.        , -0.17796699,  0.05022517, -0.07640268,  0.09657925,
        0.17851411,  0.        ,  0.        ,  0.        ,  0.04231927,
        0.        , -0.07673944,  0.        , -0.        , -0.10123373,
        0.        ,  0.37617951,  0.        , -0.23014066,  0.        ,
       -0.        ,  0.        , -0.02929725,  0.08134201, -0.        ,
       -0.        ,  0.        ,  0.32960906, -0.        ,  0.02195003,
        0.17823327, -0.00498887, -0.        , -0.21123816,  0.11021318,
        0.13456728,  0.        , -0.        ,  0.        ,  0.        ,
        0.3644809 , -0.        , -1.49671314, -0.        ,  0.12438415,
        0.14548096, -0.        , -0.        , -0.87326478, -0.        ])

In [22]:
"""
%%time
model_g1.fit(x_train[d_train==1], y_train[d_train==1])
print(model_g1.alpha_, model_g1.l1_ratio_)
print(model_g1.predict(x_test[:20]))
print(g_0(1, x_test[:20]))
"""

0.0016604646032669184 0.95
[ 8.53598317  4.02933957  5.30445082  5.87899583  3.18154968  9.43416427
  5.66363481  5.57276874 -1.60926805  2.97616642  3.14048025  4.77194485
  5.46062975  7.30029159  1.50686742  9.49802442  6.33179126  6.76632479
  3.56087634  7.92841459]
[ 8.44765297  4.82355923  4.62288453  5.72266669  2.47828193  8.59848846
  4.87169894  7.07646335 -1.75255434  2.34590245  2.36804108  5.45466763
  6.86649874  6.80355745  1.014024    8.8125616   6.12793833  6.2625267
  2.89036964  6.58935663]
CPU times: total: 250 ms
Wall time: 198 ms


In [9]:
%%time
model_g1.fit(x_quad_train[d_train==1], y_train[d_train==1])
print(model_g1.alpha_, model_g1.l1_ratio_)
print(model_g1.predict(x_quad_test[:20]))
print(g_0(1, x_test[:20]))

0.035550417344207476 1.0
[5.11957623 4.59509361 2.41527092 1.38020099 4.36185196 6.00980268
 3.08612023 5.80915018 3.93014547 2.99325677 6.76504822 3.96022368
 3.27737313 4.50278141 4.63490809 2.8432264  8.17178788 6.74628689
 7.14342205 6.48646862]
[ 5.9612774   5.02699634  2.29420618  0.75801322  3.66617622  5.12795927
  2.89765683  7.04159289  2.83702293 -1.4817367   7.39102966  2.88846975
  3.52346933  3.40335304  5.69011325  2.47692171  9.11107849  7.18952703
  8.30582078  7.10256767]
CPU times: total: 1.88 s
Wall time: 743 ms


In [10]:
model_g1.coef_

array([ 0.69195362,  0.89149201,  1.59609141,  2.46455314,  1.29095874,
       -0.        , -0.7615018 , -0.22761059, -0.        , -0.        ,
       -0.08802115,  0.11455706,  0.        ,  0.        ,  0.        ,
       -0.        , -0.04012475, -0.04255939,  0.        ,  0.        ,
        0.29196648,  0.19716197,  0.1207351 ,  0.44826184,  0.12097799,
       -0.        ,  0.        , -0.1743234 ,  0.        ,  0.        ,
        0.        ,  0.19326609, -0.10678429, -0.        , -0.        ,
        0.57669953,  0.        ,  0.01363198,  0.02142319,  0.        ,
       -0.        ,  0.23438096, -0.        ,  0.        ,  0.        ,
       -0.        , -0.        ,  0.13933789,  0.        ,  0.        ,
        0.09810267, -0.20628876,  0.        , -0.        , -0.        ,
        0.58634947, -0.        , -0.22454259,  0.        , -0.        ,
       -0.        , -0.        , -0.        , -0.        , -0.        ])

In [11]:
%%time
model_g = [model_g0, model_g1]
dml_ate(y_data, d_data, x_quad_data, model_g, model_m_quad)

CPU times: total: 2min 2s
Wall time: 43.3 s


(array([-1.74947511,  0.2421966 ,  7.43272594]),
 24.961624140095907,
 array([-4.84369324,  1.34474302]))