In [1]:
%load_ext autoreload
%autoreload 2
import warnings
import numpy as np
import torch
import random
import sys
import pandas as pd
sys.path.append("..")

warnings.filterwarnings('ignore') # ignore warnings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Set seeds for reproducibility
seed = 82718
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


<torch._C.Generator at 0x7fe57b30ccd0>

# Compare CausalPFN to other Baselines

In the following, we run causalPFN alongside other baselines on the four datasets that provide ground-truth effects and have multiple realizations. Note that the following list is not exhaustive, but it includes some of the most commonly used baselines in the literature. Running all of the baselines would have bloated the notebook size and code complexity, so we have chosen a representative subset from EconML.

Run the following to run CausalPFN alongside the baselines on a suite of datasets that have multiple realizations and ground-truth effects. The results (e.g. ATE relative error and CATE PEHE) will be saved in the `results` dataframe, alongside their runtime per 1,000 samples:

In [None]:
# Load datasets and required functions
%autoreload 2
from benchmarks import IHDPDataset, ACIC2016Dataset
from benchmarks import RealCauseLalondeCPSDataset, RealCauseLalondePSIDDataset
import time
from causalpfn import ATEEstimator, CATEEstimator
from benchmarks.base import CATE_Dataset, ATE_Dataset
from benchmarks.baselines import (
    TLearnerBaseline,
    SLearnerBaseline,
    XLearnerBaseline,
    BaselineModel
)

from causalpfn.evaluation import calculate_pehe
from tqdm import tqdm

# Get different realizations for each dataset (only the first realization - you can change `n_tables`)
datasets = {
    "IHDP": IHDPDataset(n_tables=1),
    "ACIC 2016": ACIC2016Dataset(n_tables=1),
    "RealCause Lalonde CPS": RealCauseLalondeCPSDataset(n_tables=1),
    "RealCause Lalonde PSID": RealCauseLalondePSIDDataset(n_tables=1),
}
# get all of the baselines to compare with (not exhaustive -- feel free to comment out some)
baselines = {
    "X-Learner (no HPO)": XLearnerBaseline(hpo=False),
    "S-Learner (no HPO)": SLearnerBaseline(hpo=False),
    "T-Learner (no HPO)": TLearnerBaseline(hpo=False),
    ################################################
    # Ucomment the following lines to run with HPO #
    ################################################
    # "X-Learner (HPO)": XLearnerBaseline(hpo=True),
    # "S-Learner (HPO)": SLearnerBaseline(hpo=True),
    # "T-Learner (HPO)": TLearnerBaseline(hpo=True),
}
# Initialize results DataFrame
results = pd.DataFrame(columns=["dataset", "realization", "method", "ate_rel_err", "cate_pehe", "ate_time", "cate_time"])

# Iterate through datasets and realizations
pbar = tqdm(
    total=sum(len(dataset) * (1 + len(baselines)) for dataset in datasets.values()),
    desc="Processing datasets",
)
for dataset_name, dataset in datasets.items():
    for realization_idx in range(len(dataset)):
        pbar.set_postfix({"dataset": dataset_name, "method": "CausalPFN"})
        res = dataset[realization_idx]
        cate_dset: CATE_Dataset = res[0]
        ate_dset: ATE_Dataset = res[1]

        # run CausalPFN estimator for ATE
        start_time = time.time()
        causalpfn_ate = ATEEstimator(
            device=device,
        )
        causalpfn_ate.fit(ate_dset.X, ate_dset.t, ate_dset.y)
        true_ate = ate_dset.true_ate
        causalpfn_ate_hat = causalpfn_ate.estimate_ate()
        causalpfn_rel_error = abs(causalpfn_ate_hat - true_ate) / abs(true_ate)
        ate_time = time.time() - start_time

        # run CausalPFN estimator for CATE
        start_time = time.time()
        causalpfn_cate = CATEEstimator(
            device=device,
        )
        causalpfn_cate.fit(cate_dset.X_train, cate_dset.t_train, cate_dset.y_train)
        causalpfn_cate_hat = causalpfn_cate.estimate_cate(cate_dset.X_test)
        cate_pehe = calculate_pehe(cate_dset.true_cate, causalpfn_cate_hat)
        cate_time = time.time() - start_time

        # add results for CausalPFN
        row = dict(
            dataset=dataset_name,
            realization=realization_idx,
            method="CausalPFN",
            ate_rel_err=round(causalpfn_rel_error, 2),
            cate_pehe=round(cate_pehe, 2),
            ate_time=round(ate_time / (ate_dset.X.shape[0] + ate_dset.X.shape[0]) * 100, 2),
            cate_time=round(cate_time / (cate_dset.X_train.shape[0] + cate_dset.X_test.shape[0]) * 100, 2),
        )
        pbar.update(1)
        results = pd.concat([results, pd.DataFrame([row])], ignore_index=True)

        for method_name, baseline in baselines.items():
            pbar.set_postfix({"dataset": dataset_name, "method": method_name})
            baseline: BaselineModel

            # run baseline estimator for ATE
            start_time = time.time()
            ate_pred = baseline.estimate_ate(X=ate_dset.X, t=ate_dset.t, y=ate_dset.y)
            rel_err = np.abs(ate_pred - true_ate) / np.abs(true_ate)
            ate_time = time.time() - start_time

            # run baseline estimator for CATE
            start_time = time.time()
            cate_pred = baseline.estimate_cate(X_train=cate_dset.X_train, t_train=cate_dset.t_train, y_train=cate_dset.y_train, X_test=cate_dset.X_test)
            cate_pehe = calculate_pehe(cate_dset.true_cate, cate_pred)
            cate_time = time.time() - start_time

            # add results for baseline
            row = dict(
                dataset=dataset_name,
                realization=realization_idx,
                method=method_name,
                ate_rel_err=round(rel_err, 2),
                cate_pehe=round(cate_pehe, 2),
                ate_time=round(ate_time / (ate_dset.X.shape[0] + ate_dset.X.shape[0]) * 100, 2),
                cate_time=round(cate_time / (cate_dset.X_train.shape[0] + cate_dset.X_test.shape[0]) * 100, 2),
            )
            pbar.update(1)
            results = pd.concat([results, pd.DataFrame([row])], ignore_index=True)


