# Evaluating Policy-making with CATE values

This is a notebook that evaluates Qini metrics for different RCT datasets spanning causal benchmarks and Uplift marketting datasets.

In [1]:
# Parameters
dataset_patterns = "Hill (1),Hill (2),*(Sub)*"  # ,*(Sub)*
method_patterns = "CausalPFN,T Learner,S Learner,X Learner,DA Learner,Forest DR Learner"
max_context_length = 50_000
replace_runs = False  # whether to replace existing runs

In [2]:
%load_ext autoreload
%autoreload 2
import warnings
import numpy as np
import torch
import random
from dotenv import load_dotenv
import os
import sys
from utils import *

sys.path.append("..")
load_dotenv(override=True)

dataset_patterns = dataset_patterns.split(",")
method_patterns = method_patterns.split(",")

warnings.filterwarnings('ignore') # ignore warnings

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Set seeds for reproducibility
seed = 82718
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# make the qini directory if it doesn't exist
qini_dir = os.path.join(os.environ["OUTPUT_DIR"], "qini")
os.makedirs(qini_dir, exist_ok=True)


## Setup all the different datasets

In [None]:
%autoreload 2
import re
from benchmarks import LentaDataset, CriteoDataset, HillstromDataset, X5Dataset, MegafonDataset
# this might take some time to run & load
datasets = {}

for dataset_pattern in dataset_patterns:
    if check_match("Hill (1)", dataset_pattern):
        print("Hit Hill (1)!")
        datasets["Hill (1)"] = HillstromDataset(
            seed=seed,
            n_folds=5,
            outcome_col="visit",
            control_arm = "No E-Mail",
            treatment_arm= "Womens E-Mail",
        )
    if check_match("Hill (2)", dataset_pattern):
        print("Hit Hill (2)!")
        datasets["Hill (2)"] = HillstromDataset(
            seed=seed,
            n_folds=5,
            outcome_col="visit",
            control_arm = "No E-Mail",
            treatment_arm= "Mens E-Mail",
        )
    if check_match("Criteo RCT", dataset_pattern):
        print("Hit Criteo RCT!")
        datasets["Criteo RCT"] = CriteoDataset(
            seed=seed,
            n_folds=5,
            outcome_col="visit",
            treatment_col="treatment",
        )
    if check_match("Lenta RCT", dataset_pattern):
        print("Hit Lenta RCT!")
        datasets["Lenta RCT"] = LentaDataset(
            seed=seed,
            n_folds=5,
        )
    if check_match("X5 RCT", dataset_pattern):
        print("Hit X5 RCT!")
        datasets["X5 RCT"] = X5Dataset(
            seed=seed,
            n_folds=5,
        )
    if check_match("Megafon RCT", dataset_pattern):
        print("Hit Megafon RCT!")
        datasets["Megafon RCT"] = MegafonDataset(
            seed=seed,
            n_folds=5,
        )
    if check_match("Criteo RCT (Sub)", dataset_pattern):
        print("Hit Criteo RCT (Sub)!")
        datasets["Criteo RCT (Sub)"] = CriteoDataset(
            seed=seed,
            n_folds=5,
            outcome_col="visit",
            treatment_col="treatment",
            subsample_max_rows=50_000,
        )
    if check_match("X5 RCT (Sub)", dataset_pattern):
        print("Hit X5 RCT (Sub)!")
        datasets["X5 RCT (Sub)"] = X5Dataset(
            seed=seed,
            n_folds=5,
            subsample_max_rows=50_000,
        )
    if check_match("Lenta RCT (Sub)", dataset_pattern):
        print("Hit Lenta RCT (Sub)!")
        datasets["Lenta RCT (Sub)"] = LentaDataset(
            seed=seed,
            n_folds=5,
            subsample_max_rows=50_000,
        )
    if check_match("Megafon RCT (Sub)", dataset_pattern):
        print("Hit Megafon RCT (Sub)!")
        datasets["Megafon RCT (Sub)"] = MegafonDataset(
            seed=seed,
            n_folds=5,
            subsample_max_rows=50_000,
        )

