## Run the following setup code


In [11]:
# Parameters
dataset_patterns = (
    "IHDP,ACIC 2016,RealCause Lalonde CPS,RealCause Lalonde PSID"  # all of the datasets that are being run
)
method_patterns = "*"  # all of the methods to run
index = None
replace_runs = False  # whether to replace existing runs

In [12]:
index = int(index) if index is not None else None

In [None]:
%load_ext autoreload
%autoreload 2
import warnings
import numpy as np
import torch
import random
from dotenv import load_dotenv
import os
import sys
from utils import *

sys.path.append("..")
load_dotenv(override=True)

warnings.filterwarnings('ignore') # ignore warnings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset_patterns = dataset_patterns.split(",")
method_patterns = method_patterns.split(",")

# Set seeds for reproducibility
seed = 82718
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


Load all the datasets


In [14]:
# Load datasets and required functions
%autoreload 2
from benchmarks import IHDPDataset, ACIC2016Dataset
from benchmarks import RealCauseLalondeCPSDataset, RealCauseLalondePSIDDataset

# Store all the results for all the datasets

IHDP = IHDPDataset()
ACIC2016 = ACIC2016Dataset()
RealCauseLalondeCPS = RealCauseLalondeCPSDataset()
RealCauseLalondePSID = RealCauseLalondePSIDDataset()

datasets = {
    "IHDP": IHDP if index is None else [IHDP[index%len(IHDP)]],
    "ACIC 2016": ACIC2016 if index is None else [ACIC2016[index%len(ACIC2016)]],
    "RealCause Lalonde CPS": RealCauseLalondeCPS if index is None else [RealCauseLalondeCPS[index%len(RealCauseLalondeCPS)]],
    "RealCause Lalonde PSID": RealCauseLalondePSID if index is None else [RealCauseLalondePSID[index%len(RealCauseLalondePSID)]],
}

causal_effect_path = os.path.join(os.environ["OUTPUT_DIR"], "causal_effect/")
os.makedirs(causal_effect_path, exist_ok=True)


## CausalPFN - TabDPT


In [None]:
from src.causalpfn import CATEEstimator, calculate_pehe, ATEEstimator
import time
from tqdm import tqdm

pbar = tqdm(
    total=sum([len(dataset) for dataset in datasets.values()]),
    desc="CausalPFN",
)
for dataset_name, dataset in datasets.items():
    pbar.set_postfix(dataset=dataset_name)
    for i in range(len(dataset)):
        # dataset_name: str, method_name: str, all_method_patterns: list, all_datasets_patterns: list, idx: int, artifact_dir: str, replace: bool = False
        with result_saver(
            dataset_name=dataset_name,
            method_name="CausalPFN",
            all_method_patterns=method_patterns,
            all_datasets_patterns=dataset_patterns,
            idx=i if index is None else index,
            artifact_dir=causal_effect_path,
            replace=replace_runs,
        ) as result:
            if result is not None:
                cate_dset, ate_dset = dataset[i]

                # CATE
                time_start = time.time()
                cate_estimator = CATEEstimator(
                    device=device,
                )
                cate_estimator.fit(cate_dset.X_train, cate_dset.t_train, cate_dset.y_train)
                estimated_cate = cate_estimator.estimate_cate(X=cate_dset.X_test)
                time_spent = time.time() - time_start
                pehe = calculate_pehe(cate_pred=estimated_cate, cate_true=cate_dset.true_cate)
                result["pehe"] = pehe
                result["time_cate"] = time_spent / (len(cate_dset.X_test) + len(cate_dset.X_train)) * 1000

                # ATE
                time_start = time.time()
                ate_estimator = ATEEstimator(
                    device=device,
                )
                ate_estimator.fit(ate_dset.X, ate_dset.t, ate_dset.y)
                estimated_ate = ate_estimator.estimate_ate()
                time_spent = time.time() - time_start
                result["ate_rel_err"] = abs(estimated_ate - ate_dset.true_ate) / abs(ate_dset.true_ate)
                result["time_ate"] = time_spent / len(ate_dset.X) * 1000

            pbar.update(1)
pbar.close()

## Baselines


In [None]:
# Baselines (Base)
from benchmarks.baselines import BaselineModel

# Baselines (EconML)
from benchmarks.baselines import (
    ForestDMLBaseline,
    TLearnerBaseline,
    SLearnerBaseline,
    XLearnerBaseline,
    DALearnerBaseline,
    XLearnerBaseline,
    ForestDRLearnerBaseline,
)

# Baselines (CATE Net)
from benchmarks.baselines import TarNetBaseline, DragonNetBaseline, RANetBaseline

# GRF & BART & IPW
from benchmarks.baselines import GRFBaseline, BartBaseline, IPWBaseline


