In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
## Ignore warnings
import warnings
warnings.filterwarnings('ignore') 

In [3]:
# Main imports
from econml.dml import DMLCateEstimator, LinearDMLCateEstimator,SparseLinearDMLCateEstimator, ForestDMLCateEstimator

# Helper imports
import numpy as np
from itertools import product
from sklearn.linear_model import Lasso, LassoCV, LogisticRegression, LogisticRegressionCV,LinearRegression,MultiTaskElasticNet,MultiTaskElasticNetCV
from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from sklearn.preprocessing import PolynomialFeatures
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split

%matplotlib inline

In [4]:
import statsmodels.api as sm
Y = np.array([1,3,4,5,2,3,4])
X = range(1,8)
X = sm.add_constant(X)
model = sm.OLS(Y,X)
results = model.fit(use_t=False)

In [5]:
results.summary()



0,1,2,3
Dep. Variable:,y,R-squared:,0.161
Model:,OLS,Adj. R-squared:,-0.007
Method:,Least Squares,F-statistic:,0.9608
Date:,"Tue, 17 Dec 2019",Prob (F-statistic):,0.372
Time:,14:05:35,Log-Likelihood:,-10.854
No. Observations:,7,AIC:,25.71
Df Residuals:,5,BIC:,25.6
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,z,P>|z|,[0.025,0.975]
const,2.1429,1.141,1.879,0.060,-0.093,4.378
x1,0.2500,0.255,0.980,0.327,-0.250,0.750

0,1,2,3
Omnibus:,,Durbin-Watson:,1.743
Prob(Omnibus):,,Jarque-Bera (JB):,0.482
Skew:,0.206,Prob(JB):,0.786
Kurtosis:,1.782,Cond. No.,10.4


In [6]:
from statsmodels.regression.linear_model import RegressionResults,PredictionResults

In [7]:
results.get_prediction().summary_frame()

Unnamed: 0,mean,mean_se,mean_ci_lower,mean_ci_upper,obs_ci_lower,obs_ci_upper
0,2.392857,0.9196,0.590475,4.195239,-0.808006,5.59372
1,2.642857,0.721393,1.228952,4.056762,-0.356488,5.642202
2,2.892857,0.570311,1.775067,4.010647,0.021203,5.764511
3,3.142857,0.510102,2.143076,4.142639,0.315048,5.970667
4,3.392857,0.570311,2.275067,4.510647,0.521203,6.264511
5,3.642857,0.721393,2.228952,5.056762,0.643512,6.642202
6,3.892857,0.9196,2.090475,5.695239,0.691994,7.09372


### DML inferences
1. LinearDML
2. SparseLinearDML
3. KernelDML (no inference)
4. ForestDML


##  write comprehensive DML dgps

dgp includes:
1. single Y, continuous single T
2. single Y, continuous multi T
3. single Y, discrete binary T
4. single Y, discrete multi T
5. multi Y, continuous single T
6. multi Y, continuous multi T
7. multi Y, discrete binary T
8. multi Y, discrete multi T

and also test X is None, W is None

### 1. single Y, continuous single T

In [8]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
def epsilon_sample(n): return np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
def eta_sample(n): return np.random.uniform(-1, 1, size=n)
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [9]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 1), (1000, 30), (10, 1))

In [10]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f9256b84a8>

In [11]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [12]:
"""
##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor())
est.fit(Y, T, X, W,inference='blb')
"""

"\n##forest dml\nest = ForestDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor())\nest.fit(Y, T, X, W,inference='blb')\n"

In [13]:
# constant marginal effect
est.const_marginal_effect(X_test)

array([0.5152919 , 1.04677487, 1.57825784, 2.10974081, 2.64122379,
       3.17270676, 3.70418973, 4.23567271, 4.76715568, 5.29863865])

In [14]:
est.const_marginal_effect_interval(X_test)

(array([0.24859722, 0.82643007, 1.39771144, 1.95722338, 2.4978867 ,
        3.01634542, 3.51718277, 4.00738838, 4.49170104, 4.97266907]),
 array([0.78198658, 1.26711967, 1.75880425, 2.26225825, 2.78456088,
        3.3290681 , 3.8911967 , 4.46395703, 5.04261031, 5.62460823]))

In [15]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.515,0.162,3.178,0.001,0.249,0.782
1,1.047,0.134,7.814,0.0,0.826,1.267
2,1.578,0.11,14.379,0.0,1.398,1.759
3,2.11,0.093,22.753,0.0,1.957,2.262
4,2.641,0.087,30.309,0.0,2.498,2.785
5,3.173,0.095,33.376,0.0,3.016,3.329
6,3.704,0.114,32.581,0.0,3.517,3.891
7,4.236,0.139,30.519,0.0,4.007,4.464
8,4.767,0.167,28.467,0.0,4.492,5.043
9,5.299,0.198,26.737,0.0,4.973,5.625


In [16]:
#effect
est.effect(X_test)

array([0.5152919 , 1.04677487, 1.57825784, 2.10974081, 2.64122379,
       3.17270676, 3.70418973, 4.23567271, 4.76715568, 5.29863865])

In [17]:
est.effect_interval(X_test)

(array([0.24859722, 0.82643007, 1.39771144, 1.95722338, 2.4978867 ,
        3.01634542, 3.51718277, 4.00738838, 4.49170104, 4.97266907]),
 array([0.78198658, 1.26711967, 1.75880425, 2.26225825, 2.78456088,
        3.3290681 , 3.8911967 , 4.46395703, 5.04261031, 5.62460823]))

In [18]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.515,0.162,3.178,0.001,0.249,0.782
1,1.047,0.134,7.814,0.0,0.826,1.267
2,1.578,0.11,14.379,0.0,1.398,1.759
3,2.11,0.093,22.753,0.0,1.957,2.262
4,2.641,0.087,30.309,0.0,2.498,2.785
5,3.173,0.095,33.376,0.0,3.016,3.329
6,3.704,0.114,32.581,0.0,3.517,3.891
7,4.236,0.139,30.519,0.0,4.007,4.464
8,4.767,0.167,28.467,0.0,4.492,5.043
9,5.299,0.198,26.737,0.0,4.973,5.625


In [19]:
# marginal effect
est.marginal_effect_inference(0,X_test)

AttributeError: The treatment effect is linear, please call const_marginal_effect_inference!

In [20]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1
,mean_point
,T0
Y0,2.907

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.135,21.606,0.0,2.686,3.128

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.527,0.754,5.059

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.532,0.516,5.303


### 2. single Y, continuous multi T

In [21]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
n_t=3
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
def epsilon_sample(n): return np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=(support_size,n_t))
def eta_sample(n): return np.random.uniform(-1, 1, size=(n,n_t))
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE1 = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
TE2 = np.array([2*x_i for x_i in X]).flatten()
TE3 = np.array([x_i**2 for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = TE1 * T[:,0]+TE2 * T[:,1]+TE3 * T[:,2] + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [22]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000, 3), (1000, 1), (1000, 30), (10, 1))

