In [1]:
import numpy as np
import pandas as pd

from sklearn.model_selection import KFold
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import NearestNeighbors

import statsmodels.api as sm
np.random.seed(42)
n = 2000
p = 5

X = np.random.normal(0, 1, (n, p))

# True heterogeneous treatment effect
tau = 2 + X[:, 0] - 0.5 * X[:, 1]

# Propensity score (non-random assignment)
logit_p = 0.5 * X[:, 0] - 0.25 * X[:, 2]
p_score = 1 / (1 + np.exp(-logit_p))
D = np.random.binomial(1, p_score)

# Outcome
Y = tau * D + X[:, 0] + 0.5 * X[:, 1] + np.random.normal(0, 1, n)
X_ols = sm.add_constant(np.column_stack([D, X]))
ols_model = sm.OLS(Y, X_ols).fit()
print("Naive OLS ATE:", ols_model.params[1])
ps_model = GradientBoostingClassifier()
ps_model.fit(X, D)
ps_hat = ps_model.predict_proba(X)[:, 1]

treated = np.where(D == 1)[0]
control = np.where(D == 0)[0]

nn = NearestNeighbors(n_neighbors=1)
nn.fit(ps_hat[control].reshape(-1, 1))

matches = nn.kneighbors(ps_hat[treated].reshape(-1, 1), return_distance=False)

psm_effect = np.mean(Y[treated] - Y[control][matches.flatten()])
print("PSM ATE:", psm_effect)
K = 2
kf = KFold(n_splits=K, shuffle=True, random_state=42)

Y_res = np.zeros(n)
D_res = np.zeros(n)

for train_idx, test_idx in kf.split(X):
    X_train, X_test = X[train_idx], X[test_idx]
    Y_train, Y_test = Y[train_idx], Y[test_idx]
    D_train, D_test = D[train_idx], D[test_idx]

    # Outcome model
    y_model = GradientBoostingRegressor()
    y_model.fit(X_train, Y_train)
    m_hat = y_model.predict(X_test)

    # Treatment model
    d_model = GradientBoostingClassifier()
    d_model.fit(X_train, D_train)
    p_hat = d_model.predict_proba(X_test)[:, 1]

    Y_res[test_idx] = Y_test - m_hat
    D_res[test_idx] = D_test - p_hat
dml_model = LinearRegression()
dml_model.fit(D_res.reshape(-1, 1), Y_res)

print("DML ATE:", dml_model.coef_[0])
cate_hat = tau  # true CATE for comparison

print("True CATE (first 5):", cate_hat[:5])


Naive OLS ATE: 2.0212737059762373
PSM ATE: 1.8004320386529626
DML ATE: 1.9979549220259805
True CATE (first 5): [2.5658463  0.97625664 1.76944718 1.94412803 3.57853692]
