In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
## Ignore warnings
import warnings
warnings.filterwarnings('ignore') 

In [3]:
# Main imports
from econml.dml import DMLCateEstimator, LinearDMLCateEstimator,SparseLinearDMLCateEstimator, ForestDMLCateEstimator

# Helper imports
import numpy as np
from itertools import product
from sklearn.linear_model import Lasso, LassoCV, LogisticRegression, LogisticRegressionCV,LinearRegression,MultiTaskElasticNet,MultiTaskElasticNetCV
from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from sklearn.preprocessing import PolynomialFeatures
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split

%matplotlib inline

In [4]:
import statsmodels.api as sm
Y = np.array([1,3,4,5,2,3,4])
X = range(1,8)
X = sm.add_constant(X)
model = sm.OLS(Y,X)
results = model.fit(use_t=False)

In [5]:
results.summary()



0,1,2,3
Dep. Variable:,y,R-squared:,0.161
Model:,OLS,Adj. R-squared:,-0.007
Method:,Least Squares,F-statistic:,0.9608
Date:,"Thu, 19 Dec 2019",Prob (F-statistic):,0.372
Time:,17:42:05,Log-Likelihood:,-10.854
No. Observations:,7,AIC:,25.71
Df Residuals:,5,BIC:,25.6
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,z,P>|z|,[0.025,0.975]
const,2.1429,1.141,1.879,0.060,-0.093,4.378
x1,0.2500,0.255,0.980,0.327,-0.250,0.750

0,1,2,3
Omnibus:,,Durbin-Watson:,1.743
Prob(Omnibus):,,Jarque-Bera (JB):,0.482
Skew:,0.206,Prob(JB):,0.786
Kurtosis:,1.782,Cond. No.,10.4


In [6]:
from statsmodels.regression.linear_model import RegressionResults,PredictionResults

In [7]:
results.get_prediction().summary_frame()

Unnamed: 0,mean,mean_se,mean_ci_lower,mean_ci_upper,obs_ci_lower,obs_ci_upper
0,2.392857,0.9196,0.590475,4.195239,-0.808006,5.59372
1,2.642857,0.721393,1.228952,4.056762,-0.356488,5.642202
2,2.892857,0.570311,1.775067,4.010647,0.021203,5.764511
3,3.142857,0.510102,2.143076,4.142639,0.315048,5.970667
4,3.392857,0.570311,2.275067,4.510647,0.521203,6.264511
5,3.642857,0.721393,2.228952,5.056762,0.643512,6.642202
6,3.892857,0.9196,2.090475,5.695239,0.691994,7.09372


### DML inferences
1. LinearDML
2. SparseLinearDML
3. KernelDML (no inference)
4. ForestDML


##  write comprehensive DML dgps

dgp includes:
1. single Y, continuous single T
2. single Y, continuous multi T
3. single Y, discrete binary T
4. single Y, discrete multi T
5. multi Y, continuous single T
6. multi Y, continuous multi T
7. multi Y, discrete binary T
8. multi Y, discrete multi T

and also test X is None, W is None

### 1. single Y, continuous single T

In [8]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
def epsilon_sample(n): return np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
def eta_sample(n): return np.random.uniform(-1, 1, size=n)
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [9]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 1), (1000, 30), (10, 1))

In [10]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c3d5b2908>

In [20]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')



<econml.dml.SparseLinearDMLCateEstimator at 0x26c3eb14f28>

In [29]:
##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor())
est.fit(Y, T, X, W,inference='blb')


<econml.dml.ForestDMLCateEstimator at 0x26c3f4a3630>

In [30]:
# constant marginal effect
est.const_marginal_effect(X_test)

array([1.05073979, 1.05080607, 1.62132831, 1.46529216, 2.15953262,
       2.65365426, 3.22862293, 4.03628468, 5.10813295, 4.89213012])

In [31]:
est.const_marginal_effect_interval(X_test)

(array([0.32987887, 0.46116107, 0.98427852, 0.6285489 , 1.34719862,
        2.31835584, 2.65509201, 3.37478754, 3.87920528, 3.38655527]),
 array([1.77160072, 1.64045108, 2.2583781 , 2.30203542, 2.97186662,
        2.98895268, 3.80215384, 4.69778181, 6.33706062, 6.39770497]))

In [32]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,1.051,0.438,2.398,0.017,0.33,1.772
1,1.051,0.358,2.931,0.003,0.461,1.64
2,1.621,0.387,4.186,0.0,0.984,2.258
3,1.465,0.509,2.88,0.004,0.629,2.302
4,2.16,0.494,4.373,0.0,1.347,2.972
5,2.654,0.204,13.018,0.0,2.318,2.989
6,3.229,0.349,9.26,0.0,2.655,3.802
7,4.036,0.402,10.036,0.0,3.375,4.698
8,5.108,0.747,6.837,0.0,3.879,6.337
9,4.892,0.915,5.345,0.0,3.387,6.398


In [33]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,2.727

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.519,5.253,0.0,1.873,3.58

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.452,1.051,5.011

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.542,0.732,5.58


In [34]:
#effect
est.effect(X_test)

array([1.05073979, 1.05080607, 1.62132831, 1.46529216, 2.15953262,
       2.65365426, 3.22862293, 4.03628468, 5.10813295, 4.89213012])

In [35]:
est.effect_interval(X_test)

(array([0.32987887, 0.46116107, 0.98427852, 0.6285489 , 1.34719862,
        2.31835584, 2.65509201, 3.37478754, 3.87920528, 3.38655527]),
 array([1.77160072, 1.64045108, 2.2583781 , 2.30203542, 2.97186662,
        2.98895268, 3.80215384, 4.69778181, 6.33706062, 6.39770497]))

In [36]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,1.051,0.438,2.398,0.017,0.33,1.772
1,1.051,0.358,2.931,0.003,0.461,1.64
2,1.621,0.387,4.186,0.0,0.984,2.258
3,1.465,0.509,2.88,0.004,0.629,2.302
4,2.16,0.494,4.373,0.0,1.347,2.972
5,2.654,0.204,13.018,0.0,2.318,2.989
6,3.229,0.349,9.26,0.0,2.655,3.802
7,4.036,0.402,10.036,0.0,3.375,4.698
8,5.108,0.747,6.837,0.0,3.879,6.337
9,4.892,0.915,5.345,0.0,3.387,6.398


In [37]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,2.727

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.519,5.253,0.0,1.873,3.58

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.452,1.051,5.011

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.542,0.732,5.58


In [38]:
# marginal effect
est.marginal_effect_inference(0,X_test)

AttributeError: The treatment effect is linear, please call const_marginal_effect_inference!

### 2. single Y, continuous multi T