In [23]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f92dd22d30>

In [24]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [25]:
est.const_marginal_effect(X_test)

array([[ 2.20634957e-01,  1.52120923e-01, -1.53346543e-01],
       [ 7.63344177e-01,  3.48416649e-01,  5.03432862e-03],
       [ 1.30605340e+00,  5.44712374e-01,  1.63415201e-01],
       [ 1.84876262e+00,  7.41008100e-01,  3.21796073e-01],
       [ 2.39147184e+00,  9.37303825e-01,  4.80176945e-01],
       [ 2.93418106e+00,  1.13359955e+00,  6.38557817e-01],
       [ 3.47689028e+00,  1.32989528e+00,  7.96938689e-01],
       [ 4.01959950e+00,  1.52619100e+00,  9.55319561e-01],
       [ 4.56230872e+00,  1.72248673e+00,  1.11370043e+00],
       [ 5.10501794e+00,  1.91878245e+00,  1.27208131e+00]])

In [26]:
est.const_marginal_effect_interval(X_test)

(array([[-0.08724358, -0.17763622, -0.45403169],
        [ 0.51876479,  0.08919592, -0.23163206],
        [ 1.1133223 ,  0.34391112, -0.01878256],
        [ 1.68515395,  0.57332669,  0.17358441],
        [ 2.22210055,  0.76242354,  0.33076813],
        [ 2.72705358,  0.91515202,  0.45344996],
        [ 3.21342629,  1.0478923 ,  0.55653804],
        [ 3.69062979,  1.17122313,  0.65051528],
        [ 4.16315313,  1.28987969,  0.74006364],
        [ 4.63307978,  1.40598069,  0.82723427]]),
 array([[0.52851349, 0.48187807, 0.1473386 ],
        [1.00792356, 0.60763737, 0.24170071],
        [1.49878449, 0.74551363, 0.34561296],
        [2.01237128, 0.90868951, 0.47000773],
        [2.56084313, 1.11218411, 0.62958576],
        [3.14130854, 1.35204708, 0.82366568],
        [3.74035426, 1.61189826, 1.03733934],
        [4.34856921, 1.88115887, 1.26012384],
        [4.96146431, 2.15509376, 1.48733723],
        [5.5769561 , 2.43158422, 1.71692834]]))

In [27]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,point_estimate,point_estimate,point_estimate,stderr,stderr,stderr,zstat,zstat,zstat,pvalue,pvalue,pvalue,ci_lower,ci_lower,ci_lower,ci_upper,ci_upper,ci_upper
Unnamed: 0_level_1,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2
0,0.221,0.152,-0.153,0.187,0.2,0.183,1.179,0.759,-0.839,0.238,0.448,0.402,-0.087,-0.178,-0.454,0.529,0.482,0.147
1,0.763,0.348,0.005,0.149,0.158,0.144,5.134,2.211,0.035,0.0,0.027,0.972,0.519,0.089,-0.232,1.008,0.608,0.242
2,1.306,0.545,0.163,0.117,0.122,0.111,11.146,4.462,1.475,0.0,0.0,0.14,1.113,0.344,-0.019,1.499,0.746,0.346
3,1.849,0.741,0.322,0.099,0.102,0.09,18.587,7.269,3.571,0.0,0.0,0.0,1.685,0.573,0.174,2.012,0.909,0.47
4,2.391,0.937,0.48,0.103,0.106,0.091,23.225,8.816,5.286,0.0,0.0,0.0,2.222,0.762,0.331,2.561,1.112,0.63
5,2.934,1.134,0.639,0.126,0.133,0.113,23.301,8.536,5.674,0.0,0.0,0.0,2.727,0.915,0.453,3.141,1.352,0.824
6,3.477,1.33,0.797,0.16,0.171,0.146,21.707,7.757,5.453,0.0,0.0,0.0,3.213,1.048,0.557,3.74,1.612,1.037
7,4.02,1.526,0.955,0.2,0.216,0.185,20.098,7.072,5.155,0.0,0.0,0.0,3.691,1.171,0.651,4.349,1.881,1.26
8,4.562,1.722,1.114,0.243,0.263,0.227,18.801,6.549,4.903,0.0,0.0,0.0,4.163,1.29,0.74,4.961,2.155,1.487
9,5.105,1.919,1.272,0.287,0.312,0.27,17.793,6.155,4.704,0.0,0.0,0.0,4.633,1.406,0.827,5.577,2.432,1.717


In [28]:
#effect
est.effect(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


array([0.21940934, 1.11679515, 2.01418097, 2.91156679, 3.80895261,
       4.70633843, 5.60372424, 6.50111006, 7.39849588, 8.2958817 ])

In [29]:
est.effect_interval(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


(array([-0.24638718,  0.74542395,  1.72343979,  2.67322255,  3.57496712,
         4.42642035,  5.24648101,  6.05104893,  6.84790439,  7.64058737]),
 array([0.68520586, 1.48816636, 2.30492215, 3.14991103, 4.0429381 ,
        4.9862565 , 5.96096748, 6.95117119, 7.94908738, 8.95117603]))

In [30]:
est.effect_inference(X_test).summary_frame()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.219,0.283,0.775,0.438,-0.246,0.685
1,1.117,0.226,4.946,0.0,0.745,1.488
2,2.014,0.177,11.395,0.0,1.723,2.305
3,2.912,0.145,20.093,0.0,2.673,3.15
4,3.809,0.142,26.776,0.0,3.575,4.043
5,4.706,0.17,27.655,0.0,4.426,4.986
6,5.604,0.217,25.801,0.0,5.246,5.961
7,6.501,0.274,23.76,0.0,6.051,6.951
8,7.398,0.335,22.102,0.0,6.848,7.949
9,8.296,0.398,20.823,0.0,7.641,8.951


In [31]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2,3
,mean_point,mean_point,mean_point
,T0,T1,T2
Y0,2.663,1.035,0.559

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
,stderr_mean,stderr_mean,stderr_mean,zstat,zstat,zstat,pvalue,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper,ci_mean_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,0.177,0.19,0.166,15.025,5.445,3.368,0.0,0.0,0.001,2.371,0.723,0.286,2.954,1.348,0.833

0,1,2,3,4,5,6,7,8,9
,std_point,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper,pct_point_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,1.559,0.564,0.455,0.465,0.24,-0.082,4.861,1.83,1.201

0,1,2,3,4,5,6,7,8,9
,std_point,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper,ci_point_upper
,T0,T1,T2,T0,T1,T2,T0,T1,T2
Y0,1.569,0.595,0.484,0.217,0.116,-0.191,5.114,2.024,1.371


### 3. single Y, discrete binary T

In [32]:
# Treatment effect function
def exp_te(x):
    return np.exp(2 * x[0])# DGP constants

np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n) 

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([exp_te(x_i) for x_i in X])
# Define treatment
log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
T_sigmoid = 1/(1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
# Define the outcome
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [33]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 4), (1000, 30), (10, 4))

In [34]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')
te_pred = est.effect(X_test)