baselines = {
    "T Learner (no HPO)": TLearnerBaseline(hpo=False),
    "T Learner (HPO)": TLearnerBaseline(hpo=True),
    "S Learner (no HPO)": SLearnerBaseline(hpo=False),
    "S Learner (HPO)": SLearnerBaseline(hpo=True),
    "X Learner (no HPO)": XLearnerBaseline(hpo=False),
    "X Learner (HPO)": XLearnerBaseline(hpo=True),
    "DA Learner (no HPO)": DALearnerBaseline(hpo=False),
    "DA Learner (HPO)": DALearnerBaseline(hpo=True),
    "Forest DR Learner (no HPO)": ForestDRLearnerBaseline(hpo=False),
    "Forest DR Learner (HPO)": ForestDRLearnerBaseline(hpo=True),
    "Forest DML (no HPO)": ForestDMLBaseline(hpo=False),
    "Forest DML (HPO)": ForestDMLBaseline(hpo=True),
    "DragonNet (no HPO)": DragonNetBaseline(hpo=False),
    "DragonNet (HPO)": DragonNetBaseline(hpo=True),
    "TarNet (no HPO)": TarNetBaseline(hpo=False),
    "TarNet (HPO)": TarNetBaseline(hpo=True),
    "RA Net (no HPO)": RANetBaseline(hpo=False),
    "RA Net (HPO)": RANetBaseline(hpo=True),
    "GRF (no HPO)": GRFBaseline(hpo=False),
    "GRF (HPO)": GRFBaseline(hpo=True),
    "BART": BartBaseline(hpo=False),
    "IPW (no HPO)": IPWBaseline(hpo=False),
    "IPW (HPO)": IPWBaseline(hpo=True),
}


pbar = tqdm(
    total=sum([len(dataset) * len(baselines) for dataset in datasets.values()]),
    desc="Baselines",
)

for dataset_name, dataset in datasets.items():
    for baseline_name, baseline in baselines.items():
        pbar.set_postfix(dataset=dataset_name, baseline=baseline_name)
        for i in range(len(dataset)):
            with result_saver(
                dataset_name=dataset_name,
                method_name=baseline_name,
                all_method_patterns=method_patterns,
                all_datasets_patterns=dataset_patterns,
                idx=i if index is None else index,
                artifact_dir=causal_effect_path,
                replace=replace_runs,
            ) as result:
                if result is not None:
                    baseline: BaselineModel
                    cate_dset, ate_dset = dataset[i]

                    # CATE
                    start_time = time.time()
                    cate_pred = baseline.estimate_cate(
                        X_train=cate_dset.X_train,
                        t_train=cate_dset.t_train,
                        y_train=cate_dset.y_train,
                        X_test=cate_dset.X_test,
                    )
                    time_spent = time.time() - start_time
                    pehe = calculate_pehe(cate_true=cate_dset.true_cate, cate_pred=cate_pred)
                    result["pehe"] = pehe
                    result["time_cate"] = time_spent / (len(cate_dset.X_test) + len(cate_dset.X_train)) * 1000

                    # ATE
                    start_time = time.time()
                    ate_pred = baseline.estimate_ate(
                        X=ate_dset.X,
                        t=ate_dset.t,
                        y=ate_dset.y,
                    )
                    time_spent = time.time() - start_time
                    result["ate_rel_err"] = abs(ate_pred - ate_dset.true_ate) / abs(ate_dset.true_ate)
                    result["time_ate"] = time_spent / len(ate_dset.X) * 1000
                pbar.update(1)
pbar.close()

## Parse and visualize all of the results

Once done, parse all of the results that were stored with the following code into a dataframe. This dataframe will contain different rows for each causal task and columns for the dataset, fold, method, and metric.


In [7]:
import pandas as pd

methods_to_show = ["CausalPFN"] + list(baselines.keys())
methods_to_show = [method for method in methods_to_show if any([check_match(method, m) for m in method_patterns])]
results_df = pd.DataFrame(columns=["dataset", "method", "pehe", "ate_rel_err", "time_cate", "time_ate", "realization"])
for dataset_name, dataset in datasets.items():
    dset_result = load_all_results(dataset_name, causal_effect_path)
    for method in methods_to_show:
        all_rows = dset_result[method]
        num_realizations = len(all_rows["pehe"])
        for fold_idx in range(num_realizations):
            pehe = float(all_rows["pehe"][fold_idx])
            ate_rel_err = float(all_rows["ate_rel_err"][fold_idx])
            time_cate = float(all_rows["time_cate"][fold_idx])
            time_ate = float(all_rows["time_ate"][fold_idx])
            new_row = dict(
                dataset=dataset_name,
                method=method,
                cate_pehe=pehe,
                ate_rel_err=ate_rel_err,
                ate_time=time_cate,
                cate_time=time_ate,
                realization=fold_idx,
            )
            results_df = pd.concat([results_df, pd.DataFrame(new_row, index=[0])], ignore_index=True)

Once done, you can load the results below:


