In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Load Data

In [2]:
def load_scenario_data(h5_file_path, scenario_num):
    key = f"scenario_{scenario_num}/data"
    with pd.HDFStore(h5_file_path, mode='r') as store:
        if key not in store:
            return None  # Scenario not found
        df = store[key]
        metadata = store.get_storer(key).attrs.metadata
    return {"dataset": df, "metadata": metadata}

In [3]:
store_files = [
    "synthetic_data/RCT_0_5.h5",
    "synthetic_data/RCT_0_05.h5",
    "synthetic_data/e_X.h5",
    "synthetic_data/e_X_U.h5",
    "synthetic_data/e_X_info_censor.h5",
    "synthetic_data/e_X_U_info_censor.h5"
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. RCT_0_5
    scenario_dict = {}
    for scenario in range(1, 11):
        try:
            result = load_scenario_data(path, scenario)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [48]:
experiment_setups['RCT_0_5']['scenario_1']['dataset']

Unnamed: 0,id,observed_time,event,W,X1,X2,X3,X4,X5,U1,U2,T0,T1,T,C
0,0,0.054267,1,0,0.135488,0.887852,0.932606,0.445568,0.388236,0.151609,0.205535,0.054267,0.061394,0.054267,1.803019
1,1,0.732630,1,1,0.257596,0.657368,0.492617,0.964238,0.800984,0.597208,0.255785,0.228566,0.732630,0.732630,1.689546
2,2,0.162856,1,1,0.455205,0.801058,0.041718,0.769458,0.003171,0.370382,0.223214,0.176016,0.162856,0.162856,1.256329
3,3,0.050340,1,1,0.292809,0.610914,0.913027,0.300115,0.248599,0.038464,0.409829,0.381909,0.050340,0.050340,1.241777
4,4,0.524607,1,0,0.666392,0.987533,0.468270,0.123287,0.916031,0.342961,0.791330,0.524607,1.121968,0.524607,1.516613
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
49995,49995,0.281175,1,0,0.484593,0.998236,0.668208,0.070638,0.960140,0.497815,0.206792,0.281175,0.061038,0.281175,1.365563
49996,49996,0.029867,1,1,0.036391,0.268106,0.043117,0.426886,0.342038,0.812595,0.437775,0.163239,0.029867,0.029867,0.658388
49997,49997,0.077500,1,1,0.061915,0.411210,0.426204,0.414266,0.601355,0.116056,0.416950,0.200592,0.077500,0.077500,1.118571
49998,49998,0.423983,1,0,0.178390,0.656522,0.817355,0.347013,0.060741,0.201218,0.935754,0.423983,0.205703,0.423983,0.939400


In [32]:
summary_characteristics = {}

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    summary_characteristics[setup_name] = {}
    for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios", leave=False):
        dataset_df = setup_dict[scenario_key]["dataset"]
            
        # Store placeholder for later population
        cur_dataset_df = experiment_setups[setup_name][scenario_key]["dataset"]
        summary_characteristics[setup_name][scenario_key] = {'censoring_rate': 1-cur_dataset_df['event'].mean(),
                                                             'treatment_rate': cur_dataset_df['W'].mean(),
                                                             'event_time_min': cur_dataset_df['T'].min(),
                                                             'event_time_median': cur_dataset_df['T'].median(),
                                                             'event_time_max': cur_dataset_df['T'].max(),
                                                             'event_time_mean': cur_dataset_df['T'].mean(),
                                                             'event_time_std': cur_dataset_df['T'].std(),
                                                             'censoring_time_min': cur_dataset_df['C'].min(),
                                                             'censoring_time_median': cur_dataset_df['C'].median(),
                                                             'censoring_time_max': cur_dataset_df['C'].max(),
                                                             'censoring_time_mean': cur_dataset_df['C'].mean(),
                                                             'censoring_time_std': cur_dataset_df['C'].std(),
                                                             'ate': (cur_dataset_df['T1']-cur_dataset_df['T0']).mean(),
                                                             'cate_min': (cur_dataset_df['T1']-cur_dataset_df['T0']).min(),
                                                             'cate_median': (cur_dataset_df['T1']-cur_dataset_df['T0']).median(),
                                                             'cate_max': (cur_dataset_df['T1']-cur_dataset_df['T0']).max(),}
        

# Convert the summary_characteristics dictionary to a DataFrame
summary_df = pd.DataFrame.from_dict({(i, j): summary_characteristics[i][j] 
                                       for i in summary_characteristics.keys() 
                                       for j in summary_characteristics[i].keys()},
                                      orient='index').round(2)

summary_df

  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)
