In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

# Load Data

In [2]:
def load_scenario_data(h5_file_path, scenario_num):
    key = f"scenario_{scenario_num}/data"
    with pd.HDFStore(h5_file_path, mode='r') as store:
        if key not in store:
            return None  # Scenario not found
        df = store[key]
        metadata = store.get_storer(key).attrs.metadata
    return {"dataset": df, "metadata": metadata}

In [3]:
store_files = [
    "synthetic_data/RCT_0_5.h5",
    "synthetic_data/RCT_0_05.h5",
    "synthetic_data/e_X.h5",
    "synthetic_data/e_X_U.h5",
    "synthetic_data/e_X_no_overlap.h5",
    "synthetic_data/e_X_info_censor.h5",
    "synthetic_data/e_X_U_info_censor.h5",
    "synthetic_data/e_X_no_overlap_info_censor.h5"
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. RCT_0_5
    scenario_dict = {}
    for scenario in range(1, 11):
        try:
            result = load_scenario_data(path, scenario)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [4]:
experiment_setups['RCT_0_5']['scenario_1']['dataset'].head()

Unnamed: 0,id,observed_time,event,W,X1,X2,X3,X4,X5,U1,U2,T0,T1,T,C
0,0,0.054267,1,0,0.135488,0.887852,0.932606,0.445568,0.388236,0.151609,0.205535,0.054267,0.061394,0.054267,1.803019
1,1,0.73263,1,1,0.257596,0.657368,0.492617,0.964238,0.800984,0.597208,0.255785,0.228566,0.73263,0.73263,1.689546
2,2,0.162856,1,1,0.455205,0.801058,0.041718,0.769458,0.003171,0.370382,0.223214,0.176016,0.162856,0.162856,1.256329
3,3,0.05034,1,1,0.292809,0.610914,0.913027,0.300115,0.248599,0.038464,0.409829,0.381909,0.05034,0.05034,1.241777
4,4,0.524607,1,0,0.666392,0.987533,0.46827,0.123287,0.916031,0.342961,0.79133,0.524607,1.121968,0.524607,1.516613


In [5]:
experiment_repeat_setups = pd.read_csv("synthetic_data/idx_split.csv").set_index("idx")
experiment_repeat_setups

Unnamed: 0_level_0,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,47390,5618,14210,46970,4203,16369,24535,45204,45725,45885
1,38566,46218,39045,7253,22759,34401,28889,38471,45822,37471
2,32814,20226,40012,4854,27351,39165,25359,14516,25717,29860
3,41393,39492,27153,19041,33009,19822,21243,41228,955,23901
4,12564,17823,48976,18458,22756,28169,45851,36620,29824,12711
...,...,...,...,...,...,...,...,...,...,...
49995,15948,39245,30779,48178,45056,4892,528,7486,31042,38267
49996,11102,29624,40779,3136,45904,41903,45682,36621,33204,38070
49997,16338,8986,19293,35651,10172,17947,38843,18310,2765,12581
49998,32478,32134,11955,36939,33266,41932,43910,21691,40801,33527


# EXPERIMENT CONSTANTS

In [6]:
# NUM_REPEATS_TO_INCLUDE = 10  # max 10
# NUM_TRAINING_DATA_POINTS = 5000 # max 45000
TEST_SIZE = 5000
IMPUTATION_METHOD = "Pseudo_obs" # "Margin", "IPCW-T", "Pseudo_obs"

In [7]:
output_pickle_path = f"synthetic_data/imputed_times_lookup.pkl"

In [8]:
import pickle

if os.path.exists(output_pickle_path):
    print("Loading imputation times from existing file.")
    with open(output_pickle_path, 'rb') as f:
        imputed_times = pickle.load(f)
else:
    print("Imputation times not found, creating new file.")
    imputed_times = {}

Imputation times not found, creating new file.


# Run Experiments

In [9]:
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

In [10]:
def prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points=5000):
    split_results = {}

    for rand_idx in random_idx_col_list:
        random_idx = experiment_repeat_setups[rand_idx].values
        test_ids = random_idx[-TEST_SIZE:]
        train_ids = random_idx[:min(num_training_data_points, len(random_idx) - TEST_SIZE)]

        X_cols = [c for c in dataset_df.columns if c.startswith("X") and c[1:].isdigit()]
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df["W"].to_numpy()
        Y_train = train_df[["observed_time", "event"]].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df["W"].to_numpy()
        Y_test = test_df[["observed_time", "event"]].to_numpy()

        cate_test_true = (test_df["T1"] - test_df["T0"]).to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [11]:
