In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Load Data

In [2]:
def load_scenario_data(h5_file_path, scenario_num):
    key = f"scenario_{scenario_num}/data"
    with pd.HDFStore(h5_file_path, mode='r') as store:
        if key not in store:
            return None  # Scenario not found
        df = store[key]
        metadata = store.get_storer(key).attrs.metadata
    return {"dataset": df, "metadata": metadata}

In [3]:
store_files = [
    "../synthetic_data/RCT_0_5.h5",
    "../synthetic_data/RCT_0_05.h5",
    "../synthetic_data/e_X.h5",
    "../synthetic_data/e_X_U.h5",
    "../synthetic_data/e_X_no_overlap.h5",
    "../synthetic_data/e_X_info_censor.h5",
    "../synthetic_data/e_X_U_info_censor.h5",
    "../synthetic_data/e_X_no_overlap_info_censor.h5"
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. RCT_0_5
    scenario_dict = {}
    for scenario in range(1, 11):
        try:
            result = load_scenario_data(path, scenario)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [5]:
experiment_setups['RCT_0_5']['scenario_1']['dataset'].head()

Unnamed: 0,id,observed_time,event,W,X1,X2,X3,X4,X5,U1,U2,T0,T1,T,C
0,0,0.054267,1,0,0.135488,0.887852,0.932606,0.445568,0.388236,0.151609,0.205535,0.054267,0.061394,0.054267,1.803019
1,1,0.73263,1,1,0.257596,0.657368,0.492617,0.964238,0.800984,0.597208,0.255785,0.228566,0.73263,0.73263,1.689546
2,2,0.162856,1,1,0.455205,0.801058,0.041718,0.769458,0.003171,0.370382,0.223214,0.176016,0.162856,0.162856,1.256329
3,3,0.05034,1,1,0.292809,0.610914,0.913027,0.300115,0.248599,0.038464,0.409829,0.381909,0.05034,0.05034,1.241777
4,4,0.524607,1,0,0.666392,0.987533,0.46827,0.123287,0.916031,0.342961,0.79133,0.524607,1.121968,0.524607,1.516613


In [6]:
experiment_repeat_setups = pd.read_csv("../synthetic_data/idx_split.csv").set_index("idx")
experiment_repeat_setups

Unnamed: 0_level_0,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,47390,5618,14210,46970,4203,16369,24535,45204,45725,45885
1,38566,46218,39045,7253,22759,34401,28889,38471,45822,37471
2,32814,20226,40012,4854,27351,39165,25359,14516,25717,29860
3,41393,39492,27153,19041,33009,19822,21243,41228,955,23901
4,12564,17823,48976,18458,22756,28169,45851,36620,29824,12711
...,...,...,...,...,...,...,...,...,...,...
49995,15948,39245,30779,48178,45056,4892,528,7486,31042,38267
49996,11102,29624,40779,3136,45904,41903,45682,36621,33204,38070
49997,16338,8986,19293,35651,10172,17947,38843,18310,2765,12581
49998,32478,32134,11955,36939,33266,41932,43910,21691,40801,33527


# EXPERIMENT CONSTANTS

In [7]:
# NUM_REPEATS_TO_INCLUDE = 10  # max 10
# NUM_TRAINING_DATA_POINTS = 5000 # max 45000
TEST_SIZE = 5000
# IMPUTATION_METHOD = "Pseudo_obs" # "Margin", "IPCW-T", "Pseudo_obs"

In [8]:
imputation_methods_list = ['Pseudo_obs', 'Margin', 'IPCW-T']

In [9]:
# num_training_data_points_list = [50, 100, 200, 300, 400, 500, 700, 1000, 2000, 3000, 4000, 5000, 10000, 20000, 30000]
num_training_data_points_list = [200, 300, 500, 1000, 2000, 5000, 10000, 20000]
len(num_training_data_points_list)

8

In [10]:
output_pickle_path = f"../synthetic_data/imputed_times_lookup.pkl"

In [11]:
import pickle

if os.path.exists(output_pickle_path):
    print("Loading imputation times from existing file.")
    with open(output_pickle_path, 'rb') as f:
        imputed_times = pickle.load(f)
else:
    print("Imputation times not found, creating new file.")
    imputed_times = {}

Loading imputation times from existing file.


# Run Imputation Experiments

In [12]:
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from models_causal_impute.survival_eval_impute import SurvivalEvalImputer

In [13]:
def prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points=5000, test_size=5000):
    split_results = {}

    for rand_idx in random_idx_col_list:
        random_idx = experiment_repeat_setups[rand_idx].values
        test_ids = random_idx[-test_size:]
        train_ids = random_idx[:min(num_training_data_points, len(random_idx) - test_size)]

        X_cols = [c for c in dataset_df.columns if c.startswith("X") and c[1:].isdigit()]
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df["W"].to_numpy()
        Y_train = train_df[["observed_time", "event"]].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df["W"].to_numpy()
        Y_test = test_df[["observed_time", "event"]].to_numpy()

        cate_test_true = (test_df["T1"] - test_df["T0"]).to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [14]:
