In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Load Data

In [2]:
def load_scenario_data(h5_file_path, scenario_num):
    key = f"scenario_{scenario_num}/data"
    with pd.HDFStore(h5_file_path, mode='r') as store:
        if key not in store:
            # print(f"Scenario {scenario_num} not found in the HDF5 file.")
            return None  # Scenario not found
        df = store[key]
        metadata = store.get_storer(key).attrs.metadata
    return {"dataset": df, "metadata": metadata}

In [None]:
store_files = [
    "../synthetic_data/RCT_0_5.h5",
    "../synthetic_data/RCT_0_05.h5",
    "../synthetic_data/e_X.h5",
    "../synthetic_data/e_X_U.h5",
    "../synthetic_data/e_X_no_overlap.h5",
    "../synthetic_data/e_X_info_censor.h5",
    "../synthetic_data/e_X_U_info_censor.h5",
    "../synthetic_data/e_X_no_overlap_info_censor.h5"
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. RCT_0_5
    scenario_dict = {}
    # for scenario in range(1, 11):
    for scenario in ['A', 'B', 'C', 'D', 'E']:
        try:
            result = load_scenario_data(path, scenario)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            print(f"Error loading {path} scenario {scenario}: {e}")
            continue
    experiment_setups[base_name] = scenario_dict

In [None]:
experiment_setups['RCT_0_5']['scenario_B']['dataset'].head()

Unnamed: 0,id,observed_time,event,W,X1,X2,X3,X4,X5,U1,U2,T0,T1,T,C
0,0,0.054267,1,0,0.135488,0.887852,0.932606,0.445568,0.388236,0.151609,0.205535,0.054267,0.061394,0.054267,1.803019
1,1,0.73263,1,1,0.257596,0.657368,0.492617,0.964238,0.800984,0.597208,0.255785,0.228566,0.73263,0.73263,1.689546
2,2,0.162856,1,1,0.455205,0.801058,0.041718,0.769458,0.003171,0.370382,0.223214,0.176016,0.162856,0.162856,1.256329
3,3,0.05034,1,1,0.292809,0.610914,0.913027,0.300115,0.248599,0.038464,0.409829,0.381909,0.05034,0.05034,1.241777
4,4,0.524607,1,0,0.666392,0.987533,0.46827,0.123287,0.916031,0.342961,0.79133,0.524607,1.121968,0.524607,1.516613


In [5]:
experiment_repeat_setups = pd.read_csv("../synthetic_data/idx_split.csv").set_index("idx")
experiment_repeat_setups

Unnamed: 0_level_0,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,47390,5618,14210,46970,4203,16369,24535,45204,45725,45885
1,38566,46218,39045,7253,22759,34401,28889,38471,45822,37471
2,32814,20226,40012,4854,27351,39165,25359,14516,25717,29860
3,41393,39492,27153,19041,33009,19822,21243,41228,955,23901
4,12564,17823,48976,18458,22756,28169,45851,36620,29824,12711
...,...,...,...,...,...,...,...,...,...,...
49995,15948,39245,30779,48178,45056,4892,528,7486,31042,38267
49996,11102,29624,40779,3136,45904,41903,45682,36621,33204,38070
49997,16338,8986,19293,35651,10172,17947,38843,18310,2765,12581
49998,32478,32134,11955,36939,33266,41932,43910,21691,40801,33527


# EXPERIMENT CONSTANTS

In [6]:
num_repeats_to_include = 10  # max 10
num_training_data_points = 100 # max 45000
test_size = 5000
survival_model = 'DeepHit' # 
meta_learner_type = 't_learner_survival'
# load_imputed_values = True
# imputed_times_path = f"synthetic_data/imputed_times_lookup.pkl"

In [7]:
output_pickle_path = f"results/{meta_learner_type}_{survival_model}_num_repeats_{num_repeats_to_include}_train_size_{num_training_data_points}.pkl"

# Run Experiments

In [8]:
# from models_causal_impute.meta_learners import TLearner, SLearner, XLearner
# from models_causal_impute.survival_eval_impute import SurvivalEvalImputer
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from models_causal_survival_meta.meta_learners_survival import TLearnerSurvival, SLearnerSurvival, MatchingLearnerSurvival
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import os

In [9]:
def prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points=5000, test_size=5000):
    split_results = {}

    for rand_idx in random_idx_col_list:
        random_idx = experiment_repeat_setups[rand_idx].values
        test_ids = random_idx[-test_size:]
        train_ids = random_idx[:min(num_training_data_points, len(random_idx) - test_size)]

        X_cols = [c for c in dataset_df.columns if c.startswith("X") and c[1:].isdigit()]
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df["W"].to_numpy()
        Y_train = train_df[["observed_time", "event"]].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df["W"].to_numpy()
        Y_test = test_df[["observed_time", "event"]].to_numpy()

        cate_test_true = (test_df["T1"] - test_df["T0"]).to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [10]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