imputation_method = IMPUTATION_METHOD

if imputed_times.get(imputation_method) is None:
    print(f"Imputation times not found for {imputation_method}, creating new entry.")
    imputed_times[imputation_method] = {}

Imputation times not found for Pseudo_obs, creating new entry.


In [12]:
# num_training_data_points_list = [50, 100, 200, 300, 400, 500, 700, 1000, 2000, 3000, 4000, 5000, 10000, 20000, 30000]
num_training_data_points_list = [50, 100, 200, 300, 500, 1000, 2000, 5000, 10000, 20000]
len(num_training_data_points_list)

10

In [13]:
random_idx_col_list = experiment_repeat_setups.columns.to_list()

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):

    # Check if imputed_times[imputation_method] has the setup_name
    if setup_name not in imputed_times[imputation_method]:
        print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
        imputed_times[imputation_method][setup_name] = {}

    for scenario_key in setup_dict:
        dataset_df = setup_dict[scenario_key]["dataset"]

        # check if imputed_times[imputation_method][setup_name] has the scenario_key
        if scenario_key not in imputed_times[imputation_method][setup_name]:
            print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
            imputed_times[imputation_method][setup_name][scenario_key] = {}

        for num_training_data_points in num_training_data_points_list:

            split_dict = prepare_data_split(dataset_df, experiment_repeat_setups, random_idx_col_list, num_training_data_points)

            # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
            if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
                print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
                imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

            for rand_idx in random_idx_col_list:
                X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

                # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
                if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
                    # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
                elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
                    "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
                    print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
                    continue

                start_time = time.time()

                Y_train_imputed = None #TODO: Imputation Step
                Y_test_imputed = None  #TODO: Imputation Ste

                end_time = time.time()

                imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
                    "Y_train_imputed": Y_train_imputed,
                    "Y_test_imputed": Y_test_imputed,
                    "runtime": end_time - start_time
                }
                print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
                      f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
            
        # Save progress to disk
        # with open(output_pickle_path, "wb") as f:
        #     pickle.dump(imputed_times, f)
        
        # break
    # break
            


Experiment Setups:   0%|          | 0/8 [00:00<?, ?it/s]

Creating new entry for 'RCT_0_5' in imputed times['Pseudo_obs'].
Creating new entry for 'scenario_1' in imputed times['Pseudo_obs']['RCT_0_5'].
Creating new entry for '50' in imputed times['Pseudo_obs']['RCT_0_5']['scenario_1'].
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_1', num_training: 50, random_idx6 in 0 seconds.
'Pseudo_obs' Imputatio

Experiment Setups:  12%|█▎        | 1/8 [00:01<00:11,  1.67s/it]

Creating new entry for '20000' in imputed times['Pseudo_obs']['RCT_0_5']['scenario_9'].
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_5', 'scenario_9', num_training: 20000, random_idx7 in 0 seconds.
'Pseudo_obs' Imputation complete

Experiment Setups:  25%|██▌       | 2/8 [00:03<00:09,  1.66s/it]

Creating new entry for '10000' in imputed times['Pseudo_obs']['RCT_0_05']['scenario_9'].
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'RCT_0_05', 'scenario_9', num_training: 10000, random_idx7 in 0 seconds.
'Pseudo_obs' Imputation