## Sanity Checks for the Data

Print the standard mean difference (SMD) for each feature of each dataset and print the average of this value for each data suite. This evaluates whether or not the evaluation benchmarks contain RCT data or not, which is important for the Qini and Uplift curves to be valid. 

In [None]:
print("Testing if the data is RCT:")
print(
    "\tAll the test splits should have small SMD and the train splits should also have small SMDs if they are also RCTs:"
)
for data_suite_name, dataset in datasets.items():
    print(f"Data suite: {data_suite_name}")
    train_smd = []
    test_smd = []
    for qini_data in dataset:
        for i in range(min(qini_data.X_train.shape[1], 10)):
            # Compute standard mean difference for each feature
            treatment_group = qini_data.X_train[qini_data.t_train == 1][:, i]
            control_group = qini_data.X_train[qini_data.t_train == 0][:, i]
            smd = (np.mean(treatment_group) - np.mean(control_group)) / (
                0.5 * (np.std(treatment_group) + np.std(control_group) + 1e-3)
            )
            train_smd.append(abs(smd))

            treatment_group = qini_data.X_test[qini_data.t_test == 1][:, i]
            control_group = qini_data.X_test[qini_data.t_test == 0][:, i]
            smd = (np.mean(treatment_group) - np.mean(control_group)) / (
                0.5 * (np.std(treatment_group) + np.std(control_group) + 1e-3)
            )
            test_smd.append(abs(smd))
    print(f">>>\tTest SMD: {np.mean(test_smd):.4f} ± {np.std(test_smd):.4f}")
    print(f">>>\tTrain SMD: {np.mean(train_smd):.4f} ± {np.std(train_smd):.4f}")

## Run our estimator

In [None]:
from src.causalpfn import CATEEstimator
from benchmarks.base import Qini_Dataset
import time
from tqdm import tqdm

pbar = tqdm(
    total=sum([len(dataset) for dataset in datasets.values()]),
    desc="CausalPFN",
)
for dataset_name, dataset in datasets.items():
    pbar.set_postfix(dataset=dataset_name)
    for i in range(len(dataset)):
        # dataset_name: str, method_name: str, all_method_patterns: list, all_datasets_patterns: list, idx: int, artifact_dir: str, replace: bool = False
        with result_saver(
            dataset_name=dataset_name,
            method_name="CausalPFN",
            all_method_patterns=method_patterns,
            all_datasets_patterns=dataset_patterns,
            idx=i,
            artifact_dir=qini_dir,
            replace=replace_runs,
        ) as result:
            if result is not None:

                qini_data: Qini_Dataset = dataset[i]
                time_start = time.time()
                cate_estimator_tabdpt = CATEEstimator(
                    device=device,
                    max_context_length=max_context_length,
                )

                cate_estimator_tabdpt.fit(X=qini_data.X_train, y=qini_data.y_train, t=qini_data.t_train)
                estimated_tau = cate_estimator_tabdpt.estimate_cate(X=qini_data.X_test)
                time_spent = time.time() - time_start
                result["estimated_effect"] = estimated_tau
                result["t"] = qini_data.t_test
                result["y"] = qini_data.y_test
                result["time_spent"] = time_spent / ((len(qini_data.X_test) + len(qini_data.X_train)) / 1000)

            pbar.update(1)
pbar.close()

## Run the Baselines

In [None]:
# Baselines (Base)
from benchmarks.baselines import BaselineModel

# Baselines (EconML)
from benchmarks.baselines import (
    TLearnerBaseline,
    SLearnerBaseline,
    XLearnerBaseline,
    DALearnerBaseline,
    XLearnerBaseline,
    ForestDRLearnerBaseline,
)


baselines = {
    "T Learner": TLearnerBaseline(hpo=False),
    "S Learner": SLearnerBaseline(hpo=False),
    "X Learner": XLearnerBaseline(hpo=False),
    "DA Learner": DALearnerBaseline(hpo=False),
    "Forest DR Learner": ForestDRLearnerBaseline(hpo=False),
}