Experiment Setups: 100%|██████████| 6/6 [00:00<00:00, 39.85it/s]


Unnamed: 0,Unnamed: 1,censoring_rate,treatment_rate,event_time_min,event_time_median,event_time_max,event_time_mean,event_time_std,censoring_time_min,censoring_time_median,censoring_time_max,censoring_time_mean,censoring_time_std,ate,cate_min,cate_median,cate_max
RCT_0_5,scenario_1,0.07,0.5,0.0,0.21,21.86,0.4,0.64,0.01,1.61,10.22,1.86,1.17,0.12,-16.62,0.04,21.79
RCT_0_5,scenario_2,0.2,0.5,0.0,0.17,116.49,0.95,2.69,0.0,1.5,3.0,1.5,0.86,0.16,-80.55,-0.0,116.23
RCT_0_5,scenario_3,0.08,0.5,0.0,7.0,21.0,7.2,2.77,1.0,13.0,33.0,12.97,3.63,0.75,-16.0,1.0,20.0
RCT_0_5,scenario_5,0.39,0.5,0.0,7.0,21.0,7.2,2.77,1.0,inf,inf,inf,,0.75,-16.0,1.0,20.0
RCT_0_5,scenario_6,0.71,0.5,0.0,7.0,21.0,7.2,2.77,0.0,5.0,19.0,4.74,2.23,0.75,-16.0,1.0,20.0
RCT_0_5,scenario_8,0.92,0.5,0.0,8.0,23.0,8.2,2.94,0.0,3.0,12.0,3.01,1.73,0.75,-18.0,1.0,20.0
RCT_0_5,scenario_9,0.91,0.5,0.01,1.67,151.25,3.16,4.95,0.0,0.29,3.3,0.36,0.28,0.72,-122.23,0.1,150.79
RCT_0_5,scenario_10,0.7,0.5,0.0,0.17,116.49,0.95,2.69,0.0,0.03,inf,inf,,0.16,-80.55,-0.0,116.23
RCT_0_05,scenario_1,0.04,0.05,0.0,0.19,17.01,0.34,0.51,0.02,2.17,10.22,2.34,1.27,0.12,-16.62,0.04,21.79
RCT_0_05,scenario_2,0.2,0.05,0.0,0.17,81.23,0.88,2.27,0.0,1.5,3.0,1.5,0.86,0.16,-80.55,-0.0,116.23


In [5]:
experiment_repeat_setups = pd.read_csv("synthetic_data/idx_split.csv").set_index("idx")
experiment_repeat_setups

Unnamed: 0_level_0,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,47390,5618,14210,46970,4203,16369,24535,45204,45725,45885
1,38566,46218,39045,7253,22759,34401,28889,38471,45822,37471
2,32814,20226,40012,4854,27351,39165,25359,14516,25717,29860
3,41393,39492,27153,19041,33009,19822,21243,41228,955,23901
4,12564,17823,48976,18458,22756,28169,45851,36620,29824,12711
...,...,...,...,...,...,...,...,...,...,...
49995,15948,39245,30779,48178,45056,4892,528,7486,31042,38267
49996,11102,29624,40779,3136,45904,41903,45682,36621,33204,38070
49997,16338,8986,19293,35651,10172,17947,38843,18310,2765,12581
49998,32478,32134,11955,36939,33266,41932,43910,21691,40801,33527


# Run Experiments

In [6]:
from models_causal_survival.causal_survival_forest import CausalSurvivalForestGRF
from sklearn.metrics import mean_squared_error
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

In [7]:
def prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list):
    split_results = {}

    for col in random_idx_col_list:
        random_idx = experiment_repeat_setups[col].values
        test_ids = random_idx[-5000:]
        train_ids = random_idx[:-5000]

        X_cols = [c for c in dataset_df.columns if c.startswith("X") and c[1:].isdigit()]
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df["W"].to_numpy()
        Y_train = train_df[["observed_time", "event"]].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df["W"].to_numpy()
        Y_test = test_df[["observed_time", "event"]].to_numpy()

        CATE_test_true = (test_df["T1"] - test_df["T0"]).to_numpy()

        split_results[col] = (X_train, W_train, Y_train, X_test, W_test, Y_test, CATE_test_true)

    return split_results