experiment_repeat_setups = pd.read_csv("../synthetic_data/idx_split.csv").set_index("idx")
random_idx_col_list = experiment_repeat_setups.columns.to_list()[:num_repeats_to_include]

output_pickle_path = f"results/{meta_learner_type}_{survival_model}_{'mean'}_repeats_{num_repeats_to_include}_train_{num_training_data_points}.pkl"
print("Output results path:", output_pickle_path)

# Define base survival models to use
base_model = survival_model
results_dict = {}

# Define hyperparameter grids for each model
hyperparameter_grids = {
    'RandomSurvivalForest': {
        'n_estimators': [50, 100],
        'min_samples_split': [5, 10],
        'min_samples_leaf': [3, 5]
    },
    'DeepSurv': {
        'num_nodes': [32, 64],
        'dropout': [0.1, 0.2],
        'lr': [0.01, 0.001],
        'epochs': [100, 500]
    },
    'DeepHit': {
        'num_nodes': [32, 64],
        'dropout': [0.1, 0.2],
        'lr': [0.01, 0.001],
        'epochs': [100, 500]
    }
}

counter = 0

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    results_dict[setup_name] = {}

    if setup_name != "e_X_no_overlap_info_censor":
        continue

    for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios"):

        # if scenario_key != "scenario_8":
        #     continue

        dataset_df = setup_dict[scenario_key]["dataset"]
        split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points)
        results_dict[setup_name][scenario_key] = {}

        start_time = time.time()

        for rand_idx in random_idx_col_list:
            # if rand_idx != "random_idx4":
                # continue

            X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]
            print(f"Processing {setup_name} {scenario_key} {rand_idx}")

            max_time = Y_train[:, 0].max()
            
            # Initialize the appropriate meta-learner
            if meta_learner_type == "t_learner_survival":
                learner = TLearnerSurvival(
                    base_model_name=base_model,
                    base_model_grid=hyperparameter_grids,
                    metric="mean",
                    max_time=max_time
                )
            elif meta_learner_type == "s_learner_survival":
                learner = SLearnerSurvival(
                    base_model_name=base_model,
                    base_model_grid=hyperparameter_grids,
                    metric="mean",
                    max_time=max_time
                )
            elif meta_learner_type == "matching_survival":
                learner = MatchingLearnerSurvival(
                    base_model_name=base_model,
                    base_model_grid=hyperparameter_grids,
                    metric="mean",
                    num_matches=5,
                    max_time=max_time
                )

            if meta_learner_type == "t_learner_survival":
                if Y_train[W_train == 1, 1].sum() <= 1:
                    print(f"[Warning]: For {meta_learner_type}, No event in treatment group. Skipping iteration {rand_idx}.")
                    continue
                if Y_train[W_train == 0, 1].sum() <= 1:
                    print(f"[Warning]: For {meta_learner_type}, No event in control group. Skipping iteration {rand_idx}.")
                    continue


            # Fit the learner
            learner.fit(X_train, W_train, Y_train)
            
            # Evaluate base survival models on test data
            base_model_eval = learner.evaluate_test(X_test, Y_test, W_test)
            
            # Evaluate causal effect predictions
            mse_test, cate_test_pred, ate_test_pred = learner.evaluate(X_test, cate_test_true, W_test)

            results_dict[setup_name][scenario_key][rand_idx] = {
                "cate_true": cate_test_true,
                "cate_pred": cate_test_pred,
                "ate_true": cate_test_true.mean(),
                "ate_pred": ate_test_pred,
                "cate_mse": mse_test,
                "ate_bias": ate_test_pred - cate_test_true.mean(),
                "base_model_eval": base_model_eval  # Store base model evaluation results
            }

        end_time = time.time()
        avg = results_dict[setup_name][scenario_key]
        if len(avg) == 0:
            base_model_eval_performance = {}
        else:
            base_model_eval_performance = {
                                            base_model_k: 
                                            {
                                                f"{stat}_{metric_j}": func([
                                                    avg[i]['base_model_eval'][base_model_k][metric_j] for i in random_idx_col_list
                                                    if i in avg
                                                ])
                                                for metric_j in metric_j_dict
                                                for stat, func in zip(['mean', 'std'], [np.nanmean, np.nanstd])
                                            }
                                            for base_model_k, metric_j_dict in avg[list(avg.keys())[0]]['base_model_eval'].items()
                                          }

        results_dict[setup_name][scenario_key]["average"] = {
            "mean_cate_mse": np.mean([avg[i]["cate_mse"] for i in random_idx_col_list if i in avg]),
            "std_cate_mse": np.std([avg[i]["cate_mse"] for i in random_idx_col_list if i in avg]),
            "mean_ate_pred": np.mean([avg[i]["ate_pred"] for i in random_idx_col_list if i in avg]),
            "std_ate_pred": np.std([avg[i]["ate_pred"] for i in random_idx_col_list if i in avg]),
            "mean_ate_true": np.mean([avg[i]["ate_true"] for i in random_idx_col_list if i in avg]),
            "std_ate_true": np.std([avg[i]["ate_true"] for i in random_idx_col_list if i in avg]),
            "mean_ate_bias": np.mean([avg[i]["ate_bias"] for i in random_idx_col_list if i in avg]),
            "std_ate_bias": np.std([avg[i]["ate_bias"] for i in random_idx_col_list if i in avg]),
            "runtime": (end_time - start_time) / len(avg) if len(avg) > 0 else 0,
            "base_model_eval" : base_model_eval_performance
            }
    
    # break