In [39]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
n_t=3
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
def epsilon_sample(n): return np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=(support_size,n_t))
def eta_sample(n): return np.random.uniform(-1, 1, size=(n,n_t))
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE1 = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
TE2 = np.array([2*x_i for x_i in X]).flatten()
TE3 = np.array([x_i**2 for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = TE1 * T[:,0]+TE2 * T[:,1]+TE3 * T[:,2] + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [40]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000, 3), (1000, 1), (1000, 30), (10, 1))

In [41]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c3f4e1f28>

In [50]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')

<econml.dml.SparseLinearDMLCateEstimator at 0x26c3fa66518>

In [51]:
est.const_marginal_effect(X_test)

array([[ 1.36123577,  0.21319071, -0.0090428 ],
       [ 1.37352282,  0.37700295,  0.01431691],
       [ 1.44571159,  0.53777285,  0.05713314],
       [ 1.60514465,  0.69360154,  0.12002598],
       [ 1.87916459,  0.84259012,  0.20361554],
       [ 2.295114  ,  0.98283969,  0.30852189],
       [ 2.88033546,  1.11245138,  0.43536513],
       [ 3.66217155,  1.22952629,  0.58476537],
       [ 4.66796486,  1.33216553,  0.75734269],
       [ 5.92505797,  1.41847021,  0.95371719]])

In [52]:
est.const_marginal_effect_interval(X_test)

(array([[ 1.0214989 , -0.213944  , -0.35321128],
        [ 1.01572118, -0.07130063, -0.34810958],
        [ 1.05602938,  0.0553309 , -0.33742535],
        [ 1.17125326,  0.16244782, -0.31912365],
        [ 1.38819528,  0.24470933, -0.29311695],
        [ 1.73068978,  0.29474109, -0.26219737],
        [ 2.21976111,  0.30369409, -0.23190401],
        [ 2.87462852,  0.26220435, -0.2095703 ],
        [ 3.71406376,  0.16122571, -0.20302088],
        [ 4.75745069, -0.00758291, -0.21953405]]),
 array([[1.70097264, 0.64032543, 0.33512567],
        [1.73132446, 0.82530653, 0.3767434 ],
        [1.83539379, 1.0202148 , 0.45169163],
        [2.03903604, 1.22475526, 0.55917561],
        [2.37013391, 1.4404709 , 0.70034802],
        [2.85953822, 1.6709383 , 0.87924115],
        [3.54090981, 1.92120867, 1.10263428],
        [4.44971459, 2.19684823, 1.37910104],
        [5.62186596, 2.50310535, 1.71770627],
        [7.09266524, 2.84452333, 2.12696843]]))

In [53]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,point_estimate,point_estimate,point_estimate,stderr,stderr,stderr,zstat,zstat,zstat,pvalue,pvalue,pvalue,ci_lower,ci_lower,ci_lower,ci_upper,ci_upper,ci_upper
Unnamed: 0_level_1,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2
0,1.361,0.213,-0.009,0.207,0.26,0.209,6.59,0.821,-0.043,0.0,0.412,0.966,1.021,-0.214,-0.353,1.701,0.64,0.335
1,1.374,0.377,0.014,0.218,0.273,0.22,6.314,1.383,0.065,0.0,0.167,0.948,1.016,-0.071,-0.348,1.731,0.825,0.377
2,1.446,0.538,0.057,0.237,0.293,0.24,6.102,1.834,0.238,0.0,0.067,0.812,1.056,0.055,-0.337,1.835,1.02,0.452
3,1.605,0.694,0.12,0.264,0.323,0.267,6.085,2.148,0.45,0.0,0.032,0.653,1.171,0.162,-0.319,2.039,1.225,0.559
4,1.879,0.843,0.204,0.298,0.363,0.302,6.296,2.318,0.674,0.0,0.02,0.5,1.388,0.245,-0.293,2.37,1.44,0.7
5,2.295,0.983,0.309,0.343,0.418,0.347,6.688,2.349,0.889,0.0,0.019,0.374,1.731,0.295,-0.262,2.86,1.671,0.879
6,2.88,1.112,0.435,0.402,0.492,0.406,7.172,2.263,1.073,0.0,0.024,0.283,2.22,0.304,-0.232,3.541,1.921,1.103
7,3.662,1.23,0.585,0.479,0.588,0.483,7.649,2.091,1.211,0.0,0.037,0.226,2.875,0.262,-0.21,4.45,2.197,1.379
8,4.668,1.332,0.757,0.58,0.712,0.584,8.049,1.871,1.297,0.0,0.061,0.195,3.714,0.161,-0.203,5.622,2.503,1.718
9,5.925,1.418,0.954,0.71,0.867,0.713,8.347,1.636,1.337,0.0,0.102,0.181,4.757,-0.008,-0.22,7.093,2.845,2.127


In [54]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1,2,3
,mean_point,mean_point,mean_point
,T0,T1,T2
Y0,2.71,0.874,0.343

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
,stderr_mean,stderr_mean,stderr_mean,zstat,zstat,zstat,pvalue,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper,ci_mean_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,0.407,0.499,0.41,6.664,1.753,0.836,0.0,0.08,0.403,2.041,0.054,-0.332,3.378,1.694,1.017

0,1,2,3,4,5,6,7,8,9
,std_point,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper,pct_point_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,1.496,0.391,0.316,1.367,0.287,0.001,5.359,1.38,0.865

0,1,2,3,4,5,6,7,8,9
,std_point,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper,ci_point_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,1.551,0.634,0.517,1.162,0.008,-0.305,5.934,2.061,1.364


In [55]:
#effect
est.effect(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


array([1.56538368, 1.76484268, 2.04061758, 2.41877217, 2.92537024,
       3.58647558, 4.42815198, 5.47646322, 6.75747308, 8.29724537])

In [56]:
est.effect_interval(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


(array([0.69735957, 0.85658879, 1.05997192, 1.33414509, 1.7012763 ,
        2.17734109, 2.77281943, 3.49513872, 4.35148592, 5.35065055]),
 array([ 2.43340779,  2.67309656,  3.02126323,  3.50339926,  4.14946419,
         4.99561008,  6.08348453,  7.45778772,  9.16346025, 11.24384019]))

In [57]:
est.effect_inference(X_test).summary_frame()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,1.565,0.528,2.966,0.003,0.697,2.433
1,1.765,0.552,3.196,0.001,0.857,2.673
2,2.041,0.596,3.423,0.001,1.06,3.021
3,2.419,0.659,3.668,0.0,1.334,3.503
4,2.925,0.744,3.931,0.0,1.701,4.149
5,3.586,0.857,4.186,0.0,2.177,4.996
6,4.428,1.006,4.4,0.0,2.773,6.083
7,5.476,1.205,4.546,0.0,3.495,7.458
8,6.757,1.463,4.62,0.0,4.351,9.163
9,8.297,1.791,4.632,0.0,5.351,11.244


In [58]:
est.effect_inference(X_test).population_summary().print()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


0,1
,mean_point
,T0
Y0,3.926

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,1.023,3.836,0.0,2.243,5.609

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,2.167,1.655,7.604

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,2.397,1.188,8.697


### 3. single Y, discrete binary T

In [59]:
# Treatment effect function
def exp_te(x):
    return np.exp(2 * x[0])# DGP constants

np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n) 

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([exp_te(x_i) for x_i in X])
# Define treatment
log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
T_sigmoid = 1/(1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
# Define the outcome
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [60]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 4), (1000, 30), (10, 4))