Experiment Setups:  38%|███▊      | 3/8 [00:04<00:08,  1.63s/it]

Creating new entry for '2000' in imputed times['Pseudo_obs']['e_X']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000, random_idx7 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X', 'scenario_9', num_training: 2000

Experiment Setups:  50%|█████     | 4/8 [00:06<00:06,  1.68s/it]

Creating new entry for '20000' in imputed times['Pseudo_obs']['e_X_U']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'scenario_9', num_training: 20000, random_idx7 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U', 'sc

Experiment Setups:  62%|██████▎   | 5/8 [00:08<00:05,  1.68s/it]

Creating new entry for '2000' in imputed times['Pseudo_obs']['e_X_no_overlap']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap', 'scenario_9', num_training: 2000, rand

Experiment Setups:  75%|███████▌  | 6/8 [00:10<00:03,  1.71s/it]

Creating new entry for '1000' in imputed times['Pseudo_obs']['e_X_info_censor']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_info_censor', 'scenario_9', num_training: 1

Experiment Setups:  88%|████████▊ | 7/8 [00:11<00:01,  1.76s/it]

Creating new entry for '2000' in imputed times['Pseudo_obs']['e_X_U_info_censor']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9', num_training: 2000, random_idx6 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_U_info_censor', 'scenario_9

Experiment Setups: 100%|██████████| 8/8 [00:13<00:00,  1.73s/it]

Creating new entry for '20000' in imputed times['Pseudo_obs']['e_X_no_overlap_info_censor']['scenario_9'].
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx0 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx1 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx2 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx3 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx4 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx5 in 0 seconds.
'Pseudo_obs' Imputation completed for 'e_X_no_overlap_info_censor', 'scenario_9', num_training: 20000, random_idx6 in 0




In [14]:
imputed_times['Pseudo_obs']