In [9]:
# summarize all of the ATE and CATE estimation times by averaging over realizations
time_spent_df = (
    results_df.pivot_table(
        index="method",  # rows: one per method
        columns="dataset",  # multi‐columns: first level will be dataset
        values=["ate_time", "cate_time"],  # the values to aggregate
        aggfunc="mean",  # take the mean over realizations
    )
    .swaplevel(0, 1, axis=1)
    .sort_index(axis=1, level=0)
)

# Compute mean and standard error for ATE and CATE metrics
metrics = ["cate_pehe", "ate_rel_err"]
grp = results_df.groupby(["method", "dataset"])[metrics].agg(["mean", "sem"])  # MultiIndex cols: (metric, agg)
methods = grp.index.levels[0]
datasets_index = grp.index.levels[1]
data = {}
for ds in datasets_index:
    for m in metrics:
        means = grp[(m, "mean")].xs(ds, level="dataset")
        sems = grp[(m, "sem")].xs(ds, level="dataset")
        # combine into "xx.xx ± yy.yy" strings
        data[(ds, m)] = means.combine(sems, lambda mu, se: f"{mu:.2f} ± {se:.2f}")
causal_effect_errors = pd.DataFrame(data, index=methods)
causal_effect_errors.columns = pd.MultiIndex.from_tuples(causal_effect_errors.columns, names=["dataset", "metric"])
causal_effect_errors = causal_effect_errors.sort_index(axis=1, level=0)

Compute total performance for CATE (rank of PEHE), and ATE (average relative error) for each method and all causal tasks (across different realizations of each method). Then visualize all of the error rates.


In [None]:
from collections import defaultdict

ate_errors = defaultdict(list)
cate_pehes = defaultdict(list)
for dataset_name, dataset in datasets.items():
    dset_result = load_all_results(dataset_name, causal_effect_path)
    for method in methods_to_show:
        all_rows = dset_result[method]
        num_realizations = len(all_rows["pehe"])
        for fold_idx in range(num_realizations):
            pehe = float(all_rows["pehe"][fold_idx])
            ate_rel_err = float(all_rows["ate_rel_err"][fold_idx])
            ate_errors[method].append(ate_rel_err)
            cate_pehes[method].append(pehe)


def get_ranks(res: dict):
    ranks = {}
    ranks_ste = {}
    for method in methods_to_show:
        all_len = len(res[method])
        all_ranks = []
        for idx in range(all_len):
            rank = 0
            for other_methods in methods_to_show:
                our_res = res[method][idx]
                other_res = res[other_methods][idx]
                rank += our_res >= other_res
            all_ranks.append(rank)
        ranks[method] = sum(all_ranks) / all_len
        ranks_ste[method] = np.std(all_ranks) / np.sqrt(all_len)
    return ranks, ranks_ste


cate_ranks, cate_ranks_ste = get_ranks(cate_pehes)
ate_ranks, ate_ranks_ste = get_ranks(ate_errors)

# add a multicolumn to causal_effect_errors called "overall"
causal_effect_errors[("overall", "cate_rank ± ste")] = pd.Series(
    {method: f"{cate_ranks[method]:.2f} ± {cate_ranks_ste[method]:.2f}" for method in methods_to_show}
)
causal_effect_errors[("overall", "ate_rank ± ste")] = pd.Series(
    {method: f"{ate_ranks[method]:.2f} ± {ate_ranks_ste[method]:.2f}" for method in methods_to_show}
)

causal_effect_errors[("overall", "rank")] = (pd.Series(cate_ranks) + pd.Series(ate_ranks)) / 2
# sort rows by rank
causal_effect_errors = causal_effect_errors.sort_values(by=("overall", "rank"))
causal_effect_errors

Compute average time for each method and add that to the dataframe of times and sort according to that average.


In [None]:
from collections import defaultdict
import numpy as np

ate_times = defaultdict(list)
cate_times = defaultdict(list)
for dataset_name, dataset in datasets.items():
    dset_result = load_all_results(dataset_name, causal_effect_path)
    for method in methods_to_show:
        all_rows = dset_result[method]
        num_realizations = len(all_rows["pehe"])
        for fold_idx in range(num_realizations):
            cate_time = float(all_rows["time_cate"][fold_idx])
            ate_time = float(all_rows["time_ate"][fold_idx])
            ate_times[method].append(ate_time)
            cate_times[method].append(cate_time)
med_ate_times = {}
med_cate_times = {}
for method in methods_to_show:
    med_ate_times[method] = np.median([ate_times[method][i] for i in range(len(ate_times[method]))])
    med_cate_times[method] = np.median([cate_times[method][i] for i in range(len(cate_times[method]))])

# add a multicolumn to causal_effect_errors called "overall"
time_spent_df[("overall", "ate_time")] = pd.Series(med_ate_times)
time_spent_df[("overall", "cate_time")] = pd.Series(med_cate_times)
# sort rows by rank
time_spent_df = time_spent_df.sort_values(by=("overall", "cate_time"))
time_spent_df