In [63]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')
te_pred = est.effect(X_test)

In [75]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')

<econml.dml.SparseLinearDMLCateEstimator at 0x26c40baf518>

In [84]:

##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(min_samples_leaf=10),
                              discrete_treatment=True)
est.fit(Y, T, X, W,inference='blb')


<econml.dml.ForestDMLCateEstimator at 0x26c40bd1128>

In [85]:
#constant marginal effect
est.const_marginal_effect(X_test)

array([[1.01934029],
       [1.3455195 ],
       [1.56239304],
       [1.51756366],
       [2.77382572],
       [3.03764577],
       [3.36999007],
       [5.30551147],
       [5.72400929],
       [6.64080735]])

In [86]:
est.const_marginal_effect_interval(X_test)

(array([[-0.05416631],
        [ 1.03440767],
        [ 0.94834961],
        [ 0.86743511],
        [ 2.06604063],
        [ 2.51379517],
        [ 2.68541775],
        [ 4.19508649],
        [ 4.81809514],
        [ 5.43886072]]), array([[2.09284688],
        [1.65663132],
        [2.17643647],
        [2.16769221],
        [3.48161082],
        [3.56149637],
        [4.05456239],
        [6.41593645],
        [6.62992344],
        [7.84275398]]))

In [87]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,1.019,0.653,1.562,0.118,-0.054,2.093
1,1.346,0.189,7.114,0.0,1.034,1.657
2,1.562,0.373,4.185,0.0,0.948,2.176
3,1.518,0.395,3.84,0.0,0.867,2.168
4,2.774,0.43,6.446,0.0,2.066,3.482
5,3.038,0.318,9.538,0.0,2.514,3.561
6,3.37,0.416,8.097,0.0,2.685,4.055
7,5.306,0.675,7.859,0.0,4.195,6.416
8,5.724,0.551,10.393,0.0,4.818,6.63
9,6.641,0.731,9.088,0.0,5.439,7.843


In [88]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,3.23

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.501,6.446,0.0,2.406,4.054

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.913,1.166,6.228

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.977,0.868,6.731


In [89]:
# effect
est.effect(X_test)

array([1.01934029, 1.3455195 , 1.56239304, 1.51756366, 2.77382572,
       3.03764577, 3.36999007, 5.30551147, 5.72400929, 6.64080735])

In [90]:
est.effect_interval(X_test)

(array([-0.05416631,  1.03440767,  0.94834961,  0.86743511,  2.06604063,
         2.51379517,  2.68541775,  4.19508649,  4.81809514,  5.43886072]),
 array([2.09284688, 1.65663132, 2.17643647, 2.16769221, 3.48161082,
        3.56149637, 4.05456239, 6.41593645, 6.62992344, 7.84275398]))

In [91]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,1.019,0.653,1.562,0.118,-0.054,2.093
1,1.346,0.189,7.114,0.0,1.034,1.657
2,1.562,0.373,4.185,0.0,0.948,2.176
3,1.518,0.395,3.84,0.0,0.867,2.168
4,2.774,0.43,6.446,0.0,2.066,3.482
5,3.038,0.318,9.538,0.0,2.514,3.561
6,3.37,0.416,8.097,0.0,2.685,4.055
7,5.306,0.675,7.859,0.0,4.195,6.416
8,5.724,0.551,10.393,0.0,4.818,6.63
9,6.641,0.731,9.088,0.0,5.439,7.843


In [92]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,3.23

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.501,6.446,0.0,2.406,4.054

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.913,1.166,6.228

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.977,0.868,6.731


### 4. single Y, discrete multi T

In [93]:
np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=n)

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([2*x_i[0] for x_i in X])
# Define treatment
T =np.random.choice(3, n)
# Define the outcome
Y =TE*T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [94]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 4), (1000, 30), (10, 4))

In [95]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c40be6240>

In [104]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')

<econml.dml.SparseLinearDMLCateEstimator at 0x26c4a2f6fd0>

In [105]:
est.const_marginal_effect(X_test)

array([[0.99576967, 0.43244904],
       [2.07037823, 0.25882269],
       [1.70434155, 0.79053351],
       [2.34742648, 2.17777521],
       [5.30550283, 2.71539034],
       [2.69726967, 1.85163572],
       [3.39866207, 3.25706621],
       [2.50247955, 3.49392521],
       [3.55973865, 3.73205964],
       [9.75094348, 5.97539835]])

In [106]:
est.const_marginal_effect_interval(X_test)

(array([[-0.35404333, -1.00201643],
        [ 0.26398747, -1.61653287],
        [ 0.58548775, -0.37924153],
        [-0.17890284, -0.51416618],
        [ 0.82498543, -2.0292801 ],
        [ 1.3047792 ,  0.39833826],
        [ 1.3182138 ,  1.02664563],
        [ 1.24020748,  2.12927577],
        [ 1.79357977,  1.85283389],
        [ 2.99861885, -1.19019372]]), array([[ 2.34558267,  1.86691451],
        [ 3.876769  ,  2.13417825],
        [ 2.82319535,  1.96030854],
        [ 4.8737558 ,  4.8697166 ],
        [ 9.78602024,  7.46006077],
        [ 4.08976014,  3.30493318],
        [ 5.47911035,  5.48748679],
        [ 3.76475161,  4.85857465],
        [ 5.32589753,  5.61128539],
        [16.50326811, 13.14099042]]))

In [107]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,0.996,0.432,0.821,0.872,1.213,0.496,0.225,0.62,-0.354,-1.002,2.346,1.867
1,2.07,0.259,1.098,1.14,1.885,0.227,0.059,0.82,0.264,-1.617,3.877,2.134
2,1.704,0.791,0.68,0.711,2.506,1.112,0.012,0.266,0.585,-0.379,2.823,1.96
3,2.347,2.178,1.536,1.637,1.528,1.331,0.126,0.183,-0.179,-0.514,4.874,4.87
4,5.306,2.715,2.724,2.885,1.948,0.941,0.051,0.347,0.825,-2.029,9.786,7.46
5,2.697,1.852,0.847,0.884,3.186,2.096,0.001,0.036,1.305,0.398,4.09,3.305
6,3.399,3.257,1.265,1.356,2.687,2.402,0.007,0.016,1.318,1.027,5.479,5.487
7,2.502,3.494,0.767,0.83,3.261,4.211,0.001,0.0,1.24,2.129,3.765,4.859
8,3.56,3.732,1.074,1.142,3.315,3.267,0.001,0.001,1.794,1.853,5.326,5.611
9,9.751,5.975,4.105,4.356,2.375,1.372,0.018,0.17,2.999,-1.19,16.503,13.141