In [8]:
random_idx_col_list = ["random_idx0"]
failure_times_grid_size = 500
output_pickle_path = "causal_survival_forest_random_idx0_train_45000.pkl"

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
     for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios", leave=False):
        dataset_df = setup_dict[scenario_key]["dataset"]
        split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list)

        if "result" not in experiment_setups[setup_name][scenario_key]:
            experiment_setups[setup_name][scenario_key]["result"] = {}

        start_time = time.time()

        for col in random_idx_col_list:
            X_train, W_train, Y_train, X_test, W_test, Y_test, CATE_test_true = split_dict[col]

            # Store placeholder for later population
            experiment_setups[setup_name][scenario_key]["result"][col] = {}

            # Train the model
            csf = CausalSurvivalForestGRF(failure_times_grid_size=failure_times_grid_size)
            csf.fit(X_train, W_train, Y_train)

            # Predict CATE
            CATE_test_pred = csf.predict_cate(X_test, W_test)

            # Save results
            experiment_setups[setup_name][scenario_key]["result"][col] = {
                "CATE_true": CATE_test_true,
                "CATE_Pred": CATE_test_pred,
                "CATE_MSE": mean_squared_error(CATE_test_true, CATE_test_pred)
            }

        end_time = time.time()

        experiment_setups[setup_name][scenario_key]["result"]["average"] = {
            "mean_CATE_MSE": np.mean([experiment_setups[setup_name][scenario_key]["result"][col]["CATE_MSE"]
                                      for col in random_idx_col_list]),
            "std_CATE_MSE": np.std([experiment_setups[setup_name][scenario_key]["result"][col]["CATE_MSE"]
                                     for col in random_idx_col_list]),
            "runtime": (end_time - start_time) / len(random_idx_col_list)
        }

        # Save progress to disk
        with open(output_pickle_path, "wb") as f:
            pickle.dump(experiment_setups, f)

    #     break
    # break
            


R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

Experiment Setups: 100%|██████████| 6/6 [3:38:00<00:00, 2180.02s/it]


In [9]:
def summarize_experiment_results(experiment_setups):
    records = []

    for setup_name, setup_dict in experiment_setups.items():
        for scenario_key in setup_dict:
            avg_result = setup_dict[scenario_key].get("result", {}).get("average", {})
            mean_mse = avg_result.get("mean_CATE_MSE", np.nan)
            std_mse = avg_result.get("std_CATE_MSE", np.nan)
            runtime = avg_result.get("runtime", np.nan)

            records.append({
                "setup_name": setup_name,
                "scenario_key": scenario_key,
                "mean_CATE_mse": round(mean_mse, 4) if not pd.isna(mean_mse) else np.nan,
                "std_CATE_mse": round(std_mse, 4) if not pd.isna(std_mse) else np.nan,
                "runtime [s]": round(runtime, 4) if not pd.isna(runtime) else np.nan
            })

    df = pd.DataFrame.from_records(records)
    return df

In [10]:
output_pickle_path = "causal_survival_forest_random_idx0_train_45000.pkl"
with open(output_pickle_path, "rb") as f:
    experiment_setups = pickle.load(f)

In [13]:
summary_df = summarize_experiment_results(experiment_setups)

# summary_df without std_CATE_mse column
summary_df.drop(columns=["std_CATE_mse"])

Unnamed: 0,setup_name,scenario_key,mean_CATE_mse,runtime [s]
0,RCT_0_5,scenario_1,0.7282,158.8549
1,RCT_0_5,scenario_2,18.4623,35.9607
2,RCT_0_5,scenario_3,14.5105,222.0538
3,RCT_0_5,scenario_5,14.519,80.4551
4,RCT_0_5,scenario_6,14.6071,44.7158
5,RCT_0_5,scenario_8,16.8886,727.4688
6,RCT_0_5,scenario_9,42.8919,724.4563
7,RCT_0_5,scenario_10,15.4362,149.9056
8,RCT_0_05,scenario_1,0.684,247.1146
9,RCT_0_05,scenario_2,16.6639,74.7034