# random_idx_col_list = experiment_repeat_setups.columns.to_list()
# start_time = end_time = 0

# for imputation_method_idx in trange(len(imputation_methods_list), desc="Imputation Methods"):
#     imputation_method = imputation_methods_list[imputation_method_idx]

#     if imputed_times.get(imputation_method) is None:
#         print(f"Imputation times not found for {imputation_method}, creating new entry.")
#         imputed_times[imputation_method] = {}

#     # for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
#     for setup_name, setup_dict in experiment_setups.items():

#         # Check if imputed_times[imputation_method] has the setup_name
#         if setup_name not in imputed_times[imputation_method]:
#             print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
#             imputed_times[imputation_method][setup_name] = {}

#         for scenario_key in setup_dict:
#             dataset_df = setup_dict[scenario_key]["dataset"]

#             # check if imputed_times[imputation_method][setup_name] has the scenario_key
#             if scenario_key not in imputed_times[imputation_method][setup_name]:
#                 print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
#                 imputed_times[imputation_method][setup_name][scenario_key] = {}

#             for num_training_data_points in num_training_data_points_list:

#                 split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points)

#                 # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
#                 if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
#                     # print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
#                     imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

#                 for rand_idx in random_idx_col_list:
#                     X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

#                     # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
#                     if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
#                         # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
#                         imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
#                     elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
#                         "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
#                         # print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
#                         continue

#                     start_time = time.time()

#                     # impute the missing values
#                     survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
#                     Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

#                     end_time = time.time()

#                     imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
#                         "Y_train_imputed": Y_train_imputed,
#                         "Y_test_imputed": Y_test_imputed,
#                         "runtime": end_time - start_time
#                     }
#                 print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
#                       f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
                
#             # Save progress to disk
#             with open(output_pickle_path, "wb") as f:
#                 pickle.dump(imputed_times, f)
            
#             # break
#         # break
            


In [None]:
# start_time = time.time()

# # impute the missing values
# survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
# Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

# end_time = time.time()

# Update Pseudo_obs Imputation Times For e_X_U and e_X_U_info_censor

In [23]:
imputed_times['Pseudo_obs']['e_X_U'].keys()

dict_keys(['scenario_1', 'scenario_2', 'scenario_5', 'scenario_8', 'scenario_9'])

In [None]:
random_idx_col_list = experiment_repeat_setups.columns.to_list()
start_time = end_time = 0

for imputation_method_idx in trange(len(imputation_methods_list), desc="Imputation Methods"):
    imputation_method = imputation_methods_list[imputation_method_idx]


    ####################################################################
    if imputation_method != "Pseudo_obs":
        print(f"[No Correction Needed] Skipping imputation method: {imputation_method}.")
        continue
    else:
        print(f"Correcting imputation method: {imputation_method}")
    ####################################################################


    if imputed_times.get(imputation_method) is None:
        print(f"Imputation times not found for {imputation_method}, creating new entry.")
        imputed_times[imputation_method] = {}

    # for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    for setup_name, setup_dict in experiment_setups.items():


        ####################################################################
        if setup_name in ['e_X_U', 'e_X_U_info_censor']:
            print("Correcting imputation for updated version of setup: ", setup_name)
            imputed_times[imputation_method][setup_name] = {}
            # continue
        else:
            print("(No Correction Needed) Skipping imputation for setup: ", setup_name)
            continue
        ####################################################################


        # Check if imputed_times[imputation_method] has the setup_name
        if setup_name not in imputed_times[imputation_method]:
            print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
            imputed_times[imputation_method][setup_name] = {}

        for scenario_key in setup_dict:
            dataset_df = setup_dict[scenario_key]["dataset"]

            # check if imputed_times[imputation_method][setup_name] has the scenario_key
            if scenario_key not in imputed_times[imputation_method][setup_name]:
                print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
                imputed_times[imputation_method][setup_name][scenario_key] = {}

            for num_training_data_points in num_training_data_points_list:

                split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points)

                # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
                if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
                    # print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

                for rand_idx in random_idx_col_list:
                    X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

                    # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
                    if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
                        # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
                        imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
                    elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
                        "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
                        # print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
                        continue

                    start_time = time.time()

                    # impute the missing values
                    survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
                    Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

                    end_time = time.time()

                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
                        "Y_train_imputed": Y_train_imputed,
                        "Y_test_imputed": Y_test_imputed,
                        "runtime": end_time - start_time
                    }
                print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
                      f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
                
            # Save progress to disk
            with open(output_pickle_path, "wb") as f:
                pickle.dump(imputed_times, f)
            
            # break
        # break
            