In [108]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1,2
,mean_point,mean_point
,T0,T1
Y0,3.433,2.469

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,1.817,1.928,1.889,1.281,0.059,0.2,0.444,-0.702,6.423,5.639

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,2.383,1.671,1.315,0.337,7.75,4.966

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,2.997,2.551,0.406,-0.762,10.199,6.887


In [109]:
# effect
est.effect(X_test)

array([0.99576967, 2.07037823, 1.70434155, 2.34742648, 5.30550283,
       2.69726967, 3.39866207, 2.50247955, 3.55973865, 9.75094348])

In [110]:
est.effect_interval(X_test)

(array([-0.35404333,  0.26398747,  0.58548775, -0.17890284,  0.82498543,
         1.3047792 ,  1.3182138 ,  1.24020748,  1.79357977,  2.99861885]),
 array([ 2.34558267,  3.876769  ,  2.82319535,  4.8737558 ,  9.78602024,
         4.08976014,  5.47911035,  3.76475161,  5.32589753, 16.50326811]))

In [111]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.996,0.821,1.213,0.225,-0.354,2.346
1,2.07,1.098,1.885,0.059,0.264,3.877
2,1.704,0.68,2.506,0.012,0.585,2.823
3,2.347,1.536,1.528,0.126,-0.179,4.874
4,5.306,2.724,1.948,0.051,0.825,9.786
5,2.697,0.847,3.186,0.001,1.305,4.09
6,3.399,1.265,2.687,0.007,1.318,5.479
7,2.502,0.767,3.261,0.001,1.24,3.765
8,3.56,1.074,3.315,0.001,1.794,5.326
9,9.751,4.105,2.375,0.018,2.999,16.503


In [112]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,3.433

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,1.817,1.889,0.059,0.444,6.423

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,2.383,1.315,7.75

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,2.997,0.406,10.088


### 5. multi Y, continuous single T

In [113]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
def epsilon_sample(n): return np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
def eta_sample(n): return np.random.uniform(-1, 1, size=n)
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = (TE * T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [114]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000,), (1000, 1), (1000, 30), (10, 1))

In [115]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c4f4253c8>

In [133]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')

<econml.dml.SparseLinearDMLCateEstimator at 0x26c4f49d710>

In [124]:

##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor())
est.fit(Y, T, X, W,inference='blb')

<econml.dml.ForestDMLCateEstimator at 0x26c4f453cc0>

In [134]:
est.const_marginal_effect(X_test)

array([[1.27703862, 1.34162442, 1.31719037],
       [1.40024972, 1.45796759, 1.45434486],
       [1.54713172, 1.60493641, 1.61477253],
       [1.74232825, 1.80487407, 1.82168682],
       [2.01048297, 2.08012375, 2.0983012 ],
       [2.37623951, 2.45302865, 2.46782913],
       [2.86424152, 2.94593196, 2.95348408],
       [3.49913265, 3.58117686, 3.57847951],
       [4.30555653, 4.38110654, 4.36602887],
       [5.30815681, 5.3680642 , 5.33934564]])

In [135]:
est.const_marginal_effect_interval(X_test)

(array([[1.01470173, 1.06846049, 1.04447512],
        [1.12291866, 1.16919066, 1.16604227],
        [1.241737  , 1.28693759, 1.29729604],
        [1.39841434, 1.4467663 , 1.46416727],
        [1.61830367, 1.67175861, 1.69060683],
        [1.92376189, 1.98187658, 1.99745096],
        [2.33419315, 2.39400767, 2.40246637],
        [2.86691872, 2.92287049, 2.92125446],
        [3.53851903, 3.58241223, 3.56864648],
        [4.36603046, 4.38705485, 4.35994767]]),
 array([[1.53937551, 1.61478836, 1.58990561],
        [1.67758078, 1.74674453, 1.74264746],
        [1.85252644, 1.92293524, 1.93224902],
        [2.08624217, 2.16298184, 2.17920637],
        [2.40266227, 2.4884889 , 2.50599557],
        [2.82871714, 2.92418072, 2.9382073 ],
        [3.3942899 , 3.49785624, 3.50450179],
        [4.13134657, 4.23948322, 4.23570456],
        [5.07259402, 5.17980085, 5.16341127],
        [6.25028316, 6.34907355, 6.31874361]]))

In [136]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,1.277,0.159,8.007,0.0,1.015,1.539
0,Y1,1.342,0.166,8.079,0.0,1.068,1.615
0,Y2,1.317,0.166,7.944,0.0,1.044,1.59
1,Y0,1.4,0.169,8.305,0.0,1.123,1.678
1,Y1,1.458,0.176,8.304,0.0,1.169,1.747
1,Y2,1.454,0.175,8.297,0.0,1.166,1.743
2,Y0,1.547,0.186,8.333,0.0,1.242,1.853
2,Y1,1.605,0.193,8.302,0.0,1.287,1.923
2,Y2,1.615,0.193,8.366,0.0,1.297,1.932
3,Y0,1.742,0.209,8.333,0.0,1.398,2.086


In [137]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,2.633
Y1,2.702
Y2,2.701

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.326,8.081,0.0,2.097,3.169
Y1,0.339,7.964,0.0,2.144,3.26
Y2,0.339,7.974,0.0,2.144,3.258

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.286,1.332,4.857
Y1,1.289,1.394,4.924
Y2,1.281,1.379,4.901

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.327,1.208,5.321
Y1,1.333,1.266,5.397
Y2,1.325,1.248,5.367


In [138]:
# effect
est.effect(X_test)

array([[1.27703862, 1.34162442, 1.31719037],
       [1.40024972, 1.45796759, 1.45434486],
       [1.54713172, 1.60493641, 1.61477253],
       [1.74232825, 1.80487407, 1.82168682],
       [2.01048297, 2.08012375, 2.0983012 ],
       [2.37623951, 2.45302865, 2.46782913],
       [2.86424152, 2.94593196, 2.95348408],
       [3.49913265, 3.58117686, 3.57847951],
       [4.30555653, 4.38110654, 4.36602887],
       [5.30815681, 5.3680642 , 5.33934564]])

