In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

### Load data

In [2]:
store_files = [
    "../real_data/twin.csv",
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. twin
    scenario_dict = {}
    for scenario in range(1, 2): # only one scenario per HIV data
        try:
            result = pd.read_csv(path)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [3]:
experiment_setups['twin']['scenario_1']

Unnamed: 0,idx,observed_time,event,T0,T1,T,C,W,true_cate,anemia,...,resstatb_4,mpcb_1,mpcb_2,mpcb_3,mpcb_4,mpcb_5,mpcb_6,mpcb_7,mpcb_8,mpcb_9
0,0,195.954330,0,365,365,365,195.954330,0,0,0,...,0,0,0,1,0,0,0,0,0,0
1,1,2.000000,1,365,2,2,20.717881,1,-363,0,...,0,0,0,0,0,0,0,0,0,0
2,2,55.907504,0,365,365,365,55.907504,0,0,0,...,0,0,1,0,0,0,0,0,0,0
3,3,4.168915,0,365,365,365,4.168915,1,0,0,...,0,0,1,0,0,0,0,0,0,0
4,4,0.000042,0,365,5,5,0.000042,1,-360,0,...,0,0,0,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11395,11395,15.080879,0,365,365,365,15.080879,1,0,0,...,0,0,1,0,0,0,0,0,0,0
11396,11396,233.524041,0,365,365,365,233.524041,0,0,0,...,0,1,0,0,0,0,0,0,0,0
11397,11397,117.663304,0,365,365,365,117.663304,1,0,1,...,0,0,1,0,0,0,0,0,0,0
11398,11398,30.417666,0,38,365,38,30.417666,0,327,0,...,0,0,1,0,0,0,0,0,0,0


In [4]:
experiment_repeat_setups = [pd.read_csv(f'../real_data/idx_split_twin.csv')]
experiment_repeat_setups[0]

Unnamed: 0,idx,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
0,0,4673,8193,2015,1620,11124,2792,2763,3687,4907,9965
1,1,8847,4943,7700,4045,5430,5809,2248,6564,10370,386
2,2,5554,1711,1268,1411,11188,10742,6328,1891,1937,9248
3,3,9208,2687,11191,2758,8578,10274,10448,2396,5476,4819
4,4,337,9283,9089,878,3289,3300,9654,3714,7741,2289
...,...,...,...,...,...,...,...,...,...,...,...
11395,11395,4892,3659,6350,6294,1565,10981,3671,10546,1215,8200
11396,11396,8346,8007,5825,11,7334,8926,6832,5698,6831,8875
11397,11397,4426,6524,2822,10916,6607,8329,4656,5140,4654,10104
11398,11398,3360,3408,5652,10744,8170,4830,4917,3441,10438,948


### EXPERIMENT CONSTANTS

In [5]:
NUM_REPEATS_TO_INCLUDE = 10  # max 10
# NUM_TRAINING_DATA_POINTS = 5000 # max 45000
TRAIN_SIZE = 0.5
VAL_SIZE = 0.25
TEST_SIZE = 0.25


imputation_methods_list = ['Pseudo_obs', 'Margin', 'IPCW-T']

In [6]:
output_pickle_path = f"../real_data/imputed_times_lookup_twin.pkl"

In [7]:
import pickle

if os.path.exists(output_pickle_path):
    print("Loading imputation times from existing file.")
    with open(output_pickle_path, 'rb') as f:
        imputed_times = pickle.load(f)
else:
    print("Imputation times not found, creating new file.")
    imputed_times = {}

Loading imputation times from existing file.


### Run Imputation Experiments

In [8]:
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from models_causal_impute.survival_eval_impute import SurvivalEvalImputer

In [9]:
def prepare_twin_data_split(dataset_df, X_cols, W_col, cate_base_col, experiment_repeat_setup):
    split_results = {}
    length = len(experiment_repeat_setup)
    
    for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
        y_cols = ['observed_time', 'event']
        # take the first half of the dataset for training and the second half for testing
        train_ids = experiment_repeat_setup[f'random_idx{rand_idx}'][:int(length*TRAIN_SIZE)].values
        test_ids =  experiment_repeat_setup[f'random_idx{rand_idx}'][int(length*TRAIN_SIZE):].values # this includes both validation and test data
        
        train_df = dataset_df[dataset_df['idx'].isin(train_ids)]
        test_df = dataset_df[dataset_df['idx'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df[W_col].to_numpy()
        Y_train = train_df[y_cols].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df[W_col].to_numpy()
        Y_test = test_df[y_cols].to_numpy()

        cate_test_true = test_df[cate_base_col].to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [10]:
X_binary_cols = ['anemia', 'cardiac', 'lung', 'diabetes', 'herpes', 'hydra',
       'hemo', 'chyper', 'phyper', 'eclamp', 'incervix', 'pre4000', 'preterm',
       'renal', 'rh', 'uterine', 'othermr', 
       'gestat', 'dmage', 'dmeduc', 'dmar', 'nprevist', 'adequacy']
X_num_cols = ['dtotord', 'cigar', 'drink', 'wtgain']
X_ohe_cols = ['pldel_2', 'pldel_3', 'pldel_4', 'pldel_5', 'resstatb_2', 'resstatb_3', 'resstatb_4', 
              'mpcb_1', 'mpcb_2', 'mpcb_3', 'mpcb_4', 'mpcb_5', 'mpcb_6', 'mpcb_7', 'mpcb_8', 'mpcb_9']

W = ['trt']
y_cols = ['observed_time_month', 'effect_non_censor'] # ['time', 'cid']

X_cols = X_binary_cols + X_num_cols + X_ohe_cols
W_col = ['W']
cate_true_col = 'true_cate'


In [11]:
start_time = end_time = 0

for imputation_method_idx in trange(len(imputation_methods_list), desc="Imputation Methods"):
    imputation_method = imputation_methods_list[imputation_method_idx]

    if imputed_times.get(imputation_method) is None:
        print(f"Imputation times not found for {imputation_method}, creating new entry.")
        imputed_times[imputation_method] = {}

    # for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    for setup_name, setup_dict in experiment_setups.items():

        experiment_repeat_setup = experiment_repeat_setups[0] # there's only one twin dataset
        # Check if imputed_times[imputation_method] has the setup_name
        if setup_name not in imputed_times[imputation_method]:
            print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
            imputed_times[imputation_method][setup_name] = {}

        for scenario_key in setup_dict:
            dataset_df = setup_dict[scenario_key]

            # check if imputed_times[imputation_method][setup_name] has the scenario_key
            if scenario_key not in imputed_times[imputation_method][setup_name]:
                print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
                imputed_times[imputation_method][setup_name][scenario_key] = {}

            for num_training_data_points in [f'{int(TRAIN_SIZE*100)}%']:

                split_dict = prepare_twin_data_split(dataset_df, X_cols, W_col, cate_true_col, experiment_repeat_setup)

                # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
                if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
                    # print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

                for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
                    X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

                    # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
                    if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
                        # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
                        imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
                    elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
                        "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
                        # print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
                        continue

                    start_time = time.time()

                    # impute the missing values
                    survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
                    Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

                    end_time = time.time()

                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
                        "Y_train_imputed": Y_train_imputed,
                        "Y_test_imputed": Y_test_imputed,
                        "runtime": end_time - start_time
                    }
                print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
                      f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
                
            # Save progress to disk
            with open(output_pickle_path, "wb") as f:
                pickle.dump(imputed_times, f)


Imputation Methods: 100%|██████████| 3/3 [00:00<00:00, 13.49it/s]

'Pseudo_obs' Imputation completed for 'twin', 'scenario_1', num_training: 50%, 9 in 0 seconds.
'Margin' Imputation completed for 'twin', 'scenario_1', num_training: 50%, 9 in 0 seconds.
'IPCW-T' Imputation completed for 'twin', 'scenario_1', num_training: 50%, 9 in 0 seconds.





In [12]:
imputed_times['Pseudo_obs']['twin']['scenario_1'].keys()

dict_keys(['50%'])

### Get True ATE

In [13]:
TRUE_ATE = {}
for setup_name, setup_dict in experiment_setups.items():
    for scenario_key in setup_dict:
        dataset_df = setup_dict[scenario_key]
        true_ate = dataset_df['true_cate'].mean()
        TRUE_ATE[(setup_name, scenario_key)] = true_ate

In [14]:
TRUE_ATE

{('twin', 'scenario_1'): 5.038157894736842}