In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Load Data

In [2]:
def load_scenario_data(h5_file_path, scenario_num):
    key = f"scenario_{scenario_num}/data"
    with pd.HDFStore(h5_file_path, mode='r') as store:
        if key not in store:
            return None  # Scenario not found
        df = store[key]
        metadata = store.get_storer(key).attrs.metadata
    return {"dataset": df, "metadata": metadata}

In [3]:
store_files = [
    "synthetic_data/RCT_0_5.h5",
    "synthetic_data/RCT_0_05.h5",
    "synthetic_data/e_X.h5",
    "synthetic_data/e_X_U.h5",
    "synthetic_data/e_X_info_censor.h5",
    "synthetic_data/e_X_U_info_censor.h5"
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. RCT_0_5
    scenario_dict = {}
    for scenario in range(1, 11):
        try:
            result = load_scenario_data(path, scenario)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [4]:
experiment_setups['RCT_0_5']['scenario_1']['dataset'].head()

Unnamed: 0,id,observed_time,event,W,X1,X2,X3,X4,X5,U1,U2,T0,T1,T,C
0,0,0.054267,1,0,0.135488,0.887852,0.932606,0.445568,0.388236,0.151609,0.205535,0.054267,0.061394,0.054267,1.803019
1,1,0.73263,1,1,0.257596,0.657368,0.492617,0.964238,0.800984,0.597208,0.255785,0.228566,0.73263,0.73263,1.689546
2,2,0.162856,1,1,0.455205,0.801058,0.041718,0.769458,0.003171,0.370382,0.223214,0.176016,0.162856,0.162856,1.256329
3,3,0.05034,1,1,0.292809,0.610914,0.913027,0.300115,0.248599,0.038464,0.409829,0.381909,0.05034,0.05034,1.241777
4,4,0.524607,1,0,0.666392,0.987533,0.46827,0.123287,0.916031,0.342961,0.79133,0.524607,1.121968,0.524607,1.516613


In [5]:
experiment_repeat_setups = pd.read_csv("synthetic_data/idx_split.csv").set_index("idx")
experiment_repeat_setups

Unnamed: 0_level_0,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,47390,5618,14210,46970,4203,16369,24535,45204,45725,45885
1,38566,46218,39045,7253,22759,34401,28889,38471,45822,37471
2,32814,20226,40012,4854,27351,39165,25359,14516,25717,29860
3,41393,39492,27153,19041,33009,19822,21243,41228,955,23901
4,12564,17823,48976,18458,22756,28169,45851,36620,29824,12711
...,...,...,...,...,...,...,...,...,...,...
49995,15948,39245,30779,48178,45056,4892,528,7486,31042,38267
49996,11102,29624,40779,3136,45904,41903,45682,36621,33204,38070
49997,16338,8986,19293,35651,10172,17947,38843,18310,2765,12581
49998,32478,32134,11955,36939,33266,41932,43910,21691,40801,33527


# Run Experiments

In [6]:
from models_causal_impute.meta_learners import TLearner, SLearner, XLearner
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

In [7]:
def prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list):
    split_results = {}

    for col in random_idx_col_list:
        random_idx = experiment_repeat_setups[col].values
        test_ids = random_idx[-5000:]
        train_ids = random_idx[:-5000]

        X_cols = [c for c in dataset_df.columns if c.startswith("X") and c[1:].isdigit()]
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df["W"].to_numpy()
        Y_train = train_df[["observed_time", "event"]].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df["W"].to_numpy()
        Y_test = test_df[["observed_time", "event"]].to_numpy()

        CATE_test_true = (test_df["T1"] - test_df["T0"]).to_numpy()

        split_results[col] = (X_train, W_train, Y_train, X_test, W_test, Y_test, CATE_test_true)

    return split_results

In [8]:
random_idx_col_list = ["random_idx0"]
failure_times_grid_size = 500
output_pickle_path = "causal_impute_s_learner_random_idx0_train_45000.pkl"