Output results path: results/t_learner_survival_DeepHit_mean_repeats_10_train_100.pkl


Experiment Setups:   0%|          | 0/8 [00:00<?, ?it/s]

Processing e_X_no_overlap_info_censor scenario_1 random_idx0
Processing e_X_no_overlap_info_censor scenario_1 random_idx1
Processing e_X_no_overlap_info_censor scenario_1 random_idx2
Processing e_X_no_overlap_info_censor scenario_1 random_idx3
Processing e_X_no_overlap_info_censor scenario_1 random_idx4
Processing e_X_no_overlap_info_censor scenario_1 random_idx5
Processing e_X_no_overlap_info_censor scenario_1 random_idx6
Processing e_X_no_overlap_info_censor scenario_1 random_idx7
Processing e_X_no_overlap_info_censor scenario_1 random_idx8
Processing e_X_no_overlap_info_censor scenario_1 random_idx9




Processing e_X_no_overlap_info_censor scenario_2 random_idx0
Processing e_X_no_overlap_info_censor scenario_2 random_idx1
Processing e_X_no_overlap_info_censor scenario_2 random_idx2
Processing e_X_no_overlap_info_censor scenario_2 random_idx3
Processing e_X_no_overlap_info_censor scenario_2 random_idx4
Processing e_X_no_overlap_info_censor scenario_2 random_idx5
Processing e_X_no_overlap_info_censor scenario_2 random_idx6
Processing e_X_no_overlap_info_censor scenario_2 random_idx7
Processing e_X_no_overlap_info_censor scenario_2 random_idx8
Processing e_X_no_overlap_info_censor scenario_2 random_idx9




Processing e_X_no_overlap_info_censor scenario_5 random_idx0
Processing e_X_no_overlap_info_censor scenario_5 random_idx1
Processing e_X_no_overlap_info_censor scenario_5 random_idx2
Processing e_X_no_overlap_info_censor scenario_5 random_idx3
Processing e_X_no_overlap_info_censor scenario_5 random_idx4
Processing e_X_no_overlap_info_censor scenario_5 random_idx5
Processing e_X_no_overlap_info_censor scenario_5 random_idx6
Processing e_X_no_overlap_info_censor scenario_5 random_idx7
Processing e_X_no_overlap_info_censor scenario_5 random_idx8
Processing e_X_no_overlap_info_censor scenario_5 random_idx9




Processing e_X_no_overlap_info_censor scenario_8 random_idx0
Processing e_X_no_overlap_info_censor scenario_8 random_idx1
Processing e_X_no_overlap_info_censor scenario_8 random_idx2
Processing e_X_no_overlap_info_censor scenario_8 random_idx3
Processing e_X_no_overlap_info_censor scenario_8 random_idx4
Processing e_X_no_overlap_info_censor scenario_8 random_idx5
Processing e_X_no_overlap_info_censor scenario_8 random_idx6
Processing e_X_no_overlap_info_censor scenario_8 random_idx7
Processing e_X_no_overlap_info_censor scenario_8 random_idx8
Processing e_X_no_overlap_info_censor scenario_8 random_idx9