rpy2 not installed, skipping BART baseline.
catenets not installed, skipping CATENet baselines.
rpy2 not installed, skipping GRF baseline.


Processing datasets: 100%|██████████| 16/16 [01:20<00:00,  2.36s/it, dataset=RealCause Lalonde PSID, method=T-Learner (no HPO)]

Next, run the following to visualize the average of different metrics across different realizations of each dataset.

In [3]:
# summarize all of the ATE and CATE estimation times by averaging over realizations
time_spent_df = (
    results.pivot_table(
        index="method",  # rows: one per method
        columns="dataset",  # multi‐columns: first level will be dataset
        values=["ate_time", "cate_time"],  # the values to aggregate
        aggfunc="mean",  # take the mean over realizations
    )
    .swaplevel(0, 1, axis=1)
    .sort_index(axis=1, level=0)
)

# Compute mean and standard error for ATE and CATE metrics
metrics = ["cate_pehe", "ate_rel_err"]
grp = results.groupby(["method", "dataset"])[metrics].agg(["mean", "sem"])  # MultiIndex cols: (metric, agg)
methods = grp.index.levels[0]
datasets = grp.index.levels[1]
data = {}
for ds in datasets:
    for m in metrics:
        means = grp[(m, "mean")].xs(ds, level="dataset")
        sems = grp[(m, "sem")].xs(ds, level="dataset")
        # combine into "xx.xx ± yy.yy" strings
        data[(ds, m)] = means.combine(sems, lambda mu, se: f"{mu:.2f} ± {se:.2f}")
causal_effect_errors = pd.DataFrame(data, index=methods)
causal_effect_errors.columns = pd.MultiIndex.from_tuples(causal_effect_errors.columns, names=["dataset", "metric"])
causal_effect_errors = causal_effect_errors.sort_index(axis=1, level=0)

Visualize the mean and standard errors for the causal effect estimates:

In [None]:
causal_effect_errors

dataset,ACIC 2016,ACIC 2016,IHDP,IHDP,RealCause Lalonde CPS,RealCause Lalonde CPS,RealCause Lalonde PSID,RealCause Lalonde PSID
metric,ate_rel_err,cate_pehe,ate_rel_err,cate_pehe,ate_rel_err,cate_pehe,ate_rel_err,cate_pehe
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
CausalPFN,0.07 ± nan,0.73 ± nan,0.00 ± nan,0.17 ± nan,0.00 ± nan,8877.69 ± nan,0.09 ± nan,13549.69 ± nan
S-Learner (no HPO),0.60 ± nan,3.47 ± nan,0.02 ± nan,0.50 ± nan,1.00 ± nan,12844.74 ± nan,1.02 ± nan,22972.74 ± nan
T-Learner (no HPO),0.35 ± nan,2.17 ± nan,0.02 ± nan,0.55 ± nan,0.28 ± nan,9000.14 ± nan,0.00 ± nan,13335.38 ± nan
X-Learner (no HPO),0.28 ± nan,1.99 ± nan,0.00 ± nan,0.55 ± nan,0.96 ± nan,12824.34 ± nan,0.93 ± nan,21588.71 ± nan


Visualize the average errors of causal effects

In [None]:
time_spent_df

dataset,ACIC 2016,ACIC 2016,IHDP,IHDP,RealCause Lalonde CPS,RealCause Lalonde CPS,RealCause Lalonde PSID,RealCause Lalonde PSID
Unnamed: 0_level_1,ate_time,cate_time,ate_time,cate_time,ate_time,cate_time,ate_time,cate_time
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
CausalPFN,0.17,0.06,0.27,0.08,0.1,0.06,0.12,0.03
S-Learner (no HPO),0.0,0.0,0.01,0.01,0.0,0.0,0.0,0.0
T-Learner (no HPO),0.0,0.0,0.01,0.02,0.0,0.0,0.0,0.01
X-Learner (no HPO),0.01,0.01,0.04,0.07,0.0,0.0,0.01,0.02