pbar = tqdm(
    total=sum([len(dataset) * len(baselines) for dataset in datasets.values()]),
    desc="Baselines",
)

for dataset_name, dataset in datasets.items():
    for baseline_name, baseline in baselines.items():
        pbar.set_postfix(dataset=dataset_name, baseline=baseline_name)
        for i in range(len(dataset)):
            with result_saver(
                dataset_name=dataset_name,
                method_name=baseline_name,
                all_method_patterns=method_patterns,
                all_datasets_patterns=dataset_patterns,
                idx=i,
                artifact_dir=qini_dir,
                replace=replace_runs,
            ) as result:
                if result is not None:
                    baseline: BaselineModel

                    qini_data: Qini_Dataset = dataset[i]
                    time_start = time.time()
                    estimated_tau = baseline.estimate_cate(
                        X_train=qini_data.X_train,
                        t_train=qini_data.t_train,
                        y_train=qini_data.y_train,
                        X_test=qini_data.X_test,
                    )
                    time_spent = time.time() - time_start
                    result["estimated_effect"] = estimated_tau
                    result["t"] = qini_data.t_test
                    result["y"] = qini_data.y_test
                    result["time_spent"] = time_spent / ((len(qini_data.X_test) + len(qini_data.X_train)) / 1000)

                pbar.update(1)
pbar.close()

## Compute Qini Curves & Scores 

To compute the Qini curves, we concatenate all of the outcomes, treatments, and estimated effects for each dataset into a single list, sort the estimated effects and compute the qini curve and scores. We do that across all datasets and baselines.

In [None]:
from src.causalpfn.evaluation import get_qini_curve
import pandas as pd

methods_to_show = ["CausalPFN"] + list(baselines.keys())
methods_to_show = [method for method in methods_to_show if any([check_match(method, m) for m in method_patterns])]
datasets_to_show = list(datasets.keys())
datasets_to_show = [dataset for dataset in datasets_to_show if any([check_match(dataset, d) for d in dataset_patterns])]
qini_df = pd.DataFrame(columns=datasets_to_show)
qini_curves = {}
for dataset_name, dataset in datasets.items():
    qini_curves[dataset_name] = {}
    dset_result = load_all_results(dataset_name, qini_dir)
    for method in methods_to_show:
        all_rows = dset_result[method]
        num_realizations = len(all_rows["t"])
        all_t = []
        all_y = []
        all_tau = []
        for fold_idx in range(num_realizations):
            all_t.append(all_rows["t"][fold_idx])
            all_y.append(all_rows["y"][fold_idx])
            all_tau.append(all_rows["estimated_effect"][fold_idx])
        all_t = np.concatenate(all_t)
        all_y = np.concatenate(all_y)
        all_tau = np.concatenate(all_tau)
        qini_curves[dataset_name][method], qini_score = get_qini_curve(
            rct_treatments=all_t,
            rct_outcomes=all_y,
            estimated_cate=all_tau,
            normalize=True,
        )
        qini_df.loc[method, dataset_name] = qini_score
# devide every value by the maximum of the column
# (uncomment to normalize)
qini_df = qini_df / qini_df.max()
# add an average of each row as the last column
qini_df["Average"] = qini_df.mean(axis=1)
# sort the dataframe by the average column
qini_df = qini_df.sort_values(by="Average", ascending=False)
qini_df

Visualize the actual curves:

In [None]:
import matplotlib.pyplot as plt

for dset in datasets_to_show:
    for method in methods_to_show:
        curve = qini_curves[dset][method]
        plt.plot(np.linspace(0, 1, len(curve)), curve, label=method)
    plt.plot([0, 1], [0, 1], linestyle="--", color="black", label="Random")
    plt.title(f"Qini Curve for {dset}")
    plt.legend()
    plt.xlabel("Fraction of Population")
    plt.ylabel("Qini")
    plt.show()