In [35]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123,discrete_treatment=True)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [36]:
"""
##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(min_samples_leaf=10),
                              discrete_treatment=True)
est.fit(Y, T, X, W,inference='blb')
"""

"\n##forest dml\nest = ForestDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(min_samples_leaf=10),\n                              discrete_treatment=True)\nest.fit(Y, T, X, W,inference='blb')\n"

In [37]:
#constant marginal effect
est.const_marginal_effect(X_test)

array([[0.48311844],
       [1.12617478],
       [1.61578661],
       [2.1684037 ],
       [2.72794059],
       [3.11286489],
       [3.93521029],
       [4.28283199],
       [4.91106624],
       [5.54323147]])

In [38]:
est.const_marginal_effect_interval(X_test)

(array([[0.18358764],
        [0.8225674 ],
        [1.23970621],
        [1.7781744 ],
        [2.4034647 ],
        [2.79738461],
        [3.62213802],
        [3.91454028],
        [4.56661374],
        [5.15550259]]), array([[0.78264924],
        [1.42978217],
        [1.991867  ],
        [2.558633  ],
        [3.05241649],
        [3.42834516],
        [4.24828256],
        [4.65112371],
        [5.25551874],
        [5.93096036]]))

In [39]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.483,0.182,2.653,0.008,0.184,0.783
1,1.126,0.185,6.101,0.0,0.823,1.43
2,1.616,0.229,7.067,0.0,1.24,1.992
3,2.168,0.237,9.14,0.0,1.778,2.559
4,2.728,0.197,13.829,0.0,2.403,3.052
5,3.113,0.192,16.23,0.0,2.797,3.428
6,3.935,0.19,20.675,0.0,3.622,4.248
7,4.283,0.224,19.128,0.0,3.915,4.651
8,4.911,0.209,23.452,0.0,4.567,5.256
9,5.543,0.236,23.516,0.0,5.156,5.931


In [40]:
est.const_marginal_effect_inference(X_test).population_summary()

<econml.inference.PopulationSummaryResults at 0x1f92dd60e48>

In [41]:
# effect
est.effect(X_test)

array([0.48311844, 1.12617478, 1.61578661, 2.1684037 , 2.72794059,
       3.11286489, 3.93521029, 4.28283199, 4.91106624, 5.54323147])

In [42]:
est.effect_interval(X_test)

(array([0.18358764, 0.8225674 , 1.23970621, 1.7781744 , 2.4034647 ,
        2.79738461, 3.62213802, 3.91454028, 4.56661374, 5.15550259]),
 array([0.78264924, 1.42978217, 1.991867  , 2.558633  , 3.05241649,
        3.42834516, 4.24828256, 4.65112371, 5.25551874, 5.93096036]))

In [43]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.483,0.182,2.653,0.008,0.184,0.783
1,1.126,0.185,6.101,0.0,0.823,1.43
2,1.616,0.229,7.067,0.0,1.24,1.992
3,2.168,0.237,9.14,0.0,1.778,2.559
4,2.728,0.197,13.829,0.0,2.403,3.052
5,3.113,0.192,16.23,0.0,2.797,3.428
6,3.935,0.19,20.675,0.0,3.622,4.248
7,4.283,0.224,19.128,0.0,3.915,4.651
8,4.911,0.209,23.452,0.0,4.567,5.256
9,5.543,0.236,23.516,0.0,5.156,5.931


In [67]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2
,mean_point,mean_point
,T0,T1
Y0,0.967,1.7

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.165,0.171,5.869,9.948,0.0,0.0,0.696,1.419,1.238,1.981

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.501,1.064,0.277,0.193,1.705,3.161

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.528,1.078,0.149,0.07,1.886,3.376


### 4. single Y, discrete multi T

In [58]:
np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=n)

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([2*x_i[0] for x_i in X])
# Define treatment
T =np.random.choice(3, n)
# Define the outcome
Y =TE*T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [59]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000,), (1000,), (1000, 4), (1000, 30), (10, 4))

In [60]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f92dd222e8>

In [61]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123,discrete_treatment=True)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [62]:
est.const_marginal_effect(X_test)

array([[0.1885573 , 0.11515519],
       [0.38461869, 0.28878255],
       [0.46915619, 0.67774055],
       [0.80683188, 1.31906974],
       [0.99504813, 1.57615415],
       [0.92202081, 1.73089308],
       [1.22284881, 2.31128693],
       [1.31207556, 2.7030477 ],
       [1.49136171, 2.90147911],
       [1.87912734, 3.37264014]])

In [63]:
est.const_marginal_effect_interval(X_test)

(array([[-0.04827697, -0.13059248],
        [ 0.07588062, -0.01784021],
        [ 0.24107645,  0.46111956],
        [ 0.54933769,  1.04017291],
        [ 0.69481886,  1.26343529],
        [ 0.63320734,  1.43721169],
        [ 1.06913069,  2.14941516],
        [ 1.06278466,  2.44047665],
        [ 1.22943722,  2.61736893],
        [ 1.50900641,  2.98445474]]), array([[0.42539156, 0.36090287],
        [0.69335675, 0.5954053 ],
        [0.69723594, 0.89436154],
        [1.06432608, 1.59796656],
        [1.29527741, 1.88887301],
        [1.21083428, 2.02457447],
        [1.37656693, 2.47315871],
        [1.56136646, 2.96561874],
        [1.75328621, 3.18558929],
        [2.24924827, 3.76082554]]))

In [51]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,0.189,0.115,0.144,0.149,1.31,0.771,0.19,0.441,-0.048,-0.131,0.425,0.361
1,0.385,0.289,0.188,0.186,2.049,1.549,0.04,0.121,0.076,-0.018,0.693,0.595
2,0.469,0.678,0.139,0.132,3.383,5.146,0.001,0.0,0.241,0.461,0.697,0.894
3,0.807,1.319,0.157,0.17,5.154,7.779,0.0,0.0,0.549,1.04,1.064,1.598
4,0.995,1.576,0.183,0.19,5.452,8.29,0.0,0.0,0.695,1.263,1.295,1.889
5,0.922,1.731,0.176,0.179,5.251,9.694,0.0,0.0,0.633,1.437,1.211,2.025
6,1.223,2.311,0.093,0.098,13.085,23.486,0.0,0.0,1.069,2.149,1.377,2.473
7,1.312,2.703,0.152,0.16,8.657,16.933,0.0,0.0,1.063,2.44,1.561,2.966
8,1.491,2.901,0.159,0.173,9.366,16.798,0.0,0.0,1.229,2.617,1.753,3.186
9,1.879,3.373,0.225,0.236,8.351,14.291,0.0,0.0,1.509,2.984,2.249,3.761


In [52]:
est.const_marginal_effect_inference(X_test).population_summary()

<econml.inference.PopulationSummaryResults at 0x1f92e5e4f28>

In [53]:
# effect
est.effect(X_test)

array([0.1885573 , 0.38461869, 0.46915619, 0.80683188, 0.99504813,
       0.92202081, 1.22284881, 1.31207556, 1.49136171, 1.87912734])