base_regressors = ['ridge', 'lasso', 'rf', 'gbr', 'xgb']

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios"):
        dataset_df = setup_dict[scenario_key]["dataset"]
        split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list)

        if "result" not in experiment_setups[setup_name][scenario_key]:
            experiment_setups[setup_name][scenario_key]["result"] = {}


        # For each base model, we will run the XLearner
        for base_model in tqdm(base_regressors, desc="Base Models", leave=False):
            # print(f"Running {base_model} for {setup_name} - {scenario_key}")
            
            # Store placeholder for later population
            experiment_setups[setup_name][scenario_key]["result"][base_model] = {}

            
            start_time = time.time()

            for col in random_idx_col_list:
                X_train, W_train, Y_train, X_test, W_test, Y_test, CATE_test_true = split_dict[col]

                ##################################
                # TODO: Imputation Step
                Y_train = Y_train[:, 0]  # observed_time
                Y_test  = Y_test[:, 0]   # observed_time
                ##################################
                
                
                learner = XLearner(base_model_name=base_model)
                learner.fit(X_train, W_train, Y_train)
                mse_test, CATE_test_pred = learner.evaluate(X_test, CATE_test_true)

                experiment_setups[setup_name][scenario_key]["result"][base_model][col] = {
                    "CATE_true": CATE_test_true,
                    "CATE_Pred": CATE_test_pred,
                    "CATE_MSE": mse_test
                }

            end_time = time.time()

            experiment_setups[setup_name][scenario_key]["result"][base_model]["average"] = {
                "mean_CATE_MSE": np.mean([experiment_setups[setup_name][scenario_key]["result"][base_model][col]["CATE_MSE"]
                                        for col in random_idx_col_list]),
                "std_CATE_MSE":  np.std( [experiment_setups[setup_name][scenario_key]["result"][base_model][col]["CATE_MSE"]
                                        for col in random_idx_col_list]),
                "runtime": (end_time - start_time) / len(random_idx_col_list)
            }

            # Save progress to disk
            # with open(output_pickle_path, "wb") as f:
                # pickle.dump(experiment_setups, f)
            
            # break
        # break
    # break
            


Experiment Setups:   0%|          | 0/6 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
RCT_0_5 Scenarios: 100%|██████████| 8/8 [22:55<00:00, 171.93s/it]
Experiment Setups:  17%|█▋        | 1/6 [22:55<1:54:37, 1375.47s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
RCT_0_05 Scenarios: 100%|██████████| 8/8 [23:06<00:00, 173.26s/it]
Experiment Setups:  33%|███▎      | 2/6 [46:01<1:32:06, 1381.71s/it]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
e_X Scenario

In [9]:
def summarize_experiment_results(experiment_setups, base_regressors):
    records = []

    for setup_name, setup_dict in experiment_setups.items():
        for scenario_key in setup_dict:
            row = {
                ("setup_name", ""): setup_name,
                ("scenario_key", ""): scenario_key
            }

            for base_model in base_regressors:
                avg_result = setup_dict[scenario_key].get("result", {}).get(base_model, {}).get("average", {})
                mean_mse = avg_result.get("mean_CATE_MSE", np.nan)
                std_mse = avg_result.get("std_CATE_MSE", np.nan)
                runtime = avg_result.get("runtime", np.nan)

                row[(base_model, "mean_CATE_mse")] = round(mean_mse, 4) if not pd.isna(mean_mse) else np.nan
                # row[(base_model, "std_CATE_mse")] = round(std_mse, 4) if not pd.isna(std_mse) else np.nan
                row[(base_model, "runtime [s]")] = round(runtime, 4) if not pd.isna(runtime) else np.nan

            records.append(row)

    df = pd.DataFrame.from_records(records)
    df.columns = pd.MultiIndex.from_tuples(df.columns)
    return df


In [10]:
print("X-Learner Results")
summary_df = summarize_experiment_results(experiment_setups, base_regressors)
summary_df

X-Learner Results


Unnamed: 0_level_0,setup_name,scenario_key,ridge,ridge,lasso,lasso,rf,rf,gbr,gbr,xgb,xgb
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean_CATE_mse,runtime [s],mean_CATE_mse,runtime [s],mean_CATE_mse,runtime [s],mean_CATE_mse,runtime [s],mean_CATE_mse,runtime [s]
0,RCT_0_5,scenario_1,0.6924,2.5586,0.6925,0.8475,0.6901,61.3826,0.6924,53.9787,0.6907,62.2001
1,RCT_0_5,scenario_2,15.8145,0.905,15.8148,1.3545,15.817,65.4204,15.8184,52.9082,15.814,43.4654
2,RCT_0_5,scenario_3,14.5101,0.9737,14.5086,1.2448,14.4913,65.5358,14.4916,50.2098,14.5028,51.2746
3,RCT_0_5,scenario_5,14.6206,0.7092,14.6227,1.6275,14.6572,65.7255,14.6751,57.7855,14.6502,63.5162
4,RCT_0_5,scenario_6,15.0243,0.5753,15.0315,2.1045,15.0477,67.9456,15.0553,51.3198,15.0533,75.5865
5,RCT_0_5,scenario_8,17.2597,0.5088,17.2528,0.5707,17.25,78.4729,17.251,50.0472,17.2497,40.6791
6,RCT_0_5,scenario_9,43.1339,0.4542,43.1349,0.5729,43.1167,62.2665,43.1247,46.5252,43.1267,43.7485
7,RCT_0_5,scenario_10,15.8457,0.5295,15.8513,0.6235,15.8161,68.8255,15.8703,53.3571,15.8799,26.7871
8,RCT_0_05,scenario_1,0.6919,1.603,0.6922,2.3165,0.6909,78.3963,0.7022,53.5324,0.6913,72.1096
9,RCT_0_05,scenario_2,15.8347,0.6236,15.8416,0.7501,15.8477,74.082,15.8843,53.9643,15.8812,13.7508