In [139]:
est.effect_interval(X_test)

(array([[1.01470173, 1.06846049, 1.04447512],
        [1.12291866, 1.16919066, 1.16604227],
        [1.241737  , 1.28693759, 1.29729604],
        [1.39841434, 1.4467663 , 1.46416727],
        [1.61830367, 1.67175861, 1.69060683],
        [1.92376189, 1.98187658, 1.99745096],
        [2.33419315, 2.39400767, 2.40246637],
        [2.86691872, 2.92287049, 2.92125446],
        [3.53851903, 3.58241223, 3.56864648],
        [4.36603046, 4.38705485, 4.35994767]]),
 array([[1.53937551, 1.61478836, 1.58990561],
        [1.67758078, 1.74674453, 1.74264746],
        [1.85252644, 1.92293524, 1.93224902],
        [2.08624217, 2.16298184, 2.17920637],
        [2.40266227, 2.4884889 , 2.50599557],
        [2.82871714, 2.92418072, 2.9382073 ],
        [3.3942899 , 3.49785624, 3.50450179],
        [4.13134657, 4.23948322, 4.23570456],
        [5.07259402, 5.17980085, 5.16341127],
        [6.25028316, 6.34907355, 6.31874361]]))

In [140]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,1.277,0.159,8.007,0.0,1.015,1.539
0,Y1,1.342,0.166,8.079,0.0,1.068,1.615
0,Y2,1.317,0.166,7.944,0.0,1.044,1.59
1,Y0,1.4,0.169,8.305,0.0,1.123,1.678
1,Y1,1.458,0.176,8.304,0.0,1.169,1.747
1,Y2,1.454,0.175,8.297,0.0,1.166,1.743
2,Y0,1.547,0.186,8.333,0.0,1.242,1.853
2,Y1,1.605,0.193,8.302,0.0,1.287,1.923
2,Y2,1.615,0.193,8.366,0.0,1.297,1.932
3,Y0,1.742,0.209,8.333,0.0,1.398,2.086


In [141]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,2.633
Y1,2.702
Y2,2.701

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.326,8.081,0.0,2.097,3.169
Y1,0.339,7.964,0.0,2.144,3.26
Y2,0.339,7.974,0.0,2.144,3.258

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.286,1.332,4.857
Y1,1.289,1.394,4.924
Y2,1.281,1.379,4.901

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.327,1.208,5.321
Y1,1.333,1.266,5.397
Y2,1.325,1.248,5.367


### 6. multi Y, continuous multi T

In [142]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
n_t=2
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
def epsilon_sample(n): return np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=(support_size,n_t))
def eta_sample(n): return np.random.uniform(-1, 1, size=(n,n_t))
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE1 = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
TE2 = np.array([2*x_i for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = (TE1 * T[:,0]).reshape(-1,1)+(TE2 * T[:,1]).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [143]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000, 2), (1000, 1), (1000, 30), (10, 1))

In [144]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c4f4d0908>

In [153]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')


<econml.dml.SparseLinearDMLCateEstimator at 0x26c4f503668>

In [154]:
est.const_marginal_effect(X_test)

