In [7]:
# A few different package options
#import DoubleML as dml # Also has R package by same authors
import econml # Created by Microsoft
from econml.dml import DML, LinearDML

In [2]:
import numpy as np
from itertools import product
from sklearn.linear_model import (Lasso, LassoCV, LogisticRegression,
                                  LogisticRegressionCV,LinearRegression,
                                  MultiTaskElasticNet,MultiTaskElasticNetCV)
from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from sklearn.preprocessing import PolynomialFeatures
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split

In [4]:
# DGP constants
np.random.seed(123)
n = 2000
n_w = 30
support_size = 5
n_x = 1
# Outcome support
support_Y = np.random.choice(np.arange(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n)
# Treatment support
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n)

# Generate controls, covariates, treatments and outcomes
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))
# Heterogeneous treatment effects
TE = np.array([exp_te(x_i) for x_i in X])
T = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

Y_train, Y_val, T_train, T_val, X_train, X_val, W_train, W_val = train_test_split(Y, T, X, W, test_size=.2)
# Generate test data
X_test = np.array(list(product(np.arange(0, 1, 0.01), repeat=n_x)))


In [5]:
est = LinearDML(model_y=RandomForestRegressor(),
                model_t=RandomForestRegressor(),
                random_state=123)
est.fit(Y_train, T_train, X=X_train, W=W_train)
te_pred = est.effect(X_test)


In [6]:
te_pred

array([0.275715  , 0.3413093 , 0.40690359, 0.47249788, 0.53809217,
       0.60368647, 0.66928076, 0.73487505, 0.80046934, 0.86606363,
       0.93165793, 0.99725222, 1.06284651, 1.1284408 , 1.1940351 ,
       1.25962939, 1.32522368, 1.39081797, 1.45641226, 1.52200656,
       1.58760085, 1.65319514, 1.71878943, 1.78438372, 1.84997802,
       1.91557231, 1.9811666 , 2.04676089, 2.11235519, 2.17794948,
       2.24354377, 2.30913806, 2.37473235, 2.44032665, 2.50592094,
       2.57151523, 2.63710952, 2.70270382, 2.76829811, 2.8338924 ,
       2.89948669, 2.96508098, 3.03067528, 3.09626957, 3.16186386,
       3.22745815, 3.29305244, 3.35864674, 3.42424103, 3.48983532,
       3.55542961, 3.62102391, 3.6866182 , 3.75221249, 3.81780678,
       3.88340107, 3.94899537, 4.01458966, 4.08018395, 4.14577824,
       4.21137254, 4.27696683, 4.34256112, 4.40815541, 4.4737497 ,
       4.539344  , 4.60493829, 4.67053258, 4.73612687, 4.80172116,
       4.86731546, 4.93290975, 4.99850404, 5.06409833, 5.12969