Imputation Methods: 100%|██████████| 3/3 [00:00<00:00, 9313.78it/s]

Correcting imputation method: Pseudo_obs
(No Correction Needed) Skipping imputation for setup:  RCT_0_5
(No Correction Needed) Skipping imputation for setup:  RCT_0_05
(No Correction Needed) Skipping imputation for setup:  e_X
Correcting imputation for updated version of setup:  e_X_U
(No Correction Needed) Skipping imputation for setup:  e_X_no_overlap
(No Correction Needed) Skipping imputation for setup:  e_X_info_censor
Correcting imputation for updated version of setup:  e_X_U_info_censor
(No Correction Needed) Skipping imputation for setup:  e_X_no_overlap_info_censor
[No Correction Needed] Skipping imputation method: Margin.
[No Correction Needed] Skipping imputation method: IPCW-T.





# Update IPCW-T Imputation Times For All Scenarios

In [25]:
imputed_times['IPCW-T'].keys()

dict_keys(['RCT_0_5', 'RCT_0_05', 'e_X', 'e_X_U', 'e_X_no_overlap', 'e_X_info_censor', 'e_X_U_info_censor', 'e_X_no_overlap_info_censor'])

In [None]:
random_idx_col_list = experiment_repeat_setups.columns.to_list()
start_time = end_time = 0

for imputation_method_idx in trange(len(imputation_methods_list), desc="Imputation Methods"):
    imputation_method = imputation_methods_list[imputation_method_idx]

    ####################################################################
    if imputation_method != "IPCW-T":
        print(f"[No Correction Needed] Skipping imputation method: {imputation_method}.")
        continue
    else:
        print(f"Correcting imputation method: {imputation_method}")
        imputed_times[imputation_method] = {}
    ####################################################################

    if imputed_times.get(imputation_method) is None:
        print(f"Imputation times not found for {imputation_method}, creating new entry.")
        imputed_times[imputation_method] = {}

    # for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    for setup_name, setup_dict in experiment_setups.items():

        # Check if imputed_times[imputation_method] has the setup_name
        if setup_name not in imputed_times[imputation_method]:
            print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
            imputed_times[imputation_method][setup_name] = {}

        for scenario_key in setup_dict:
            dataset_df = setup_dict[scenario_key]["dataset"]

            # check if imputed_times[imputation_method][setup_name] has the scenario_key
            if scenario_key not in imputed_times[imputation_method][setup_name]:
                print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
                imputed_times[imputation_method][setup_name][scenario_key] = {}

            for num_training_data_points in num_training_data_points_list:

                split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points)

                # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
                if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
                    # print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

                for rand_idx in random_idx_col_list:
                    X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

                    # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
                    if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
                        # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
                        imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
                    elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
                        "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
                        # print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
                        continue

                    start_time = time.time()

                    # impute the missing values
                    survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
                    Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

                    end_time = time.time()

                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
                        "Y_train_imputed": Y_train_imputed,
                        "Y_test_imputed": Y_test_imputed,
                        "runtime": end_time - start_time
                    }
                print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
                      f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
                
            # Save progress to disk
            with open(output_pickle_path, "wb") as f:
                pickle.dump(imputed_times, f)
            
            # break
        # break
            


Imputation Methods: 100%|██████████| 3/3 [00:00<00:00, 26829.24it/s]

[No Correction Needed] Skipping imputation method: Pseudo_obs.
[No Correction Needed] Skipping imputation method: Margin.
Correcting imputation method: IPCW-T



