In [1]:
import numpy as np
import pandas as pd
from sklearn.neural_network import MLPClassifier, MLPRegressor


def dgp(
    n_obs: int = 1000,
    n_covariates: int = 10,
    n_confounders: int = 10,
    n_treatments: int = 1,
    n_outcomes: int = 1,
    binary_treatment: bool = False,
    fraction_treated: float = 0.5,
    binary_outcome: bool = False,
    fraction_positive: float = 0.5,
    scale: float = 1,
    seed: int | None = None,
    diagonal_covariance_matrix: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Generate data for a treatment effect estimation problem.
    """

    np.random.seed(seed)

    # Generate Covariates and Confounders
    means = np.random.normal(0, 1, n_covariates) * scale
    cov_mat = (
        np.eye(n_covariates)
        if diagonal_covariance_matrix
        else create_random_covariance_matrix(n_covariates)
    )
    covariates = np.random.multivariate_normal(means, cov_mat, n_obs)

    # Generate Confounders
    means = np.random.normal(0, 1, n_confounders) * scale
    cov_mat = (
        np.eye(n_confounders)
        if diagonal_covariance_matrix
        else create_random_covariance_matrix(n_confounders)
    )
    confounders = np.random.multivariate_normal(means, cov_mat, n_obs)

    ## Neural Network for Nuissance Functions
    activation = "relu"

    # Generate Treatment
    hidden_layers = (20, 20, 20)
    pseudo_x = np.random.randn(n_obs, n_confounders) * scale
    if binary_treatment:
        pseudo_t = (np.random.binomial(1, 0.5, (n_obs, n_treatments))).reshape(
            -1, n_treatments
        )
        treatment_model = MLPClassifier(
            hidden_layer_sizes=hidden_layers,
            activation=activation,
            random_state=seed,
            max_iter=1000,
        )
    else:
        pseudo_t = (np.random.randn(n_obs, n_treatments) * scale).reshape(
            -1, n_treatments
        )
        treatment_model = MLPRegressor(
            hidden_layer_sizes=hidden_layers,
            activation=activation,
            random_state=seed,
            max_iter=1000,
        )
    if n_treatments == 1:
        treatment_model.fit(pseudo_x, pseudo_t.ravel())
    else:
        treatment_model.fit(pseudo_x, pseudo_t)

    if n_treatments == 1 and binary_treatment:
        treatments = treatment_model.predict_proba(confounders)[:, 1].reshape(
            -1, n_treatments
        )
    elif binary_treatment:
        treatments = treatment_model.predict_proba(confounders).reshape(
            -1, n_treatments
        )
    else:
        treatments = treatment_model.predict(confounders).reshape(-1, n_treatments)

    treatments = (
        np.where(
            treatments > find_optimal_threshold(treatments, fraction_treated), 1, 0
        )
        if binary_treatment
        else treatments
    )

    if not binary_treatment:
        treatments += np.random.randn(n_obs, n_treatments) * scale

    # Generate Outcome
    hidden_layers = (20, 20, 20)
    pseudo_x = np.random.randn(n_obs, (n_confounders + n_covariates)) * scale

    if binary_outcome:
        pseudo_y = (np.random.binomial(1, 0.5, (n_obs, n_outcomes))).reshape(
            -1, n_outcomes
        )
        outcome_model = MLPClassifier(
            hidden_layer_sizes=hidden_layers,
            activation=activation,
            random_state=seed,
            max_iter=1000,
        )
    else:
        pseudo_y = (np.random.randn(n_obs, n_outcomes) * scale).reshape(-1, n_outcomes)
        outcome_model = MLPRegressor(
            hidden_layer_sizes=hidden_layers,
            activation=activation,
            random_state=seed,
            max_iter=1000,
        )
    if n_outcomes == 1:
        outcome_model.fit(
            np.concatenate((pseudo_x, pseudo_t), axis=1), pseudo_y.ravel()
        )
    else:
        outcome_model.fit(np.concatenate((pseudo_x, pseudo_t), axis=1), pseudo_y)

    X_W = np.concatenate([covariates, confounders], axis=1)
    X_W_t = np.concatenate([X_W, treatments], axis=1)

    if n_outcomes == 1 and binary_outcome:
        outcome = outcome_model.predict_proba(X_W_t)[:, 1].reshape(-1, n_outcomes)
    elif binary_outcome:
        outcome = outcome_model.predict_proba(X_W_t).reshape(-1, n_outcomes)
    else:
        outcome = outcome_model.predict(X_W_t).reshape(-1, n_outcomes)

    outcome = (
        np.where(outcome > find_optimal_threshold(outcome, fraction_positive), 1, 0)
        if binary_outcome
        else outcome
    )

    if not binary_outcome:
        outcome += np.random.randn(n_obs, n_outcomes) * scale

    # Compute true treatment effect
    cates = np.zeros((n_obs, n_treatments * n_outcomes))
    for t in range(n_treatments):
        # Potential Outcome Under Treatment=1
        t1 = np.copy(treatments)
        t1[:, t] = 1
        X_W_t1 = np.concatenate((X_W, t1), axis=1)
        if n_outcomes == 1 and binary_outcome:
            y1 = outcome_model.predict_proba(X_W_t1)[:, 1].reshape(-1, n_outcomes)
        elif binary_outcome:
            y1 = outcome_model.predict_proba(X_W_t1).reshape(-1, n_outcomes)
        else:
            y1 = outcome_model.predict(X_W_t1).reshape(-1, n_outcomes)

        # Potential Outcome Under Treatment=0
        t0 = np.copy(treatments)
        t0[:, t] = 0
        X_W_t0 = np.concatenate((X_W, t0), axis=1)
        if n_outcomes == 1 and binary_outcome:
            y0 = outcome_model.predict_proba(X_W_t0)[:, 1].reshape(-1, n_outcomes)
        elif binary_outcome:
            y0 = outcome_model.predict_proba(X_W_t0).reshape(-1, n_outcomes)
        else:
            y0 = outcome_model.predict(X_W_t0).reshape(-1, n_outcomes)

        cates[:, t * n_outcomes : t * n_outcomes + n_outcomes] = y1 - y0

    # Save results
    results_np = np.concatenate((covariates, confounders, treatments, outcome), axis=1)
    data_columns = (
        [f"W{i}" for i in range(n_covariates)]
        + [f"X{i}" for i in range(n_confounders)]
        + [f"T{i}" for i in range(n_treatments)]
        + [f"Y{i}" for i in range(n_outcomes)]
    )
    data_df = pd.DataFrame(results_np, columns=data_columns)

    cate_columns = [
        f"cate_T{i}_Y{j}" for i in range(n_treatments) for j in range(n_outcomes)
    ]
    cate_df = pd.DataFrame(cates, columns=cate_columns)

    ate_columns = [
        f"ate_T{i}_Y{j}" for i in range(n_treatments) for j in range(n_outcomes)
    ]
    ate_df = pd.DataFrame(cates.mean(axis=0).reshape(1, -1), columns=ate_columns)

    return data_df, cate_df, ate_df


def find_optimal_threshold(pred_probs, approx_percentage=0.5):
    """
    Find the optimal threshold to achieve the approx. percentage of positive class.
    This helps to prevent having no overlap between the treated and control groups.
    """
    sorted_probs = np.sort(pred_probs, axis=0)
    threshold_index = int(len(sorted_probs) * approx_percentage)
    optimal_threshold = sorted_probs[threshold_index]
    return optimal_threshold


def create_random_covariance_matrix(dim):
    """
    Create a random covariance matrix.
    """
    A = np.random.randn(dim, dim)
    cov_matrix = np.dot(A, A.T)

    diag_indices = np.diag_indices_from(cov_matrix)
    cov_matrix[diag_indices] = np.abs(cov_matrix[diag_indices]) + 0.1
    return cov_matrix

In [2]:
data_df, cate_df, ate_df = dgp(
    n_obs=1000,
    n_covariates=10,
    n_confounders=10,
    n_treatments=5,
    n_outcomes=1,
    binary_treatment=True,
    fraction_treated=0.5,
    binary_outcome=False,
    fraction_positive=0.5,
    scale=30,
    seed=None,
    diagonal_covariance_matrix=False,
)



In [3]:
data_df

Unnamed: 0,W0,W1,W2,W3,W4,W5,W6,W7,W8,W9,...,X6,X7,X8,X9,T0,T1,T2,T3,T4,Y0
0,44.068070,5.563959,-15.260422,54.469207,11.474304,47.772324,-15.153069,22.256353,-10.142323,-11.500673,...,-93.289708,-40.226084,-2.854754,29.269818,1.0,1.0,1.0,1.0,0.0,-86.932298
1,47.247715,4.227357,-9.957668,59.428336,15.029105,36.200244,-10.479692,21.018649,1.340407,-9.151722,...,-92.553691,-38.430124,-2.026319,32.385306,1.0,1.0,1.0,0.0,1.0,-53.675490
2,44.356834,9.501696,-10.840262,53.795746,13.213829,43.352723,-5.051027,20.256115,-6.063109,-16.876549,...,-92.787805,-35.600485,2.560935,23.508743,1.0,1.0,1.0,1.0,0.0,-86.602018
3,47.326064,12.472349,-13.018940,59.364736,20.723081,34.246228,-6.243308,15.674363,11.108955,-13.764291,...,-93.520261,-40.254826,-3.625336,33.952704,1.0,1.0,1.0,1.0,1.0,-30.212918
4,46.147671,15.907871,-8.255778,60.102249,13.515695,37.563065,-1.054328,18.394813,-0.670780,-18.218689,...,-91.529224,-33.041514,4.338915,25.131318,0.0,1.0,0.0,0.0,0.0,-70.188366
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,44.172605,10.220925,-7.430808,60.259274,9.841663,36.928944,-7.307646,21.429358,0.116150,-16.581319,...,-94.378536,-38.990963,-1.148649,32.073726,0.0,0.0,0.0,0.0,1.0,-96.245058
996,49.656972,7.159894,-13.386703,55.209077,15.186534,39.876344,-10.065630,18.304867,-3.203812,-8.929702,...,-89.808554,-36.637169,-4.133478,30.448921,1.0,1.0,1.0,1.0,1.0,-73.097483
997,45.511290,12.557860,-10.347858,57.124201,8.368166,38.986198,-6.961067,16.893740,-2.216060,-16.148744,...,-95.542308,-31.513134,1.643670,28.939637,0.0,0.0,0.0,0.0,0.0,-46.979959
998,46.339735,10.409986,-11.765037,61.375303,6.138240,38.341220,-10.319428,19.211055,-4.884522,-14.376638,...,-94.051172,-36.117302,-8.640294,26.661948,1.0,1.0,1.0,1.0,1.0,-71.119958


In [4]:
data_df[[c for c in data_df.columns if "T" in c and "_" not in c]].sum()

T0    499.0
T1    499.0
T2    499.0
T3    499.0
T4    499.0
dtype: float64

In [5]:
data_df[[c for c in data_df.columns if "Y" in c and "_" not in c]].sum()

Y0   -63248.256433
dtype: float64

In [6]:
cate_df

Unnamed: 0,cate_T0_Y0,cate_T1_Y0,cate_T2_Y0,cate_T3_Y0,cate_T4_Y0
0,-0.309689,1.830175,4.480727,0.706532,2.756298
1,0.490595,5.631614,1.849682,1.996809,5.984933
2,-0.039416,1.830175,4.480727,-1.777876,2.097867
3,1.371510,4.909145,0.742526,2.917155,3.846381
4,0.517141,1.603614,3.123140,-2.932657,1.760626
...,...,...,...,...,...
995,1.749345,4.463935,2.536360,-0.248574,1.665973
996,-0.309689,1.830175,4.480727,0.706532,2.756298
997,0.517141,0.523397,2.491214,-2.547378,-0.499110
998,-0.309689,1.830175,4.480727,0.642925,2.756298


In [7]:
print(data_df["Y0"].mean())
print(data_df["Y0"].min())
print(data_df["Y0"].max())
print(data_df["Y0"].std())

-63.248256432799096
-164.99009635855262
40.71974666548617
33.0897900944593


In [8]:
print(cate_df.min(axis=0))
print(cate_df.max(axis=0))
print(cate_df.mean(axis=0))
print(cate_df.std(axis=0))

cate_T0_Y0   -0.811531
cate_T1_Y0   -0.201293
cate_T2_Y0   -1.511691
cate_T3_Y0   -3.141224
cate_T4_Y0   -0.876797
dtype: float64
cate_T0_Y0     2.637253
cate_T1_Y0     6.844015
cate_T2_Y0    12.862326
cate_T3_Y0     4.352056
cate_T4_Y0     8.006148
dtype: float64
cate_T0_Y0    0.668859
cate_T1_Y0    2.329986
cate_T2_Y0    2.780357
cate_T3_Y0   -0.158873
cate_T4_Y0    2.031624
dtype: float64
cate_T0_Y0    0.762433
cate_T1_Y0    1.671822
cate_T2_Y0    1.495751
cate_T3_Y0    1.751056
cate_T4_Y0    1.693494
dtype: float64


In [9]:
ate_df

Unnamed: 0,ate_T0_Y0,ate_T1_Y0,ate_T2_Y0,ate_T3_Y0,ate_T4_Y0
0,0.668859,2.329986,2.780357,-0.158873,2.031624