In [54]:
est.effect_interval(X_test)

(array([-0.04827697,  0.07588062,  0.24107645,  0.54933769,  0.69481886,
         0.63320734,  1.06913069,  1.06278466,  1.22943722,  1.50900641]),
 array([0.42539156, 0.69335675, 0.69723594, 1.06432608, 1.29527741,
        1.21083428, 1.37656693, 1.56136646, 1.75328621, 2.24924827]))

In [65]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,0.189,0.144,1.31,0.19,-0.048,0.425
1,0.385,0.188,2.049,0.04,0.076,0.693
2,0.469,0.139,3.383,0.001,0.241,0.697
3,0.807,0.157,5.154,0.0,0.549,1.064
4,0.995,0.183,5.452,0.0,0.695,1.295
5,0.922,0.176,5.251,0.0,0.633,1.211
6,1.223,0.093,13.085,0.0,1.069,1.377
7,1.312,0.152,8.657,0.0,1.063,1.561
8,1.491,0.159,9.366,0.0,1.229,1.753
9,1.879,0.225,8.351,0.0,1.509,2.249


In [66]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2
,mean_point,mean_point
,T0,T1
Y0,0.967,1.7

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.165,0.171,5.869,9.948,0.0,0.0,0.696,1.419,1.238,1.981

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.501,1.064,0.277,0.193,1.705,3.161

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.528,1.078,0.149,0.07,1.886,3.376


### 5. multi Y, continuous single T

In [68]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
def epsilon_sample(n): return np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
def eta_sample(n): return np.random.uniform(-1, 1, size=n)
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = (TE * T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [69]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000,), (1000, 1), (1000, 30), (10, 1))

In [70]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f92ee8a400>

In [71]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [72]:
"""
##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor())
est.fit(Y, T, X, W,inference='blb')
"""

"\n##forest dml\nest = ForestDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor())\nest.fit(Y, T, X, W,inference='blb')\n"

In [73]:
est.const_marginal_effect(X_test)

array([[0.21790203, 0.59645789, 0.32154596],
       [0.75333562, 1.1113758 , 0.86153439],
       [1.28876921, 1.62629371, 1.40152282],
       [1.8242028 , 2.14121162, 1.94151124],
       [2.35963639, 2.65612953, 2.48149967],
       [2.89506998, 3.17104745, 3.0214881 ],
       [3.43050358, 3.68596536, 3.56147653],
       [3.96593717, 4.20088327, 4.10146496],
       [4.50137076, 4.71580118, 4.64145338],
       [5.03680435, 5.23071909, 5.18144181]])

In [74]:
est.const_marginal_effect_interval(X_test)

(array([[-0.02271208,  0.31594049,  0.05978908],
        [ 0.55970617,  0.88441206,  0.64809086],
        [ 1.13599985,  1.44637582,  1.23012231],
        [ 1.69997122,  1.9953907 ,  1.8001725 ],
        [ 2.24225926,  2.52125024,  2.34976635],
        [ 2.75953329,  3.01887488,  2.87482791],
        [ 3.2595887 ,  3.4958182 ,  3.38135637],
        [ 3.75075703,  3.96173733,  3.87750994],
        [ 4.23747308,  4.42209894,  4.36823746],
        [ 4.72179582,  4.87948325,  4.85599358]]),
 array([[0.45851613, 0.87697529, 0.58330284],
        [0.94696506, 1.33833954, 1.07497792],
        [1.44153856, 1.8062116 , 1.57292332],
        [1.94843438, 2.28703255, 2.08284999],
        [2.47701353, 2.79100883, 2.61323299],
        [3.03060668, 3.32322001, 3.16814829],
        [3.60141845, 3.87611252, 3.74159668],
        [4.1811173 , 4.44002921, 4.32541997],
        [4.76526844, 5.00950342, 4.91466931],
        [5.35181288, 5.58195493, 5.50689004]]))

In [78]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,0.218,0.146,1.49,0.136,-0.023,0.459
0,Y1,0.596,0.171,3.497,0.0,0.316,0.877
0,Y2,0.322,0.159,2.021,0.043,0.06,0.583
1,Y0,0.753,0.118,6.399,0.0,0.56,0.947
1,Y1,1.111,0.138,8.054,0.0,0.884,1.338
1,Y2,0.862,0.13,6.639,0.0,0.648,1.075
2,Y0,1.289,0.093,13.876,0.0,1.136,1.442
2,Y1,1.626,0.109,14.868,0.0,1.446,1.806
2,Y2,1.402,0.104,13.45,0.0,1.23,1.573
3,Y0,1.824,0.076,24.153,0.0,1.7,1.948


In [71]:
# effect
est.effect(X_test)

array([[0.21790203, 0.59645789, 0.32154596],
       [0.75333562, 1.1113758 , 0.86153439],
       [1.28876921, 1.62629371, 1.40152282],
       [1.8242028 , 2.14121162, 1.94151124],
       [2.35963639, 2.65612953, 2.48149967],
       [2.89506998, 3.17104745, 3.0214881 ],
       [3.43050358, 3.68596536, 3.56147653],
       [3.96593717, 4.20088327, 4.10146496],
       [4.50137076, 4.71580118, 4.64145338],
       [5.03680435, 5.23071909, 5.18144181]])

In [72]:
est.effect_interval(X_test)

(array([[-0.02271208,  0.31594049,  0.05978908],
        [ 0.55970617,  0.88441206,  0.64809086],
        [ 1.13599985,  1.44637582,  1.23012231],
        [ 1.69997122,  1.9953907 ,  1.8001725 ],
        [ 2.24225926,  2.52125024,  2.34976635],
        [ 2.75953329,  3.01887488,  2.87482791],
        [ 3.2595887 ,  3.4958182 ,  3.38135637],
        [ 3.75075703,  3.96173733,  3.87750994],
        [ 4.23747308,  4.42209894,  4.36823746],
        [ 4.72179582,  4.87948325,  4.85599358]]),
 array([[0.45851613, 0.87697529, 0.58330284],
        [0.94696506, 1.33833954, 1.07497792],
        [1.44153856, 1.8062116 , 1.57292332],
        [1.94843438, 2.28703255, 2.08284999],
        [2.47701353, 2.79100883, 2.61323299],
        [3.03060668, 3.32322001, 3.16814829],
        [3.60141845, 3.87611252, 3.74159668],
        [4.1811173 , 4.44002921, 4.32541997],
        [4.76526844, 5.00950342, 4.91466931],
        [5.35181288, 5.58195493, 5.50689004]]))

In [74]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,0.218,0.146,1.49,0.136,-0.023,0.459
0,Y1,0.596,0.171,3.497,0.0,0.316,0.877
0,Y2,0.322,0.159,2.021,0.043,0.06,0.583
1,Y0,0.753,0.118,6.399,0.0,0.56,0.947
1,Y1,1.111,0.138,8.054,0.0,0.884,1.338
1,Y2,0.862,0.13,6.639,0.0,0.648,1.075
2,Y0,1.289,0.093,13.876,0.0,1.136,1.442
2,Y1,1.626,0.109,14.868,0.0,1.446,1.806
2,Y2,1.402,0.104,13.45,0.0,1.23,1.573
3,Y0,1.824,0.076,24.153,0.0,1.7,1.948


