In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

### Load data

In [None]:
store_files = [
    "../real_data/twin.csv",
    "../real_data/twin30.csv",
    "../real_data/twin180.csv",
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. twin
    scenario_dict = {}
    for scenario in range(1, 2): # only one scenario per HIV data
        try:
            result = pd.read_csv(path)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [3]:
experiment_setups['twin180']['scenario_1']

Unnamed: 0,idx,observed_time,event,T0,T1,T,C,W,true_cate,anemia,...,resstatb_4,mpcb_1,mpcb_2,mpcb_3,mpcb_4,mpcb_5,mpcb_6,mpcb_7,mpcb_8,mpcb_9
0,0,180.000000,1,180,180,180,195.954330,0,0,0,...,0,0,0,1,0,0,0,0,0,0
1,1,2.000000,1,180,2,2,20.717881,1,-178,0,...,0,0,0,0,0,0,0,0,0,0
2,2,55.907504,0,180,180,180,55.907504,0,0,0,...,0,0,1,0,0,0,0,0,0,0
3,3,4.168915,0,180,180,180,4.168915,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,4,0.000042,0,180,5,5,0.000042,1,-175,0,...,0,0,0,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11395,11395,15.080879,0,180,180,180,15.080879,1,0,0,...,0,0,1,0,0,0,0,0,0,0
11396,11396,180.000000,1,180,180,180,233.524041,0,0,0,...,0,1,0,0,0,0,0,0,0,0
11397,11397,117.663304,0,180,180,180,117.663304,1,0,1,...,0,0,1,0,0,0,0,0,0,0
11398,11398,30.417666,0,38,180,38,30.417666,0,142,0,...,0,0,1,0,0,0,0,0,0,0


In [4]:
experiment_repeat_setups = [pd.read_csv(f'../real_data/idx_split_twin.csv')]
experiment_repeat_setups[0]

Unnamed: 0,idx,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
0,0,4673,8193,2015,1620,11124,2792,2763,3687,4907,9965
1,1,8847,4943,7700,4045,5430,5809,2248,6564,10370,386
2,2,5554,1711,1268,1411,11188,10742,6328,1891,1937,9248
3,3,9208,2687,11191,2758,8578,10274,10448,2396,5476,4819
4,4,337,9283,9089,878,3289,3300,9654,3714,7741,2289
...,...,...,...,...,...,...,...,...,...,...,...
11395,11395,4892,3659,6350,6294,1565,10981,3671,10546,1215,8200
11396,11396,8346,8007,5825,11,7334,8926,6832,5698,6831,8875
11397,11397,4426,6524,2822,10916,6607,8329,4656,5140,4654,10104
11398,11398,3360,3408,5652,10744,8170,4830,4917,3441,10438,948


In [5]:
TRUE_ATE = {('twin', 'scenario_1'): 5.038157894736842,
            ('twin30', 'scenario_1'): 0.32824561403508773,
            ('twin180', 'scenario_1'): 2.260438596491228}

### EXPERIMENT CONSTANTS

In [6]:
NUM_REPEATS_TO_INCLUDE = 10
TRAIN_SIZE = 0.5
VAL_SIZE = 0.25
TEST_SIZE = 0.25

horizon = 365
min_node_size = 18
train_size = 0.5

output_pickle_path = f"../results/real_data/models_causal_survival/causal_survival_forest/"
output_pickle_path += f"twin_causal_survival_forest_repeats_{NUM_REPEATS_TO_INCLUDE}.pkl"

### Run Experiments

In [7]:
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from models_causal_survival.causal_survival_forest import CausalSurvivalForestGRF
from sklearn.metrics import mean_squared_error
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

In [8]:
def prepare_twin_data_split(dataset_df, X_cols, W_col, cate_base_col, experiment_repeat_setup):
    split_results = {}
    length = len(experiment_repeat_setup)
    
    for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
        y_cols = ['observed_time', 'event']
        # take the first half of the dataset for training and the second half for testing
        train_ids = experiment_repeat_setup[f'random_idx{rand_idx}'][:int(length*TRAIN_SIZE)].values
        test_ids =  experiment_repeat_setup[f'random_idx{rand_idx}'][int(length*TRAIN_SIZE):].values # this includes both validation and test data
        
        train_df = dataset_df[dataset_df['idx'].isin(train_ids)]
        test_df = dataset_df[dataset_df['idx'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df[W_col].to_numpy().flatten()
        Y_train = train_df[y_cols].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df[W_col].to_numpy().flatten()
        Y_test = test_df[y_cols].to_numpy()

        cate_test_true = test_df[cate_base_col].to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [9]:
X_binary_cols = ['anemia', 'cardiac', 'lung', 'diabetes', 'herpes', 'hydra',
       'hemo', 'chyper', 'phyper', 'eclamp', 'incervix', 'pre4000', 'preterm',
       'renal', 'rh', 'uterine', 'othermr', 
       'gestat', 'dmage', 'dmeduc', 'dmar', 'nprevist', 'adequacy']
X_num_cols = ['dtotord', 'cigar', 'drink', 'wtgain']
X_ohe_cols = ['pldel_2', 'pldel_3', 'pldel_4', 'pldel_5', 'resstatb_2', 'resstatb_3', 'resstatb_4', 
              'mpcb_1', 'mpcb_2', 'mpcb_3', 'mpcb_4', 'mpcb_5', 'mpcb_6', 'mpcb_7', 'mpcb_8', 'mpcb_9']

y_cols = ['observed_time_month', 'effect_non_censor'] # ['time', 'cid']

X_cols = X_binary_cols + X_num_cols + X_ohe_cols
W_col = ['W']
cate_true_col = 'true_cate'


In [10]:
failure_times_grid_size = 200
# failure_times_grid: non-uniform discretization – 
# i.e. resolution of days in the first 30 days and months after the first 30 days
failure_times_grid = np.concatenate([np.arange(0, 30), np.arange(30, 365, 30)]) # every day for 1 month, then every month
print("Failure times grid:", failure_times_grid)
horizon = 365
min_node_size = 18

print("Output results path:", output_pickle_path)

if os.path.exists(output_pickle_path):
    print(f"Pickle file already exists. Loading from {output_pickle_path}...")
    with open(output_pickle_path, "rb") as f:
        results_dict = pickle.load(f)
else:
    results_dict = {}

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    if setup_name == "twin30":
        horizon = 30
    elif setup_name == "twin180":
        horizon = 180
    else:
        pass
    if setup_name in results_dict:
        print(f"Skipping setup {setup_name} as it already exists in results.")
        continue
    results_dict[setup_name] = {}
    experiment_repeat_setup = experiment_repeat_setups[0]

    for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios", leave=False):
        dataset_df = setup_dict[scenario_key]
        split_dict = prepare_twin_data_split(dataset_df, X_cols, W_col, cate_true_col, 
                                             experiment_repeat_setup)

        # Initialize results dictionary for this setup and scenario
        results_dict[setup_name][scenario_key] = {}

        start_time = time.time()

        for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
            X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]
            # take first half of test set as validation set
            X_val, W_val, Y_val = X_test[:int(len(dataset_df)*VAL_SIZE)], W_test[:int(len(dataset_df)*VAL_SIZE)], Y_test[:int(len(dataset_df)*VAL_SIZE)]
            cate_val_true = cate_test_true[:int(len(dataset_df)*VAL_SIZE)]
            X_test, W_test, Y_test = X_test[int(len(dataset_df)*VAL_SIZE):], W_test[int(len(dataset_df)*VAL_SIZE):], Y_test[int(len(dataset_df)*VAL_SIZE):]
            cate_test_true = cate_test_true[int(len(dataset_df)*VAL_SIZE):]

            # Store placeholder for later population
            results_dict[setup_name][scenario_key][rand_idx] = {}

            # Train the model
            csf = CausalSurvivalForestGRF(failure_times_grid_size=failure_times_grid_size, 
                                          horizon=horizon, min_node_size=min_node_size, seed=2025+rand_idx)
            csf.fit(X_train, W_train, Y_train, failure_times_grid=failure_times_grid)

            ate_true =     TRUE_ATE.get((setup_name, scenario_key), cate_test_true.mean())
            ate_true_val = TRUE_ATE.get((setup_name, scenario_key), cate_val_true.mean())

            # Predict CATE
            # cate_test_pred = csf.predict_cate(X_test, W_test)
            mse_test, cate_test_pred, ate_test_pred = csf.evaluate(X_test, cate_test_true, W_test)
            mse_val, cate_val_pred, ate_val_pred = csf.evaluate(X_val, cate_val_true, W_val)

            # Save results
            results_dict[setup_name][scenario_key][rand_idx] = {
                "cate_true": cate_test_true,
                "cate_pred": cate_test_pred,
                "ate_true": ate_true,
                "ate_pred": ate_test_pred,
                "cate_mse": mse_test,
                "ate_bias": ate_test_pred - ate_true,
                # val set:
                "cate_true_val": cate_val_true,
                "cate_pred": cate_val_pred,
                "ate_true_val": ate_true_val,
                "ate_pred_val": ate_val_pred,
                "cate_mse_val": mse_val,
                "ate_bias_val": ate_val_pred - ate_true_val,

            }

        end_time = time.time()

        # Save results to the setup dictionary
        results_dict[setup_name][scenario_key]["average"] = {
            "mean_cate_mse": np.mean([results_dict[setup_name][scenario_key][rand_idx]["cate_mse"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_cate_mse": np.std([results_dict[setup_name][scenario_key][rand_idx]["cate_mse"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_pred": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_pred"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_pred": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_pred"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_true": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_true"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_true": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_true"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_bias": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_bias"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_bias": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_bias"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),

            "mean_cate_mse_val": np.mean([results_dict[setup_name][scenario_key][rand_idx]["cate_mse_val"]
                                            for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_cate_mse_val": np.std([results_dict[setup_name][scenario_key][rand_idx]["cate_mse_val"]
                                        for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_pred_val": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_pred_val"]
                                          for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_pred_val": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_pred_val"]
                                        for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_true_val": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_true_val"]
                                          for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_true_val": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_true_val"]
                                        for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_bias_val": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_bias_val"]
                                          for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_bias_val": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_bias_val"]
                                        for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),


            "runtime": (end_time - start_time) / len(range(NUM_REPEATS_TO_INCLUDE))
        }

        # Save progress to disk
        with open(output_pickle_path, "wb") as f:
            pickle.dump(results_dict, f)