array([[[0.93916984, 0.40085738],
        [1.18996186, 0.56856209],
        [0.96394706, 0.72166158]],

       [[1.05668369, 0.52315064],
        [1.25540552, 0.74208671],
        [1.074231  , 0.9011107 ]],

       [[1.28159455, 0.66263806],
        [1.44029605, 0.92238436],
        [1.29869006, 1.0875689 ]],

       [[1.61535523, 0.82045005],
        [1.74711221, 1.11112711],
        [1.63937003, 1.28263678]],

       [[2.05941858, 0.99771703],
        [2.17833276, 1.30998705],
        [2.09831671, 1.48791491]],

       [[2.61523741, 1.19556941],
        [2.73643645, 1.52063623],
        [2.67757587, 1.70500389]],

       [[3.28426455, 1.4151376 ],
        [3.42390204, 1.74474674],
        [3.37919333, 1.93550431]],

       [[4.06795283, 1.65755202],
        [4.24320828, 1.98399065],
        [4.20521485, 2.18101675]],

       [[4.96775508, 1.92394308],
        [5.19683393, 2.24004003],
        [5.15768624, 2.4431418 ]],

       [[5.98512411, 2.2154412 ],
        [6.28725774, 2.5145669

In [155]:
est.const_marginal_effect_interval(X_test)

(array([[[0.65328407, 0.1509328 ],
         [0.88467822, 0.30167969],
         [0.65143504, 0.44846004]],
 
        [[0.75669995, 0.2607057 ],
         [0.93483794, 0.46178776],
         [0.74607315, 0.61417495]],
 
        [[0.96015528, 0.38046222],
         [1.0969014 , 0.62144119],
         [0.94716464, 0.77950013]],
 
        [[1.26263732, 0.50946594],
         [1.37016607, 0.7794697 ],
         [1.25349871, 0.94312652]],
 
        [[1.66235343, 0.64612597],
         [1.75321822, 0.93407326],
         [1.66313648, 1.10310039]],
 
        [[2.15702177, 0.78815894],
         [2.24439396, 1.08323226],
         [2.173883  , 1.25724325]],
 
        [[2.74430648, 0.9329887 ],
         [2.84227311, 1.22513692],
         [2.78379283, 1.40359137]],
 
        [[3.42215539, 1.07814519],
         [3.54593976, 1.35840395],
         [3.4914367 , 1.54061766]],
 
        [[4.18891164, 1.22147512],
         [4.35496194, 1.48205207],
         [4.29588075, 1.66720651]],
 
        [[5.04324957, 1.3611

In [156]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,Unnamed: 1_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,Y0,0.939,0.401,0.174,0.152,5.404,2.638,0.0,0.008,0.653,0.151,1.225,0.651
0,Y1,1.19,0.569,0.186,0.162,6.411,3.504,0.0,0.0,0.885,0.302,1.495,0.835
0,Y2,0.964,0.722,0.19,0.166,5.074,4.345,0.0,0.0,0.651,0.448,1.276,0.995
1,Y0,1.057,0.523,0.182,0.16,5.794,3.279,0.0,0.001,0.757,0.261,1.357,0.786
1,Y1,1.255,0.742,0.195,0.17,6.442,4.355,0.0,0.0,0.935,0.462,1.576,1.022
1,Y2,1.074,0.901,0.2,0.174,5.384,5.166,0.0,0.0,0.746,0.614,1.402,1.188
2,Y0,1.282,0.663,0.195,0.172,6.558,3.863,0.0,0.0,0.96,0.38,1.603,0.945
2,Y1,1.44,0.922,0.209,0.183,6.899,5.041,0.0,0.0,1.097,0.621,1.784,1.223
2,Y2,1.299,1.088,0.214,0.187,6.077,5.807,0.0,0.0,0.947,0.78,1.65,1.396
3,Y0,1.615,0.82,0.214,0.189,7.533,4.34,0.0,0.0,1.263,0.509,1.968,1.131


In [157]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1,2
,mean_point,mean_point
,T0,T1
Y0,2.787,1.181
Y1,2.97,1.466
Y2,2.873,1.647

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.331,0.297,8.413,3.975,0.0,0.0,2.242,0.693,3.332,1.67
Y1,0.357,0.32,8.319,4.586,0.0,0.0,2.383,0.94,3.557,1.992
Y2,0.365,0.327,7.862,5.034,0.0,0.0,2.272,1.109,3.474,2.185

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,1.658,0.581,0.992,0.456,5.527,2.084
Y1,1.686,0.618,1.219,0.647,5.797,2.391
Y2,1.735,0.636,1.014,0.802,5.752,2.597

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,1.69,0.653,0.865,0.338,6.003,2.413
Y1,1.723,0.696,1.079,0.521,6.318,2.749
Y2,1.773,0.715,0.876,0.674,6.261,2.961


In [158]:
# effect
est.effect(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


array([[1.34002722, 1.75852395, 1.68560863],
       [1.57983433, 1.99749222, 1.9753417 ],
       [1.9442326 , 2.36268041, 2.38625896],
       [2.43580528, 2.85823932, 2.92200681],
       [3.0571356 , 3.48831981, 3.58623162],
       [3.81080681, 4.25707268, 4.38257977],
       [4.69940215, 5.16864878, 5.31469763],
       [5.72550485, 6.22719893, 6.3862316 ],
       [6.89169816, 7.43687395, 7.60082804],
       [8.20056531, 8.80182469, 8.96213334]])

In [159]:
est.effect_interval(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


(array([[0.94870706, 1.34065202, 1.25784251],
        [1.16747985, 1.5568173 , 1.52423265],
        [1.49963832, 1.88780635, 1.90014104],
        [1.9445707 , 2.3333512 , 2.3846906 ],
        [2.50054276, 2.89227884, 2.97607784],
        [3.16501144, 3.56311535, 3.6721912 ],
        [3.93518516, 4.34469818, 4.47123786],
        [4.80851081, 5.23648001, 5.37205484],
        [5.78288954, 6.23849403, 6.37407336],
        [6.85665083, 7.35115741, 7.47711775]]),
 array([[ 1.73134738,  2.17639588,  2.11337476],
        [ 1.99218881,  2.43816715,  2.42645075],
        [ 2.38882688,  2.83755446,  2.87237689],
        [ 2.92703987,  3.38312745,  3.45932302],
        [ 3.61372845,  4.08436077,  4.1963854 ],
        [ 4.45660218,  4.95103001,  5.09296833],
        [ 5.46361913,  5.99259938,  6.1581574 ],
        [ 6.64249888,  7.21791785,  7.40040836],
        [ 8.00050678,  8.63525388,  8.82758272],
        [ 9.5444798 , 10.25249196, 10.44714892]]))

In [160]:
est.effect_inference(X_test).summary_frame()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,1.34,0.238,5.633,0.0,0.949,1.731
0,Y1,1.759,0.254,6.922,0.0,1.341,2.176
0,Y2,1.686,0.26,6.482,0.0,1.258,2.113
1,Y0,1.58,0.251,6.302,0.0,1.167,1.992
1,Y1,1.997,0.268,7.456,0.0,1.557,2.438
1,Y2,1.975,0.274,7.203,0.0,1.524,2.426
2,Y0,1.944,0.27,7.193,0.0,1.5,2.389
2,Y1,2.363,0.289,8.184,0.0,1.888,2.838
2,Y2,2.386,0.296,8.074,0.0,1.9,2.872
3,Y0,2.436,0.299,8.156,0.0,1.945,2.927


In [161]:
est.effect_inference(X_test).population_summary().print()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


0,1
,mean_point
,T0
Y0,3.969
Y1,4.436
Y2,4.52

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.469,8.463,0.0,3.197,4.74
Y1,0.505,8.779,0.0,3.605,5.267
Y2,0.517,8.739,0.0,3.669,5.371

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,2.236,1.448,7.612
Y1,2.294,1.866,8.188
Y2,2.363,1.816,8.35

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,2.285,1.269,8.247
Y1,2.349,1.678,8.849
Y2,2.419,1.619,9.031


### 7. multi Y, discrete binary T

In [162]:
# Treatment effect function
def exp_te(x):
    return np.exp(2 * x[0])# DGP constants

np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n) 

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([exp_te(x_i) for x_i in X])
# Define treatment
log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
T_sigmoid = 1/(1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
# Define the outcome
Y = (TE * T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [163]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000,), (1000, 4), (1000, 30), (10, 4))

In [164]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')
te_pred = est.effect(X_test)

In [173]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')


<econml.dml.SparseLinearDMLCateEstimator at 0x26c50a5fbe0>

In [182]:

##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(min_samples_leaf=10),
                              discrete_treatment=True)
est.fit(Y, T, X, W,inference='blb')

<econml.dml.ForestDMLCateEstimator at 0x26c50a8ac18>

In [183]:
est.const_marginal_effect(X_test)

array([[[1.46246037],
        [1.42876812],
        [1.48112907]],

       [[1.36724009],
        [1.67099648],
        [2.0251    ]],

       [[1.67713647],
        [1.87698555],
        [1.98563145]],

       [[1.95998833],
        [2.11462079],
        [2.5307241 ]],

       [[2.45489066],
        [2.58297718],
        [2.76517442]],

       [[3.26927448],
        [3.56805003],
        [4.00636865]],

       [[3.81096307],
        [3.9061488 ],
        [4.32115471]],

       [[4.38316541],
        [4.69323211],
        [5.02154711]],

       [[5.14986355],
        [5.22483241],
        [5.78317602]],

       [[5.56802902],
        [5.48482809],
        [5.80372215]]])

In [184]:
est.const_marginal_effect_interval(X_test)

(array([[[0.80450506],
         [0.1744264 ],
         [0.66143112]],
 
        [[0.51884929],
         [0.24158093],
         [1.1584269 ]],
 
        [[0.98448872],
         [1.3641186 ],
         [1.39223364]],
 
        [[1.44706822],
         [1.55785611],
         [1.89091893]],
 
        [[2.02183411],
         [1.87929015],
         [2.03730237]],
 
        [[2.55651823],
         [2.81142691],
         [3.50904717]],
 
        [[3.23904269],
         [3.18064177],
         [3.68827915]],
 
        [[3.7167488 ],
         [3.85922276],
         [4.2327593 ]],
 
        [[4.03807403],
         [3.85587816],
         [4.46090369]],
 
        [[4.76701517],
         [4.52877089],
         [4.7920904 ]]]), array([[[2.12041569],
         [2.68310983],
         [2.30082703]],
 
        [[2.2156309 ],
         [3.10041204],
         [2.8917731 ]],
 
        [[2.36978423],
         [2.38985251],
         [2.57902926]],
 
        [[2.47290843],
         [2.67138548],
         [3.1705292

In [185]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,1.462,0.4,3.656,0.0,0.805,2.12
0,Y1,1.429,0.763,1.874,0.061,0.174,2.683
0,Y2,1.481,0.498,2.972,0.003,0.661,2.301
1,Y0,1.367,0.516,2.651,0.008,0.519,2.216
1,Y1,1.671,0.869,1.923,0.054,0.242,3.1
1,Y2,2.025,0.527,3.843,0.0,1.158,2.892
2,Y0,1.677,0.421,3.983,0.0,0.984,2.37
2,Y1,1.877,0.312,6.02,0.0,1.364,2.39
2,Y2,1.986,0.361,5.504,0.0,1.392,2.579
3,Y0,1.96,0.312,6.285,0.0,1.447,2.473


In [186]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,3.11
Y1,3.255
Y2,3.572

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.44,7.067,0.0,2.386,3.834
Y1,0.585,5.562,0.0,2.292,4.218
Y2,0.5,7.149,0.0,2.75,4.394

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.48,1.41,5.38
Y1,1.447,1.538,5.368
Y2,1.541,1.708,5.794

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.544,1.058,5.788
Y1,1.561,0.982,5.855
Y2,1.62,1.315,6.261


In [187]:
# effect
est.effect(X_test)

array([[1.46246037, 1.42876812, 1.48112907],
       [1.36724009, 1.67099648, 2.0251    ],
       [1.67713647, 1.87698555, 1.98563145],
       [1.95998833, 2.11462079, 2.5307241 ],
       [2.45489066, 2.58297718, 2.76517442],
       [3.26927448, 3.56805003, 4.00636865],
       [3.81096307, 3.9061488 , 4.32115471],
       [4.38316541, 4.69323211, 5.02154711],
       [5.14986355, 5.22483241, 5.78317602],
       [5.56802902, 5.48482809, 5.80372215]])

In [188]:
est.effect_interval(X_test)

(array([[0.80450506, 0.1744264 , 0.66143112],
        [0.51884929, 0.24158093, 1.1584269 ],
        [0.98448872, 1.3641186 , 1.39223364],
        [1.44706822, 1.55785611, 1.89091893],
        [2.02183411, 1.87929015, 2.03730237],
        [2.55651823, 2.81142691, 3.50904717],
        [3.23904269, 3.18064177, 3.68827915],
        [3.7167488 , 3.85922276, 4.2327593 ],
        [4.03807403, 3.85587816, 4.46090369],
        [4.76701517, 4.52877089, 4.7920904 ]]),
 array([[2.12041569, 2.68310983, 2.30082703],
        [2.2156309 , 3.10041204, 2.8917731 ],
        [2.36978423, 2.38985251, 2.57902926],
        [2.47290843, 2.67138548, 3.17052928],
        [2.88794722, 3.2866642 , 3.49304646],
        [3.98203072, 4.32467315, 4.50369012],
        [4.38288344, 4.63165584, 4.95403027],
        [5.04958203, 5.52724145, 5.81033492],
        [6.26165307, 6.59378666, 7.10544835],
        [6.36904286, 6.4408853 , 6.81535389]]))

In [189]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,1.462,0.4,3.656,0.0,0.805,2.12
0,Y1,1.429,0.763,1.874,0.061,0.174,2.683
0,Y2,1.481,0.498,2.972,0.003,0.661,2.301
1,Y0,1.367,0.516,2.651,0.008,0.519,2.216
1,Y1,1.671,0.869,1.923,0.054,0.242,3.1
1,Y2,2.025,0.527,3.843,0.0,1.158,2.892
2,Y0,1.677,0.421,3.983,0.0,0.984,2.37
2,Y1,1.877,0.312,6.02,0.0,1.364,2.39
2,Y2,1.986,0.361,5.504,0.0,1.392,2.579
3,Y0,1.96,0.312,6.285,0.0,1.447,2.473


In [190]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,3.11
Y1,3.255
Y2,3.572

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.44,7.067,0.0,2.386,3.834
Y1,0.585,5.562,0.0,2.292,4.218
Y2,0.5,7.149,0.0,2.75,4.394

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.48,1.41,5.38
Y1,1.447,1.538,5.368
Y2,1.541,1.708,5.794

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.544,1.058,5.788
Y1,1.561,0.982,5.855
Y2,1.62,1.315,6.261


### 8. multi Y, discrete multi T

In [191]:
np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=(n,3))

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([2*x_i[0] for x_i in X])
# Define treatment
T =np.random.choice(3, n)
# Define the outcome
Y =(TE*T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [192]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x26c50acd860>

In [203]:

##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')


<econml.dml.SparseLinearDMLCateEstimator at 0x26c50b1ff98>

In [204]:
est.const_marginal_effect(X_test)

array([[[-0.05255946,  0.46349196],
        [ 0.11191726,  0.69459418],
        [ 0.19139689,  0.41428258]],

       [[ 0.29815516,  0.68978678],
        [ 0.41575269,  0.84822045],
        [ 0.31700538,  0.86529086]],

       [[ 0.30349351,  1.32576593],
        [ 0.50579208,  1.48254337],
        [ 0.76413994,  0.83121615]],

       [[-0.39019954,  3.86213411],
        [-0.20697011,  3.19694953],
        [ 1.01174308,  1.52628075]],

       [[-0.32665243,  4.34239432],
        [ 0.05759803,  3.29408013],
        [ 0.85135042,  2.67148429]],

       [[ 0.79754271,  3.15226453],
        [ 1.08327032,  3.09713997],
        [ 1.84181676,  1.5504902 ]],

       [[ 0.77303061,  4.46750534],
        [ 0.87544772,  3.60977336],
        [ 1.16206767,  3.5965658 ]],

       [[ 0.14075585,  7.40776429],
        [ 0.59696974,  5.9328942 ],
        [ 2.58072076,  3.23303214]],

       [[ 1.36307697,  5.26382597],
        [ 1.77185327,  4.71414422],
        [ 2.59507551,  3.21027159]],

       [[ 

In [205]:
est.const_marginal_effect_interval(X_test)

(array([[[-0.43516324,  0.03874764],
         [-0.36682777,  0.16311947],
         [-0.31407796, -0.146866  ]],
 
        [[ 0.07126749,  0.44336584],
         [ 0.13185235,  0.53987848],
         [ 0.01725401,  0.5397332 ]],
 
        [[-0.23924279,  0.7282364 ],
         [-0.17332382,  0.73486572],
         [ 0.04710691,  0.04179335]],
 
        [[-2.26495065,  1.79851808],
         [-2.55281152,  0.61478493],
         [-1.46507389, -1.20005406]],
 
        [[-2.25757316,  2.21205368],
         [-2.35852739,  0.62842426],
         [-1.69967474, -0.14300337]],
 
        [[-0.46426495,  1.76508476],
         [-0.49560629,  1.36138757],
         [ 0.1747866 , -0.28217461]],
 
        [[-0.67617062,  2.88927822],
         [-0.937911  ,  1.63496693],
         [-0.75253646,  1.51149984]],
 
        [[-3.14304199,  3.78354175],
         [-3.5119858 ,  1.39797152],
         [-1.75765045, -1.55508927]],
 
        [[-0.58902638,  3.11161543],
         [-0.67077758,  2.02112295],
         [ 0.0

In [206]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,Unnamed: 1_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,Y0,-0.053,0.463,0.233,0.258,-0.226,1.795,0.821,0.073,-0.435,0.039,0.33,0.888
0,Y1,0.112,0.695,0.291,0.323,0.385,2.15,0.701,0.032,-0.367,0.163,0.591,1.226
0,Y2,0.191,0.414,0.307,0.341,0.623,1.214,0.533,0.225,-0.314,-0.147,0.697,0.975
1,Y0,0.298,0.69,0.138,0.15,2.162,4.604,0.031,0.0,0.071,0.443,0.525,0.936
1,Y1,0.416,0.848,0.173,0.187,2.409,4.525,0.016,0.0,0.132,0.54,0.7,1.157
1,Y2,0.317,0.865,0.182,0.198,1.74,4.372,0.082,0.0,0.017,0.54,0.617,1.191
2,Y0,0.303,1.326,0.33,0.363,0.92,3.65,0.358,0.0,-0.239,0.728,0.846,1.923
2,Y1,0.506,1.483,0.413,0.455,1.225,3.262,0.221,0.001,-0.173,0.735,1.185,2.23
2,Y2,0.764,0.831,0.436,0.48,1.753,1.732,0.08,0.083,0.047,0.042,1.481,1.621
3,Y0,-0.39,3.862,1.14,1.255,-0.342,3.078,0.732,0.002,-2.265,1.799,1.485,5.926


In [207]:
est.const_marginal_effect_inference(X_test).population_summary().print()

0,1,2
,mean_point,mean_point
,T0,T1
Y0,0.385,3.975
Y1,0.68,3.443
Y2,1.567,2.051

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,1.306,1.438,0.295,2.764,0.768,0.006,-1.762,1.61,2.533,6.341
Y1,1.634,1.799,0.416,1.914,0.677,0.056,-2.007,0.484,3.367,6.403
Y2,1.725,1.9,0.908,1.08,0.364,0.28,-1.27,-1.074,4.404,5.176

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.545,2.602,-0.362,0.565,1.176,8.162
Y1,0.618,2.077,-0.088,0.764,1.688,6.83
Y2,1.225,1.092,0.248,0.602,3.562,3.433

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,1.415,2.973,-1.877,0.429,2.74,9.858
Y1,1.747,2.748,-1.982,0.513,3.661,8.973
Y2,2.116,2.191,-0.801,-0.538,5.71,5.916


In [208]:
# effect
est.effect(X_test)

array([[-0.05255946,  0.11191726,  0.19139689],
       [ 0.29815516,  0.41575269,  0.31700538],
       [ 0.30349351,  0.50579208,  0.76413994],
       [-0.39019954, -0.20697011,  1.01174308],
       [-0.32665243,  0.05759803,  0.85135042],
       [ 0.79754271,  1.08327032,  1.84181676],
       [ 0.77303061,  0.87544772,  1.16206767],
       [ 0.14075585,  0.59696974,  2.58072076],
       [ 1.36307697,  1.77185327,  2.59507551],
       [ 0.94667195,  1.58624366,  4.35277888]])

In [209]:
est.effect_interval(X_test)

(array([[-0.43516324, -0.36682777, -0.31407796],
        [ 0.07126749,  0.13185235,  0.01725401],
        [-0.23924279, -0.17332382,  0.04710691],
        [-2.26495065, -2.55281152, -1.46507389],
        [-2.25757316, -2.35852739, -1.69967474],
        [-0.46426495, -0.49560629,  0.1747866 ],
        [-0.67617062, -0.937911  , -0.75253646],
        [-3.14304199, -3.5119858 , -1.75765045],
        [-0.58902638, -0.67077758,  0.01606504],
        [-3.53591301, -4.02273205, -1.56936325]]),
 array([[ 0.33004432,  0.5906623 ,  0.69687174],
        [ 0.52504282,  0.69965303,  0.61675675],
        [ 0.84622981,  1.18490798,  1.48117298],
        [ 1.48455156,  2.1388713 ,  3.48856004],
        [ 1.6042683 ,  2.47372345,  3.40237557],
        [ 2.05935038,  2.66214692,  3.50884691],
        [ 2.22223184,  2.68880644,  3.0766718 ],
        [ 3.42455369,  4.70592529,  6.91909196],
        [ 3.31518033,  4.21448412,  5.17408597],
        [ 5.42925691,  7.19521937, 10.274921  ]]))

In [210]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,-0.053,0.233,-0.226,0.821,-0.435,0.33
0,Y1,0.112,0.291,0.385,0.701,-0.367,0.591
0,Y2,0.191,0.307,0.623,0.533,-0.314,0.697
1,Y0,0.298,0.138,2.162,0.031,0.071,0.525
1,Y1,0.416,0.173,2.409,0.016,0.132,0.7
1,Y2,0.317,0.182,1.74,0.082,0.017,0.617
2,Y0,0.303,0.33,0.92,0.358,-0.239,0.846
2,Y1,0.506,0.413,1.225,0.221,-0.173,1.185
2,Y2,0.764,0.436,1.753,0.08,0.047,1.481
3,Y0,-0.39,1.14,-0.342,0.732,-2.265,1.485


In [211]:
est.effect_inference(X_test).population_summary().print()

0,1
,mean_point
,T0
Y0,0.385
Y1,0.68
Y2,1.567

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,1.306,0.295,0.768,-1.762,2.533
Y1,1.634,0.416,0.677,-2.007,3.367
Y2,1.725,0.908,0.364,-1.27,4.404

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,0.545,-0.362,1.176
Y1,0.618,-0.088,1.688
Y2,1.225,0.248,3.562

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.415,-1.873,2.74
Y1,1.747,-1.978,3.661
Y2,2.116,-0.807,5.71