Processing e_X_no_overlap_info_censor scenario_9 random_idx0
Processing e_X_no_overlap_info_censor scenario_9 random_idx1
Processing e_X_no_overlap_info_censor scenario_9 random_idx2
Processing e_X_no_overlap_info_censor scenario_9 random_idx3
Processing e_X_no_overlap_info_censor scenario_9 random_idx4
Processing e_X_no_overlap_info_censor scenario_9 random_idx5
Processing e_X_no_overlap_info_censor scenario_9 random_idx6
Processing e_X_no_overlap_info_censor scenario_9 random_idx7
Processing e_X_no_overlap_info_censor scenario_9 random_idx8
Processing e_X_no_overlap_info_censor scenario_9 random_idx9


e_X_no_overlap_info_censor Scenarios: 100%|██████████| 5/5 [12:47<00:00, 153.53s/it]
Experiment Setups: 100%|██████████| 8/8 [12:47<00:00, 95.96s/it]






In [11]:
avg

{'random_idx0': {'cate_true': array([ 0.24516762,  5.054912  , -0.65175434, ..., -1.31640703,
          1.40203994, -2.31674823]),
  'cate_pred': array([ 0.39975813,  0.01670147, -0.0560206 , ...,  0.02687146,
          0.15669542,  0.16576771]),
  'ate_true': 0.7514350882286424,
  'ate_pred': 0.12791718746913078,
  'cate_mse': 42.3903629312977,
  'ate_bias': -0.6235179007595115,
  'base_model_eval': {'treated': {'concordance_td': 0.6131713841874025,
    'integrated_brier_score': 0.22349982702599056},
   'control': {'concordance_td': 0.5375023556282554,
    'integrated_brier_score': 0.23038016911430076}}},
 'random_idx1': {'cate_true': array([ 3.07188362e-01, -9.94193658e-03, -1.07624696e+00, ...,
          2.06002923e+01, -3.64365556e-01,  1.21556760e+00]),
  'cate_pred': array([0.90999473, 1.75088956, 1.2808498 , ..., 1.34998186, 1.09740341,
         0.97394742]),
  'ate_true': 0.7082314320951666,
  'ate_pred': 0.9814416570335422,
  'cate_mse': 42.9500322421798,
  'ate_bias': 0.27321

In [12]:
Y_train[W_train == 1]

array([[3.31042832, 0.        ],
       [1.05478829, 0.        ],
       [0.59399601, 1.        ],
       [0.30636157, 0.        ],
       [0.2972404 , 1.        ],
       [1.17606344, 0.        ],
       [2.20752164, 0.        ],
       [0.18416644, 0.        ],
       [0.54998955, 1.        ],
       [1.42736017, 0.        ],
       [1.60473032, 0.        ],
       [1.47408764, 1.        ],
       [1.29709275, 1.        ],
       [0.26189765, 0.        ],
       [1.05090438, 0.        ],
       [2.68932442, 1.        ],
       [1.72891107, 0.        ],
       [1.144921  , 1.        ],
       [0.12230504, 1.        ],
       [1.44478399, 1.        ],
       [0.26931528, 0.        ],
       [2.14008607, 0.        ],
       [4.2538921 , 1.        ],
       [2.07064197, 0.        ],
       [1.52846452, 1.        ],
       [2.93907535, 1.        ],
       [0.76340821, 0.        ],
       [0.82585401, 1.        ],
       [0.38773147, 1.        ],
       [0.79787432, 1.        ],
       [2.

In [13]:
results_dict

{'RCT_0_5': {},
 'RCT_0_05': {},
 'e_X': {},
 'e_X_U': {},
 'e_X_no_overlap': {},
 'e_X_info_censor': {},
 'e_X_U_info_censor': {},
 'e_X_no_overlap_info_censor': {'scenario_1': {'random_idx0': {'cate_true': array([ 0.01964673,  0.67548137,  0.01256466, ..., -0.08345062,
            0.20826006, -0.22013693]),
    'cate_pred': array([ 0.07852027, -0.74223609, -0.86039045, ..., -0.55867068,
            0.1970808 , -0.75922389]),
    'ate_true': 0.12915102003284332,
    'ate_pred': 0.015578006788981486,
    'cate_mse': 1.457846684883849,
    'ate_bias': -0.11357301324386183,
    'base_model_eval': {'treated': {'concordance_td': 0.5683284137598212,
      'integrated_brier_score': 0.07956359080258539},
     'control': {'concordance_td': 0.5301772989457605,
      'integrated_brier_score': 0.1973636824931311}}},
   'random_idx1': {'cate_true': array([ 0.04373459, -0.01348651, -0.12558549, ...,  2.88443699,
           -0.03312028,  0.24602149]),
    'cate_pred': array([0.08602105, 0.14740382, 