{'RCT_0_5': {'scenario_1': {50: {'random_idx0': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 7.152557373046875e-07},
    'random_idx1': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 2.384185791015625e-07},
    'random_idx2': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx3': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx4': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx5': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 2.384185791015625e-07},
    'random_idx6': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx7': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx8': {'Y_train_imputed': None,
     'Y_test_imputed': None,
     'runtime': 0.0},
    'random_idx9': {'Y_train_imputed': None,
     'Y_

# Imputation Methods

In [20]:
from dataclasses import InitVar, dataclass, field
import warnings
import numpy as np

@dataclass
class KaplanMeier:
    """
    This class is borrowed from survival_evaluation package.
    """
    event_times: InitVar[np.array]
    event_indicators: InitVar[np.array]
    survival_times: np.array = field(init=False)
    population_count: np.array = field(init=False)
    events: np.array = field(init=False)
    survival_probabilities: np.array = field(init=False)
    cumulative_dens: np.array = field(init=False)
    probability_dens: np.array = field(init=False)

    def __post_init__(self, event_times, event_indicators):
        index = np.lexsort((event_indicators, event_times))
        unique_times = np.unique(event_times[index], return_counts=True)
        self.survival_times = unique_times[0]
        self.population_count = np.flip(np.flip(unique_times[1]).cumsum())

        event_counter = np.append(0, unique_times[1].cumsum()[:-1])
        event_ind = list()
        for i in range(np.size(event_counter[:-1])):
            event_ind.append(event_counter[i])
            event_ind.append(event_counter[i + 1])
        event_ind.append(event_counter[-1])
        event_ind.append(len(event_indicators))
        self.events = np.add.reduceat(np.append(event_indicators[index], 0), event_ind)[::2]

        event_ratios = 1 - self.events / self.population_count
        self.survival_probabilities = np.cumprod(event_ratios)
        self.cumulative_dens = 1 - self.survival_probabilities
        self.probability_dens = np.diff(np.append(self.cumulative_dens, 1))

    def predict(self, prediction_times: np.array):
        probability_index = np.digitize(prediction_times, self.survival_times)
        probability_index = np.where(
            probability_index == self.survival_times.size + 1,
            probability_index - 1,
            probability_index,
        )
        probabilities = np.append(1, self.survival_probabilities)[probability_index]

        return probabilities


@dataclass
class KaplanMeierArea(KaplanMeier):
    area_times: np.array = field(init=False)
    area_probabilities: np.array = field(init=False)
    area: np.array = field(init=False)
    km_linear_zero: float = field(init=False)

    def __post_init__(self, event_times, event_indicators):
        super().__post_init__(event_times, event_indicators)
        area_probabilities = np.append(1, self.survival_probabilities)
        area_times = np.append(0, self.survival_times)
        self.km_linear_zero = area_times[-1] / (1 - area_probabilities[-1])
        if self.survival_probabilities[-1] != 0:
            area_times = np.append(area_times, self.km_linear_zero)
            area_probabilities = np.append(area_probabilities, 0)

        # we are facing the choice of using the trapzoidal rule or directly using the area under the step function
        # we choose to use trapz because it is more accurate
        area_diff = np.diff(area_times, 1)
        average_probabilities = (area_probabilities[0:-1] + area_probabilities[1:]) / 2
        area = np.flip(np.flip(area_diff * average_probabilities).cumsum())
        # area = np.flip(np.flip(area_diff * area_probabilities[0:-1]).cumsum())

        self.area_times = np.append(area_times, np.inf)
        self.area_probabilities = area_probabilities
        self.area = np.append(area, 0)

    @property
    def mean(self):
        return self.best_guess(np.array([0])).item()

    def best_guess(self, censor_times: np.array):
        # calculate the slope using the [0, 1] - [max_time, S(t|x)]
        slope = (1 - min(self.survival_probabilities)) / (0 - max(self.survival_times))
        # if after the last time point, then the best guess is the linear function
        before_last_idx = censor_times <= max(self.survival_times)
        after_last_idx = censor_times > max(self.survival_times)
        surv_prob = np.empty_like(censor_times).astype(float)
        surv_prob[after_last_idx] = 1 + censor_times[after_last_idx] * slope
        surv_prob[before_last_idx] = self.predict(censor_times[before_last_idx])
        # do not use np.clip(a_min=0) here because we will use surv_prob as the denominator,
        # if surv_prob is below 0 (or 1e-10 after clip), the nominator will be 0 anyway.
        surv_prob = np.clip(surv_prob, a_min=1e-10, a_max=None)

        censor_indexes = np.digitize(censor_times, self.area_times)
        censor_indexes = np.where(
            censor_indexes == self.area_times.size + 1,
            censor_indexes - 1,
            censor_indexes,
        )

        # for those beyond the end point, censor_area = 0
        beyond_idx = censor_indexes > len(self.area_times) - 2
        censor_area = np.zeros_like(censor_times).astype(float)
        # trapzoidal rule:  (x1 - x0) * (f(x0) + f(x1)) * 0.5
        censor_area[~beyond_idx] = ((self.area_times[censor_indexes[~beyond_idx]] - censor_times[~beyond_idx]) *
                                    (self.area_probabilities[censor_indexes[~beyond_idx]] + surv_prob[~beyond_idx])
                                    * 0.5)
        censor_area[~beyond_idx] += self.area[censor_indexes[~beyond_idx]]
        return censor_times + censor_area / surv_prob

    def _km_linear_predict(self, times):
        slope = (1 - min(self.survival_probabilities)) / (0 - max(self.survival_times))

        predict_prob = np.empty_like(times)
        before_last_time_idx = times <= max(self.survival_times)
        after_last_time_idx = times > max(self.survival_times)
        predict_prob[before_last_time_idx] = self.predict(times[before_last_time_idx])
        predict_prob[after_last_time_idx] = np.clip(1 + times[after_last_time_idx] * slope, a_min=0, a_max=None)
        # if time <= max(self.survival_times):
        #     predict_prob = self.predict(time)
        # else:
        #     predict_prob = max(1 + time * slope, 0)
        return predict_prob

    def _compute_best_guess(self, time: float, restricted: bool = False):
        """
        Given a censor time, compute the decensor event time based on the residual mean survival time on KM curves.
        :param time:
        :return:
        """
        # Using integrate.quad from Scipy should be more accurate, but also making the program unbearably slow.
        # The compromised method uses numpy.trapz to approximate the integral using composite trapezoidal rule.
        warnings.warn("This method is deprecated. Use best_guess instead.", DeprecationWarning)
        if restricted:
            last_time = max(self.survival_times)
        else:
            last_time = self.km_linear_zero
        time_range = np.linspace(time, last_time, 2000)
        if self.predict(time) == 0:
            best_guess = time
        else:
            best_guess = time + np.trapezoid(self._km_linear_predict(time_range), time_range) / self.predict(time)

        return best_guess

In [None]:
from sklearn.model_selection import KFold
import numpy as np
from tqdm import trange

class SurvivalEvalImputer:
    def __init__(self, imputation_method="Pseudo_obs", verbose=True):
        self.imputation_method = imputation_method
        self.verbose = verbose


    def fit_transform(self, Y_train, Y_test, impute_train=True):
        """
        Note in our setup, Y_test imputation is not important as we calculate CATE from X_test and W_test
        where our CATE estimator is trained using X_train, W_train and Y_train.
        Nevertheless, we still impute Y_test for the sake of consistency.
        """

        if self.imputation_method == "Pseudo_obs":
            return self._pseudo_obs_imputation(Y_train, Y_test, impute_train=impute_train)
        elif self.imputation_method == "Margin":
            return self._margin_imputation(Y_train, Y_test, impute_train=impute_train)
        elif self.imputation_method == "IPCW-T":
            return self._ipcw_t_imputation(Y_train, Y_test, impute_train=impute_train)
        else:
            raise ValueError(f"Unknown imputation method: {self.imputation_method}")


    def _km_mean(self, times: np.ndarray, survival_probabilities: np.ndarray) -> float:
        """
        Calculate the mean of the Kaplan-Meier curve.

        Parameters
        ----------
        times: np.ndarray, shape = (n_samples, )
            Survival times for KM curve of the testing samples
        survival_probabilities: np.ndarray, shape = (n_samples, )
            Survival probabilities for KM curve of the testing samples

        Returns
        -------
        The mean of the Kaplan-Meier curve.
        """
        # calculate the area under the curve for each interval
        area_probabilities = np.append(1, survival_probabilities)
        area_times = np.append(0, times)
        km_linear_zero = -1 / ((area_probabilities[-1] - 1) / area_times[-1])
        if survival_probabilities[-1] != 0:
            area_times = np.append(area_times, km_linear_zero)
            area_probabilities = np.append(area_probabilities, 0)
        area_diff = np.diff(area_times, 1)
        # we are using trap rule
        average_probabilities = (area_probabilities[0:-1] + area_probabilities[1:]) / 2
        area = np.flip(np.flip(area_diff * average_probabilities).cumsum())
        area = np.append(area, 0)
        # or the step function rule (deprecated for now)
        # area_subs = area_diff * area_probabilities[0:-1]
        # area_subs[-1] = area_subs[-1] / 2
        # area = np.flip(np.flip(area_subs).cumsum())

        # calculate the mean
        probability_index = np.digitize(0, times)
        surv_prob = np.append(1, survival_probabilities)[probability_index]

        return area[0] / surv_prob
    

    def _pseudo_obs_imputation_train(self, Y_train):
        """
        Pseudo-observation imputation method.
        Calculate the best guess time (surrogate time) by the contribution of the censored subjects to KM curve
        """
        event_times = Y_train[:, 0]
        event_indicators = (Y_train[:, 1]).astype(bool)
        max_horizon_time = max(event_times)

        best_guesses = event_times.copy().astype(float)

        for i in trange(len(event_times), desc="Calculating surrogate times for Pseudo-observation", disable=not self.verbose):
            if event_indicators[i] == 1:
                continue

            # train_event_times would be all the event times except the current one
            train_event_times = np.delete(event_times, i)
            train_event_indicators = np.delete(event_indicators, i)
            test_event_time = event_times[i]
            test_event_indicator = event_indicators[i]

            n_train = train_event_times.size

            km_model = KaplanMeierArea(train_event_times, train_event_indicators)

            # Survival eval extrapolates the KM curve to the right until survival probability reaches 0
            # km_linear_zero = km_model.km_linear_zero
            # We instead use the max time in the training set
            km_linear_zero = max_horizon_time

            events, population_counts = km_model.events.copy(), km_model.population_count.copy()
            times = km_model.survival_times.copy()
            probs = km_model.survival_probabilities.copy()

            # get the discrete time points where the event happens, then calculate the area under those discrete time only
            # this doesn't make any difference for step function, but it does for trapezoid rule.
            unique_idx = np.where(events != 0)[0]
            if unique_idx[-1] != len(events) - 1:
                unique_idx = np.append(unique_idx, len(events) - 1)
            times = times[unique_idx]
            population_counts = population_counts[unique_idx]
            events = events[unique_idx]
            probs = probs[unique_idx]
            sub_expect_time = self._km_mean(times.copy(), probs.copy())

            # use the idea of dynamic programming to calculate the multiplier of the KM estimator in advance.
            # if we add a new time point to the KM curve, the multiplier before the new time point will be
            # 1 - event_counts / (population_counts + 1), and the multiplier after the new time point will be
            # the same as before.
            multiplier = 1 - events / population_counts
            multiplier_total = 1 - events / (population_counts + 1)

            total_multiplier = multiplier.copy()
            insert_index = np.searchsorted(times, test_event_time, side='right')
            total_multiplier[:insert_index] = multiplier_total[:insert_index]
            survival_probabilities = np.cumprod(total_multiplier)
            if insert_index == len(times):
                times_addition = np.append(times, test_event_time)
                survival_probabilities_addition = np.append(survival_probabilities, survival_probabilities[-1])
                total_expect_time = self._km_mean(times_addition, survival_probabilities_addition)
            else:
                total_expect_time = self._km_mean(times, survival_probabilities)
            best_guesses[i] = (n_train + 1) * total_expect_time - n_train * sub_expect_time

        assert np.all(best_guesses >= 0), "Best guesses should be non-negative"
        assert np.all(best_guesses[Y_train[:, 1] == 0] >= event_times[Y_train[:, 1] == 0]), "Best guesses should be greater than or equal to censor times"
        assert np.all(best_guesses[Y_train[:, 1] == 1] == event_times[Y_train[:, 1] == 1]), "Best guesses should be less than or equal to event times"

        return best_guesses


    def _pseudo_obs_imputation(self, Y_train, Y_test, impute_train=True):
        """
        Pseudo-observation imputation method.
        Calculate the best guess time (surrogate time) by the contribution of the censored subjects to KM curve
        
        Note: We do not need to impute Y_test in our setup, but we still do it for consistency.

        :param Y_train: np.ndarray, shape = (n_samples, 2)
            The training set with observed time and event indicator
        :param Y_test: np.ndarray, shape = (n_samples, 2)
            The test set with observed time and event indicator
        :param impute_train: bool
            Whether to impute the training set

        :return best_guesses_train: np.ndarray, shape = (n_samples, )
            The imputed time for traing set.
            
            (if impute_train is False, the observed time for the training set is returned)
        :return best_guesses: np.ndarray, shape = (n_samples, )
            The imputed time for test set.
        """
        train_event_times = Y_train[:, 0]
        train_event_indicators = (Y_train[:, 1]).astype(bool)
        test_event_times = Y_test[:, 0]
        test_event_indicators = (Y_test[:, 1]).astype(bool)

        n_train = train_event_times.size
        n_test = test_event_times.size

        km_model = KaplanMeierArea(train_event_times, train_event_indicators)

        # Survival eval extrapolates the KM curve to the right until survival probability reaches 0
        # km_linear_zero = km_model.km_linear_zero
        # We instead use the max time in the training set
        km_linear_zero = max(km_model.survival_times)

        test_censor_times = test_event_times[~test_event_indicators]

        events, population_counts = km_model.events.copy(), km_model.population_count.copy()
        times = km_model.survival_times.copy()
        probs = km_model.survival_probabilities.copy()

        # get the discrete time points where the event happens, then calculate the area under those discrete time only
        # this doesn't make any difference for step function, but it does for trapezoid rule.
        unique_idx = np.where(events != 0)[0]
        if unique_idx[-1] != len(events) - 1:
            unique_idx = np.append(unique_idx, len(events) - 1)
        times = times[unique_idx]
        population_counts = population_counts[unique_idx]
        events = events[unique_idx]
        probs = probs[unique_idx]
        sub_expect_time = self._km_mean(times.copy(), probs.copy())


        # use the idea of dynamic programming to calculate the multiplier of the KM estimator in advance.
        # if we add a new time point to the KM curve, the multiplier before the new time point will be
        # 1 - event_counts / (population_counts + 1), and the multiplier after the new time point will be
        # the same as before.
        multiplier = 1 - events / population_counts
        multiplier_total = 1 - events / (population_counts + 1)
        best_guesses = test_event_times.copy().astype(float)

        for i in trange(n_test, desc="Calculating surrogate times for Pseudo-observation", disable=not self.verbose):
            if test_event_indicators[i] != 1:
                total_multiplier = multiplier.copy()
                insert_index = np.searchsorted(times, test_event_times[i], side='right')
                total_multiplier[:insert_index] = multiplier_total[:insert_index]
                survival_probabilities = np.cumprod(total_multiplier)
                if insert_index == len(times):
                    times_addition = np.append(times, test_event_times[i])
                    survival_probabilities_addition = np.append(survival_probabilities, survival_probabilities[-1])
                    total_expect_time = self._km_mean(times_addition, survival_probabilities_addition)
                else:
                    total_expect_time = self._km_mean(times, survival_probabilities)
                best_guesses[i] = (n_train + 1) * total_expect_time - n_train * sub_expect_time

        assert np.all(best_guesses >= 0), "Best guesses should be non-negative"
        assert np.all(best_guesses[Y_test[:, 1] == 0] >= test_censor_times), "Best guesses should be greater than or equal to censor times"
        assert np.all(best_guesses[Y_test[:, 1] == 1] == test_event_times[test_event_indicators]), "Best guesses should be less than or equal to event times"


        if impute_train:
            best_guesses_train = self._pseudo_obs_imputation_train(Y_train)
        else:
            best_guesses_train = Y_train[:, 0].copy()

        return best_guesses_train, best_guesses
    

    def _margin_imputation_train(self, Y_train, num_folds=5):
        """
        Margin imputation method.
        Calculate the best guess time (surrogate time) by the contribution of the censored subjects to KM curve
        The L1-margin method proposed by https://www.jmlr.org/papers/v21/18-772.html
        Calculate the best guess survival time given the KM curve and censoring time of that patient.
        :param Y_train: np.ndarray, shape = (n_samples, 2)
            The training set with observed time and event indicator
        :param num_folds: int
            The number of folds for cross-validation
        :return best_guesses: np.ndarray, shape = (n_samples, )
            The imputed time for traing set.
        """
        event_times = Y_train[:, 0]
        event_indicators = Y_train[:, 1].astype(bool)
        max_horizon_time = max(event_times)
        
        n = len(Y_train)
        best_guesses = event_times.copy().astype(float)

        # Only impute censored points
        censored_indices = np.where(~event_indicators)[0]
        kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

        # Split only censored indices
        for train_index, val_index in kf.split(censored_indices):
            censored_train_idx = censored_indices[train_index]
            censored_val_idx = censored_indices[val_index]
            
            # Build KM on full training data excluding the censored test fold
            km_train_idx = np.setdiff1d(np.arange(n), censored_val_idx)
            km_model = KaplanMeierArea(event_times[km_train_idx], event_indicators[km_train_idx])
            
            km_linear_zero = max_horizon_time
            val_censor_times = event_times[censored_val_idx]
            
            imputed_val = km_model.best_guess(val_censor_times)
            imputed_val[val_censor_times > km_linear_zero] = val_censor_times[val_censor_times > km_linear_zero]
            
            best_guesses[censored_val_idx] = imputed_val

        assert np.all(best_guesses >= 0), "Best guesses must be non-negative"
        assert np.all(best_guesses[~event_indicators] >= event_times[~event_indicators]), "Imputed must be ≥ censoring time"
        assert np.all(best_guesses[event_indicators] == event_times[event_indicators]), "Uncensored should match original"

        return best_guesses

    
    def _margin_imputation(self, Y_train, Y_test, impute_train=True):
        """
        Margin imputation method.
        Calculate the best guess time (surrogate time) by the contribution of the censored subjects to KM curve
        
        The L1-margin method proposed by https://www.jmlr.org/papers/v21/18-772.html
        
        Calculate the best guess survival time given the KM curve and censoring time of that patient.

        :param Y_train: np.ndarray, shape = (n_samples, 2)
            The training set with observed time and event indicator
        :param Y_test: np.ndarray, shape = (n_samples, 2)
            The test set with observed time and event indicator
        :param impute_train: bool
            Whether to impute the training set
        :return best_guesses_train: np.ndarray, shape = (n_samples, )
            The imputed time for traing set.
            (if impute_train is False, the observed time for the training set is returned)
        :return best_guesses: np.ndarray, shape = (n_samples, )
            The imputed time for test set.
        """
        train_event_times = Y_train[:, 0]
        train_event_indicators = (Y_train[:, 1]).astype(bool)
        test_event_times = Y_test[:, 0]
        test_event_indicators = (Y_test[:, 1]).astype(bool)

        n_train = train_event_times.size
        n_test = test_event_times.size

        km_model = KaplanMeierArea(train_event_times, train_event_indicators)

        # Survival eval extrapolates the KM curve to the right until survival probability reaches 0
        # km_linear_zero = km_model.km_linear_zero
        # We instead use the max time in the training set
        km_linear_zero = max(km_model.survival_times)

        test_censor_times = test_event_times[~test_event_indicators]

        # The L1-margin method proposed by https://www.jmlr.org/papers/v21/18-772.html
        # Calculate the best guess survival time given the KM curve and censoring time of that patient
        best_guesses_censored_data = km_model.best_guess(test_censor_times)
        best_guesses_censored_data[test_censor_times > km_linear_zero] = test_censor_times[test_censor_times > km_linear_zero]

        best_guesses = test_event_times.copy().astype(float)
        best_guesses[~test_event_indicators] = best_guesses_censored_data
        assert np.all(best_guesses >= 0), "Best guesses should be non-negative"
        assert np.all(best_guesses[Y_test[:, 1] == 0] >= test_censor_times), "Best guesses should be greater than or equal to censor times"
        assert np.all(best_guesses[Y_test[:, 1] == 1] == test_event_times[test_event_indicators]), "Best guesses should be less than or equal to event times"

        if impute_train:
            best_guesses_train = self._margin_imputation_train(Y_train)
        else:
            best_guesses_train = Y_train[:, 0].copy()
        
        return best_guesses_train, best_guesses
    


In [None]:
survival_eval_imputer = SurvivalEvalImputer(imputation_method="Margin", verbose=True)

aa, bb = survival_eval_imputer.fit_transform(Y_train, Y_test)

In [61]:
aa.shape

(20000,)