In [77]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1
,mean_point
,T0
Y0,2.627
Y1,2.914
Y2,2.751

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.123,21.325,0.0,2.425,2.83
Y1,0.14,20.853,0.0,2.684,3.143
Y2,0.131,20.978,0.0,2.536,2.967

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.538,0.459,4.796
Y1,1.479,0.828,4.999
Y2,1.551,0.565,4.938

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.543,0.216,5.041
Y1,1.486,0.594,5.228
Y2,1.557,0.32,5.185


### 6. multi Y, continuous multi T

In [79]:
n = 1000
n_w = 30
support_size = 5
n_x = 1
n_t=2
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
def epsilon_sample(n): return np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=(support_size,n_t))
def eta_sample(n): return np.random.uniform(-1, 1, size=(n,n_t))
# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE1 = np.array([np.exp(2 * x_i) for x_i in X]).flatten()
TE2 = np.array([2*x_i for x_i in X]).flatten()
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = (TE1 * T[:,0]).reshape(-1,1)+(TE2 * T[:,1]).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)
X_test = np.array(list(product(np.arange(0, 1, 0.1), repeat=n_x)))

In [80]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000, 2), (1000, 1), (1000, 30), (10, 1))

In [81]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestRegressor(),
                             random_state=123)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f92eed66a0>