Failure times grid: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  60  90 120 150 180
 210 240 270 300 330 360]
Output results path: ../results/real_data/models_causal_survival/causal_survival_forest/twin_causal_survival_forest_repeats_10.pkl
Pickle file already exists. Loading from ../results/real_data/models_causal_survival/causal_survival_forest/twin_causal_survival_forest_repeats_10.pkl...


R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

R[write to console]: 
 

Experiment Setups: 100%|██████████| 2/2 [02:18<00:00, 69.17s/it]


In [11]:
def summarize_experiment_results(results_dict):
    records = []

    for setup_name, setup_dict in results_dict.items():
        for scenario_key in setup_dict:
            avg_result = setup_dict[scenario_key].get("average", {})
            mean_mse = avg_result.get("mean_cate_mse", np.nan)
            std_mse = avg_result.get("std_cate_mse", np.nan)
            mean_ate_pred = avg_result.get("mean_ate_pred", np.nan)
            std_ate_pred = avg_result.get("std_ate_pred", np.nan)
            mean_ate_true = avg_result.get("mean_ate_true", np.nan)
            std_ate_true = avg_result.get("std_ate_true", np.nan)
            mean_ate_bias = avg_result.get("mean_ate_bias", np.nan)
            std_ate_bias = avg_result.get("std_ate_bias", np.nan)

            mean_ate_pred_val = avg_result.get("mean_ate_pred_val", np.nan)
            std_ate_pred_val = avg_result.get("std_ate_pred_val", np.nan)
            mean_ate_true_val = avg_result.get("mean_ate_true_val", np.nan)
            std_ate_true_val = avg_result.get("std_ate_true_val", np.nan)
            mean_ate_bias_val = avg_result.get("mean_ate_bias_val", np.nan)
            std_ate_bias_val = avg_result.get("std_ate_bias_val", np.nan)
            mean_cate_mse_val = avg_result.get("mean_cate_mse_val", np.nan)
            std_cate_mse_val = avg_result.get("std_cate_mse_val", np.nan)
            mean_ate_bias_val = avg_result.get("mean_ate_bias_val", np.nan)
            std_ate_bias_val = avg_result.get("std_ate_bias_val", np.nan)

            runtime = avg_result.get("runtime", np.nan)

            records.append({
                "setup_name": setup_name,
                "scenario_key": scenario_key,
                "CATE_MSE": f"{mean_mse:.3f} ± {std_mse:.3f}" if not pd.isna(mean_mse) else np.nan,
                "ATE_pred": f"{mean_ate_pred:.3f} ± {std_ate_pred:.3f}" if not pd.isna(mean_ate_pred) else np.nan,
                "ATE_true": f"{mean_ate_true:.3f} ± {std_ate_true:.3f}" if not pd.isna(mean_ate_true) else np.nan,
                "ATE_bias": f"{mean_ate_bias:.3f} ± {std_ate_bias:.3f}" if not pd.isna(mean_ate_bias) else np.nan,
                
                "CATE_MSE_val": f"{mean_cate_mse_val:.3f} ± {std_cate_mse_val:.3f}" if not pd.isna(mean_cate_mse_val) else np.nan,
                "ATE_pred_val": f"{mean_ate_pred_val:.3f} ± {std_ate_pred_val:.3f}" if not pd.isna(mean_ate_pred_val) else np.nan,
                "ATE_true_val": f"{mean_ate_true_val:.3f} ± {std_ate_true_val:.3f}" if not pd.isna(mean_ate_true_val) else np.nan,
                "ATE_bias_val": f"{mean_ate_bias_val:.3f} ± {std_ate_bias_val:.3f}" if not pd.isna(mean_ate_bias_val) else np.nan,
                "ATE_bias_val": f"{mean_ate_bias_val:.3f} ± {std_ate_bias_val:.3f}" if not pd.isna(mean_ate_bias_val) else np.nan,
                "runtime [s]": round(runtime) if not pd.isna(runtime) else np.nan
            })

    df = pd.DataFrame.from_records(records)
    return df

In [12]:
summary_df = summarize_experiment_results(results_dict)
summary_df

Unnamed: 0,setup_name,scenario_key,CATE_MSE,ATE_pred,ATE_true,ATE_bias,CATE_MSE_val,ATE_pred_val,ATE_true_val,ATE_bias_val,runtime [s]
0,twin,scenario_1,10864.104 ± 411.275,9.235 ± 4.103,5.038 ± 0.000,4.197 ± 4.103,11450.804 ± 442.565,9.163 ± 4.054,5.038 ± 0.000,4.125 ± 4.054,7
1,twin180,scenario_1,2444.692 ± 100.620,6.146 ± 2.539,2.260 ± 0.000,3.886 ± 2.539,2605.536 ± 114.053,5.996 ± 2.570,2.260 ± 0.000,3.736 ± 2.570,7
2,twin30,scenario_1,75.075 ± 13.229,2.167 ± 1.962,0.328 ± 0.000,1.839 ± 1.962,79.151 ± 12.305,2.030 ± 1.962,0.328 ± 0.000,1.702 ± 1.962,7


In [13]:
np.sqrt(75.075)

8.664583082872483

In [16]:
np.sqrt(2444.692)

49.44382671274545

In [14]:
np.sqrt(10864.104)

104.2310126593808