In [82]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestRegressor(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestRegressor(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [83]:
est.const_marginal_effect(X_test)

array([[[0.12803405, 0.11935974],
        [0.37315658, 0.50213732],
        [0.24461697, 0.25465818]],

       [[0.65871914, 0.29671327],
        [0.90320195, 0.66217789],
        [0.77131113, 0.42520059]],

       [[1.18940423, 0.4740668 ],
        [1.43324731, 0.82221846],
        [1.29800529, 0.595743  ]],

       [[1.72008931, 0.65142032],
        [1.96329267, 0.98225903],
        [1.82469945, 0.76628541]],

       [[2.2507744 , 0.82877385],
        [2.49333803, 1.1422996 ],
        [2.3513936 , 0.93682781]],

       [[2.78145949, 1.00612737],
        [3.02338339, 1.30234017],
        [2.87808776, 1.10737022]],

       [[3.31214457, 1.1834809 ],
        [3.55342875, 1.46238074],
        [3.40478192, 1.27791263]],

       [[3.84282966, 1.36083442],
        [4.08347411, 1.62242131],
        [3.93147608, 1.44845504]],

       [[4.37351474, 1.53818795],
        [4.61351947, 1.78246187],
        [4.45817024, 1.61899745]],

       [[4.90419983, 1.71554148],
        [5.14356483, 1.9425024

In [84]:
est.const_marginal_effect_interval(X_test)

(array([[[-0.12855737, -0.12086662],
         [ 0.08532946,  0.23781172],
         [-0.02386447, -0.00534645]],
 
        [[ 0.45557539,  0.10674171],
         [ 0.67225497,  0.45081239],
         [ 0.55701256,  0.21750939]],
 
        [[ 1.02838101,  0.32423474],
         [ 1.24715039,  0.65288374],
         [ 1.12612427,  0.43014961]],
 
        [[ 1.5792971 ,  0.52184384],
         [ 1.7997961 ,  0.83439206],
         [ 1.67324021,  0.62325041]],
 
        [[ 2.09928417,  0.69053331],
         [ 2.3211936 ,  0.98658655],
         [ 2.18979432,  0.78766631]],
 
        [[ 2.59355258,  0.83463162],
         [ 2.81520129,  1.11307848],
         [ 2.68043562,  0.9262857 ]],
 
        [[ 3.07360402,  0.96509796],
         [ 3.29294961,  1.22450484],
         [ 3.15619386,  1.04968827]],
 
        [[ 3.54664212,  1.08889439],
         [ 3.76228494,  1.32824208],
         [ 3.62438768,  1.16537657]],
 
        [[ 4.01604391,  1.20926299],
         [ 4.22715294,  1.42793485],
         [ 4.0

In [85]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,Unnamed: 1_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,Y0,0.128,0.119,0.156,0.146,0.821,0.817,0.412,0.414,-0.129,-0.121,0.385,0.36
0,Y1,0.373,0.502,0.175,0.161,2.132,3.125,0.033,0.002,0.085,0.238,0.661,0.766
0,Y2,0.245,0.255,0.163,0.158,1.499,1.611,0.134,0.107,-0.024,-0.005,0.513,0.515
1,Y0,0.659,0.297,0.124,0.115,5.334,2.569,0.0,0.01,0.456,0.107,0.862,0.487
1,Y1,0.903,0.662,0.14,0.129,6.433,5.153,0.0,0.0,0.672,0.451,1.134,0.874
1,Y2,0.771,0.425,0.13,0.126,5.92,3.367,0.0,0.001,0.557,0.218,0.986,0.633
2,Y0,1.189,0.474,0.098,0.091,12.15,5.204,0.0,0.0,1.028,0.324,1.35,0.624
2,Y1,1.433,0.822,0.113,0.103,12.668,7.987,0.0,0.0,1.247,0.653,1.619,0.992
2,Y2,1.298,0.596,0.104,0.101,12.422,5.918,0.0,0.0,1.126,0.43,1.47,0.761
3,Y0,1.72,0.651,0.086,0.079,20.096,8.269,0.0,0.0,1.579,0.522,1.861,0.781


In [86]:
# effect
est.effect(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


array([[0.2473938 , 0.87529391, 0.49927515],
       [0.95543241, 1.56537984, 1.19651172],
       [1.66347102, 2.25546576, 1.89374829],
       [2.37150964, 2.94555169, 2.59098485],
       [3.07954825, 3.63563762, 3.28822142],
       [3.78758686, 4.32572355, 3.98545799],
       [4.49562547, 5.01580948, 4.68269455],
       [5.20366408, 5.70589541, 5.37993112],
       [5.9117027 , 6.39598134, 6.07716768],
       [6.61974131, 7.08606727, 6.77440425]])

In [87]:
est.effect_interval(X_test)

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


(array([[-0.04829345,  0.54733268,  0.18342229],
        [ 0.7224818 ,  1.30438655,  0.94512885],
        [ 1.48349883,  2.05071212,  1.69696718],
        [ 2.22385136,  2.77532173,  2.42865371],
        [ 2.92953188,  3.46452345,  3.12673108],
        [ 3.60185634,  4.11877064,  3.79076336],
        [ 4.25525969,  4.75194098,  4.43403413],
        [ 4.8997816 ,  5.37472885,  5.06710985],
        [ 5.53994688,  5.9922938 ,  5.69509253],
        [ 6.17775807,  6.60700193,  6.32030642]]),
 array([[0.54308105, 1.20325513, 0.81512802],
        [1.18838302, 1.82637312, 1.4478946 ],
        [1.84344322, 2.4602194 , 2.0905294 ],
        [2.51916791, 3.11578166, 2.75331599],
        [3.22956462, 3.80675179, 3.44971176],
        [3.97331738, 4.53267647, 4.18015261],
        [4.73599125, 5.27967799, 4.93135497],
        [5.50754657, 6.03706198, 5.69275238],
        [6.28345851, 6.79966889, 6.45924284],
        [7.06172454, 7.56513262, 7.22850208]]))

In [88]:
est.effect_inference(X_test).summary_frame()

  warn("A scalar was specified but there are multiple treatments; "
  warn("A scalar was specified but there are multiple treatments; "


Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,0.247,0.18,1.376,0.169,-0.048,0.543
0,Y1,0.875,0.199,4.39,0.0,0.547,1.203
0,Y2,0.499,0.192,2.6,0.009,0.183,0.815
1,Y0,0.955,0.142,6.746,0.0,0.722,1.188
1,Y1,1.565,0.159,9.865,0.0,1.304,1.826
1,Y2,1.197,0.153,7.829,0.0,0.945,1.448
2,Y0,1.663,0.109,15.203,0.0,1.483,1.843
2,Y1,2.255,0.124,18.119,0.0,2.051,2.46
2,Y2,1.894,0.12,15.829,0.0,1.697,2.091
3,Y0,2.372,0.09,26.418,0.0,2.224,2.519


In [89]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2
,mean_point,mean_point
,T0,T1
Y0,2.516,0.917
Y1,2.758,1.222
Y2,2.615,1.022

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.156,0.144,16.102,6.367,0.0,0.0,2.259,0.68,2.773,1.154
Y1,0.172,0.157,16.071,7.779,0.0,0.0,2.476,0.964,3.041,1.481
Y2,0.163,0.152,16.062,6.723,0.0,0.0,2.347,0.772,2.883,1.272

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,1.524,0.509,0.367,0.199,4.665,1.636
Y1,1.522,0.46,0.612,0.574,4.905,1.87
Y2,1.513,0.49,0.482,0.331,4.748,1.713

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,1.532,0.529,0.126,0.102,4.905,1.785
Y1,1.532,0.486,0.375,0.474,5.149,2.031
Y2,1.522,0.513,0.246,0.23,4.993,1.867


### 7. multi Y, discrete binary T

In [90]:
# Treatment effect function
def exp_te(x):
    return np.exp(2 * x[0])# DGP constants

np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=(n,3))
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n) 

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([exp_te(x_i) for x_i in X])
# Define treatment
log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
T_sigmoid = 1/(1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])
# Define the outcome
Y = (TE * T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [91]:
Y.shape,T.shape,X.shape,W.shape,X_test.shape

((1000, 3), (1000,), (1000, 4), (1000, 30), (10, 4))

In [92]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')
te_pred = est.effect(X_test)

In [93]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123,discrete_treatment=True)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [94]:
"""
##forest dml
est = ForestDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(min_samples_leaf=10),
                              discrete_treatment=True)
est.fit(Y, T, X, W,inference='blb')
"""

"\n##forest dml\nest = ForestDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(min_samples_leaf=10),\n                              discrete_treatment=True)\nest.fit(Y, T, X, W,inference='blb')\n"

In [95]:
est.const_marginal_effect(X_test)

array([[[-0.28999832],
        [-0.13850692],
        [-0.23044226]],

       [[ 0.84285901],
        [ 1.24121914],
        [ 1.47950661]],

       [[ 1.14456611],
        [ 1.31521787],
        [ 1.31855025]],

       [[ 1.86434513],
        [ 2.08942017],
        [ 2.21609769]],

       [[ 2.5143868 ],
        [ 2.77575318],
        [ 2.93662437]],

       [[ 2.83872925],
        [ 3.12076148],
        [ 3.27235528]],

       [[ 3.61221935],
        [ 3.88602137],
        [ 4.10374909]],

       [[ 4.06851269],
        [ 4.39414637],
        [ 4.642164  ]],

       [[ 5.07501697],
        [ 5.44058721],
        [ 5.85298553]],

       [[ 5.09942359],
        [ 5.29113028],
        [ 5.49780606]]])

In [96]:
est.const_marginal_effect_interval(X_test)

(array([[[-0.65501037],
         [-0.5588186 ],
         [-0.61677974]],
 
        [[ 0.51984522],
         [ 0.86229739],
         [ 1.13310276]],
 
        [[ 0.88084829],
         [ 1.0086018 ],
         [ 1.02276538]],
 
        [[ 1.70719024],
         [ 1.89339029],
         [ 2.02975295]],
 
        [[ 2.26917184],
         [ 2.4884486 ],
         [ 2.66529647]],
 
        [[ 2.55972978],
         [ 2.80756335],
         [ 2.98769548]],
 
        [[ 3.41220706],
         [ 3.66225648],
         [ 3.89332092]],
 
        [[ 3.70993005],
         [ 3.99580542],
         [ 4.27318804]],
 
        [[ 4.70500075],
         [ 5.01347049],
         [ 5.44916974]],
 
        [[ 4.66546071],
         [ 4.79956976],
         [ 5.04771866]]]), array([[[0.07501374],
         [0.28180475],
         [0.15589522]],
 
        [[1.1658728 ],
         [1.62014089],
         [1.82591046]],
 
        [[1.40828393],
         [1.62183393],
         [1.61433513]],
 
        [[2.02150001],
         [2.

In [97]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,-0.29,0.222,-1.307,0.191,-0.655,0.075
0,Y1,-0.139,0.256,-0.542,0.588,-0.559,0.282
0,Y2,-0.23,0.235,-0.981,0.327,-0.617,0.156
1,Y0,0.843,0.196,4.292,0.0,0.52,1.166
1,Y1,1.241,0.23,5.388,0.0,0.862,1.62
1,Y2,1.48,0.211,7.025,0.0,1.133,1.826
2,Y0,1.145,0.16,7.139,0.0,0.881,1.408
2,Y1,1.315,0.186,7.056,0.0,1.009,1.622
2,Y2,1.319,0.18,7.332,0.0,1.023,1.614
3,Y0,1.864,0.096,19.513,0.0,1.707,2.022


In [99]:
# effect
est.effect(X_test)

array([[-0.28999832, -0.13850692, -0.23044226],
       [ 0.84285901,  1.24121914,  1.47950661],
       [ 1.14456611,  1.31521787,  1.31855025],
       [ 1.86434513,  2.08942017,  2.21609769],
       [ 2.5143868 ,  2.77575318,  2.93662437],
       [ 2.83872925,  3.12076148,  3.27235528],
       [ 3.61221935,  3.88602137,  4.10374909],
       [ 4.06851269,  4.39414637,  4.642164  ],
       [ 5.07501697,  5.44058721,  5.85298553],
       [ 5.09942359,  5.29113028,  5.49780606]])

In [100]:
est.effect_interval(X_test)

(array([[-0.65501037, -0.5588186 , -0.61677974],
        [ 0.51984522,  0.86229739,  1.13310276],
        [ 0.88084829,  1.0086018 ,  1.02276538],
        [ 1.70719024,  1.89339029,  2.02975295],
        [ 2.26917184,  2.4884486 ,  2.66529647],
        [ 2.55972978,  2.80756335,  2.98769548],
        [ 3.41220706,  3.66225648,  3.89332092],
        [ 3.70993005,  3.99580542,  4.27318804],
        [ 4.70500075,  5.01347049,  5.44916974],
        [ 4.66546071,  4.79956976,  5.04771866]]),
 array([[0.07501374, 0.28180475, 0.15589522],
        [1.1658728 , 1.62014089, 1.82591046],
        [1.40828393, 1.62183393, 1.61433513],
        [2.02150001, 2.28545005, 2.40244243],
        [2.75960176, 3.06305777, 3.20795227],
        [3.11772873, 3.43395962, 3.55701509],
        [3.81223165, 4.10978625, 4.31417726],
        [4.42709532, 4.79248732, 5.01113996],
        [5.44503319, 5.86770394, 6.25680132],
        [5.53338648, 5.7826908 , 5.94789345]]))

In [101]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,-0.29,0.222,-1.307,0.191,-0.655,0.075
0,Y1,-0.139,0.256,-0.542,0.588,-0.559,0.282
0,Y2,-0.23,0.235,-0.981,0.327,-0.617,0.156
1,Y0,0.843,0.196,4.292,0.0,0.52,1.166
1,Y1,1.241,0.23,5.388,0.0,0.862,1.62
1,Y2,1.48,0.211,7.025,0.0,1.133,1.826
2,Y0,1.145,0.16,7.139,0.0,0.881,1.408
2,Y1,1.315,0.186,7.056,0.0,1.009,1.622
2,Y2,1.319,0.18,7.332,0.0,1.023,1.614
3,Y0,1.864,0.096,19.513,0.0,1.707,2.022


In [102]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1
,mean_point
,T0
Y0,2.677
Y1,2.942
Y2,3.109

0,1,2,3,4,5
,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
,T0,T0,T0,T0,T0
Y0,0.189,14.185,0.0,2.367,2.987
Y1,0.216,13.598,0.0,2.586,3.297
Y2,0.201,15.474,0.0,2.778,3.439

0,1,2,3
,std_point,pct_point_lower,pct_point_upper
,T0,T0,T0
Y0,1.722,0.22,5.088
Y1,1.749,0.482,5.373
Y2,1.856,0.467,5.693

0,1,2,3
,std_point,ci_point_lower,ci_point_upper
,T0,T0,T0
Y0,1.732,-0.289,5.254
Y1,1.762,-0.134,5.565
Y2,1.867,-0.226,5.896


### 8. multi Y, discrete multi T

In [103]:
np.random.seed(1234)
n = 1000
n_w = 30
support_size = 5
n_x = 4
# Outcome support
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=(support_size,3))
epsilon_sample = lambda n:np.random.uniform(-1, 1, size=(n,3))

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([2*x_i[0] for x_i in X])
# Define treatment
T =np.random.choice(3, n)
# Define the outcome
Y =(TE*T).reshape(-1,1) + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# get testing data
X_test = np.random.uniform(0, 1, size=(10, n_x))
X_test[:, 0] = np.linspace(0, 1, 10)

In [104]:
est = LinearDMLCateEstimator(model_y=RandomForestRegressor(),
                             model_t=RandomForestClassifier(),
                             random_state=123,
                            discrete_treatment=True)
est.fit(Y, T, X, W,inference='statsmodels')

<econml.dml.LinearDMLCateEstimator at 0x1f92ef1aa20>

In [105]:
"""
##sparse linear
est = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),
                              model_t=RandomForestClassifier(),
                              featurizer=PolynomialFeatures(degree=3),
                              random_state=123,discrete_treatment=True)
est.fit(Y, T, X, W,inference='debiasedlasso')
"""

"\n##sparse linear\nest = SparseLinearDMLCateEstimator(model_y=RandomForestRegressor(),\n                              model_t=RandomForestClassifier(),\n                              featurizer=PolynomialFeatures(degree=3),\n                              random_state=123,discrete_treatment=True)\nest.fit(Y, T, X, W,inference='debiasedlasso')\n"

In [106]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2
,mean_point,mean_point
,T0,T1
Y0,0.872,1.701
Y1,0.973,1.752
Y2,0.983,1.818

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.15,0.161,5.801,10.542,0.0,0.0,0.625,1.436,1.119,1.967
Y1,0.201,0.207,4.85,8.477,0.0,0.0,0.643,1.412,1.303,2.092
Y2,0.206,0.217,4.777,8.378,0.0,0.0,0.644,1.461,1.321,2.175

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.491,1.055,0.231,0.171,1.611,3.13
Y1,0.459,1.117,0.394,0.17,1.647,3.352
Y2,0.639,1.003,0.143,0.424,1.948,3.207

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.514,1.067,0.093,0.006,1.714,3.322
Y1,0.501,1.136,0.199,-0.024,1.781,3.577
Y2,0.671,1.026,-0.039,0.194,2.085,3.369


In [107]:
est.const_marginal_effect(X_test)

array([[[0.13077804, 0.02332393],
        [0.25547372, 0.00938365],
        [0.02415333, 0.2152192 ]],

       [[0.3541333 , 0.3519374 ],
        [0.56279053, 0.36521078],
        [0.28732139, 0.67938142]],

       [[0.52149326, 0.72967782],
        [0.65649485, 0.84908201],
        [0.53406124, 0.91182323]],

       [[0.50094157, 1.2731858 ],
        [0.57551043, 1.19889876],
        [0.52189689, 1.22348235]],

       [[0.64698214, 1.63190329],
        [0.71810747, 1.47267158],
        [0.69027919, 1.65070377]],

       [[1.04337743, 1.81531876],
        [1.14435473, 2.02001235],
        [1.21221194, 1.96034615]],

       [[1.08505939, 2.29733447],
        [1.27276342, 2.23438365],
        [1.22465871, 2.49918927]],

       [[1.22819938, 2.67202881],
        [1.25252386, 2.71712541],
        [1.45758627, 2.64670391]],

       [[1.54423043, 2.89559918],
        [1.62496971, 3.08528142],
        [1.84215341, 3.08587816]],

       [[1.66543734, 3.32143631],
        [1.66536014, 3.5711236

In [108]:
est.const_marginal_effect_interval(X_test)

(array([[[-0.13762973, -0.25261339],
         [-0.1211709 , -0.35776153],
         [-0.35187154, -0.14816464]],
 
        [[ 0.03499393,  0.02493207],
         [ 0.13048679, -0.05370933],
         [-0.14722734,  0.25960719]],
 
        [[ 0.31212803,  0.53024587],
         [ 0.36354088,  0.57914919],
         [ 0.24887897,  0.64992537]],
 
        [[ 0.33415382,  1.07060752],
         [ 0.34827067,  0.95237471],
         [ 0.29313873,  0.96294004]],
 
        [[ 0.41405034,  1.36756418],
         [ 0.39459682,  1.13168459],
         [ 0.35499746,  1.30095382]],
 
        [[ 0.87127566,  1.63654286],
         [ 0.92258963,  1.79170129],
         [ 0.98522643,  1.71194206]],
 
        [[ 0.78759008,  1.99041965],
         [ 0.87602695,  1.83107936],
         [ 0.79987841,  2.0689614 ]],
 
        [[ 1.00638181,  2.42827684],
         [ 0.95688237,  2.4066607 ],
         [ 1.15799161,  2.31954015]],
 
        [[ 1.32487798,  2.64530238],
         [ 1.35785542,  2.78181122],
         [ 1.5

In [109]:
est.const_marginal_effect_inference(X_test).summary_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,point_estimate,point_estimate,stderr,stderr,zstat,zstat,pvalue,pvalue,ci_lower,ci_lower,ci_upper,ci_upper
Unnamed: 0_level_1,Unnamed: 1_level_1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
0,Y0,0.131,0.023,0.163,0.168,0.801,0.139,0.423,0.889,-0.138,-0.253,0.399,0.299
0,Y1,0.255,0.009,0.229,0.223,1.116,0.042,0.265,0.966,-0.121,-0.358,0.632,0.377
0,Y2,0.024,0.215,0.229,0.221,0.106,0.974,0.916,0.33,-0.352,-0.148,0.4,0.579
1,Y0,0.354,0.352,0.194,0.199,1.825,1.77,0.068,0.077,0.035,0.025,0.673,0.679
1,Y1,0.563,0.365,0.263,0.255,2.141,1.434,0.032,0.152,0.13,-0.054,0.995,0.784
1,Y2,0.287,0.679,0.264,0.255,1.088,2.662,0.277,0.008,-0.147,0.26,0.722,1.099
2,Y0,0.521,0.73,0.127,0.121,4.097,6.018,0.0,0.0,0.312,0.53,0.731,0.929
2,Y1,0.656,0.849,0.178,0.164,3.686,5.174,0.0,0.0,0.364,0.579,0.949,1.119
2,Y2,0.534,0.912,0.173,0.159,3.08,5.727,0.002,0.0,0.249,0.65,0.819,1.174
3,Y0,0.501,1.273,0.101,0.123,4.94,10.338,0.0,0.0,0.334,1.071,0.668,1.476


In [110]:
# effect
est.effect(X_test)

array([[0.13077804, 0.25547372, 0.02415333],
       [0.3541333 , 0.56279053, 0.28732139],
       [0.52149326, 0.65649485, 0.53406124],
       [0.50094157, 0.57551043, 0.52189689],
       [0.64698214, 0.71810747, 0.69027919],
       [1.04337743, 1.14435473, 1.21221194],
       [1.08505939, 1.27276342, 1.22465871],
       [1.22819938, 1.25252386, 1.45758627],
       [1.54423043, 1.62496971, 1.84215341],
       [1.66543734, 1.66536014, 2.03373201]])

In [111]:
est.effect_interval(X_test)

(array([[-0.13762973, -0.1211709 , -0.35187154],
        [ 0.03499393,  0.13048679, -0.14722734],
        [ 0.31212803,  0.36354088,  0.24887897],
        [ 0.33415382,  0.34827067,  0.29313873],
        [ 0.41405034,  0.39459682,  0.35499746],
        [ 0.87127566,  0.92258963,  0.98522643],
        [ 0.78759008,  0.87602695,  0.79987841],
        [ 1.00638181,  0.95688237,  1.15799161],
        [ 1.32487798,  1.35785542,  1.55478265],
        [ 1.35524696,  1.27482056,  1.6280135 ]]),
 array([[0.3991858 , 0.63211833, 0.4001782 ],
        [0.67327267, 0.99509427, 0.72187011],
        [0.73085849, 0.94944882, 0.81924351],
        [0.66772932, 0.80275019, 0.75065506],
        [0.87991395, 1.04161813, 1.02556092],
        [1.21547921, 1.36611983, 1.43919745],
        [1.38252869, 1.66949988, 1.649439  ],
        [1.45001695, 1.54816535, 1.75718092],
        [1.76358289, 1.89208399, 2.12952417],
        [1.97562771, 2.05589972, 2.43945051]]))

In [112]:
est.effect_inference(X_test).summary_frame()

Unnamed: 0,Unnamed: 1,point_estimate,stderr,zstat,pvalue,ci_lower,ci_upper
0,Y0,0.131,0.163,0.801,0.423,-0.138,0.399
0,Y1,0.255,0.229,1.116,0.265,-0.121,0.632
0,Y2,0.024,0.229,0.106,0.916,-0.352,0.4
1,Y0,0.354,0.194,1.825,0.068,0.035,0.673
1,Y1,0.563,0.263,2.141,0.032,0.13,0.995
1,Y2,0.287,0.264,1.088,0.277,-0.147,0.722
2,Y0,0.521,0.127,4.097,0.0,0.312,0.731
2,Y1,0.656,0.178,3.686,0.0,0.364,0.949
2,Y2,0.534,0.173,3.08,0.002,0.249,0.819
3,Y0,0.501,0.101,4.94,0.0,0.334,0.668


In [113]:
est.const_marginal_effect_inference(X_test).population_summary().print()

[<class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>, <class 'statsmodels.iolib.table.SimpleTable'>]


0,1,2
,mean_point,mean_point
,T0,T1
Y0,0.872,1.701
Y1,0.973,1.752
Y2,0.983,1.818

0,1,2,3,4,5,6,7,8,9,10
,stderr_mean,stderr_mean,zstat,zstat,pvalue,pvalue,ci_mean_lower,ci_mean_lower,ci_mean_upper,ci_mean_upper
,T0,T1,T0,T1,T0,T1,T0,T1,T0,T1
Y0,0.15,0.161,5.801,10.542,0.0,0.0,0.625,1.436,1.119,1.967
Y1,0.201,0.207,4.85,8.477,0.0,0.0,0.643,1.412,1.303,2.092
Y2,0.206,0.217,4.777,8.378,0.0,0.0,0.644,1.461,1.321,2.175

0,1,2,3,4,5,6
,std_point,std_point,pct_point_lower,pct_point_lower,pct_point_upper,pct_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.491,1.055,0.231,0.171,1.611,3.13
Y1,0.459,1.117,0.394,0.17,1.647,3.352
Y2,0.639,1.003,0.143,0.424,1.948,3.207

0,1,2,3,4,5,6
,std_point,std_point,ci_point_lower,ci_point_lower,ci_point_upper,ci_point_upper
,T0,T1,T0,T1,T0,T1
Y0,0.514,1.067,0.093,0.006,1.714,3.322
Y1,0.501,1.136,0.199,-0.024,1.781,3.577
Y2,0.671,1.026,-0.039,0.194,2.085,3.369
