In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

### Load data

In [2]:
store_files = [
    "../real_data/ACTG_175_HIV1.csv",
    "../real_data/ACTG_175_HIV2.csv",
    "../real_data/ACTG_175_HIV3.csv",
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. ACTG_175_HIV1
    scenario_dict = {}
    for scenario in range(1, 2): # only one scenario per HIV data
        try:
            result = pd.read_csv(path)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [3]:
experiment_setups['ACTG_175_HIV1']['scenario_1']

Unnamed: 0,id,observed_time_month,effect_non_censor,trt,z30,gender,race,hemo,homo,drugs,...,e1,e2,e3,e4,e5,e6,e7,e8,e9,cate_base
0,1,30.000000,1,0,1,1,0,0,1,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.788055
1,2,30.000000,1,1,1,1,0,0,1,1,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.450623
2,3,26.466667,1,0,1,1,0,0,1,0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.516565
3,4,30.000000,1,0,1,1,0,0,1,1,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.994492
4,5,6.266667,1,0,1,1,0,0,1,0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.468248
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1049,1050,5.133333,1,0,1,1,0,1,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.586317
1050,1051,19.600000,1,0,1,1,0,1,0,0,...,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,2.132294
1051,1052,30.000000,1,1,0,0,0,1,0,0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.734120
1052,1053,13.166667,0,0,1,1,1,1,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.265178


In [5]:
experiment_repeat_setups = [pd.read_csv(f'../real_data/idx_split_HIV{i}.csv') for i in range(1, 4)]
experiment_repeat_setups[0]

Unnamed: 0,idx,random_idx0,random_idx1,random_idx2,random_idx3,random_idx4,random_idx5,random_idx6,random_idx7,random_idx8,random_idx9
0,0,149,935,151,727,1002,396,553,971,933,398
1,1,887,653,215,709,611,1043,858,424,490,602
2,2,1021,26,406,345,914,746,41,701,507,25
3,3,219,433,896,191,579,321,54,932,726,435
4,4,539,356,448,113,928,973,76,314,590,137
...,...,...,...,...,...,...,...,...,...,...,...
1049,1049,1034,333,856,963,835,253,528,956,223,648
1050,1050,160,77,59,835,0,796,626,417,111,900
1051,1051,323,952,984,635,848,943,902,355,322,293
1052,1052,862,794,861,622,498,972,364,214,436,759


### EXPERIMENT CONSTANTS

In [None]:
NUM_REPEATS_TO_INCLUDE = 10  # max 10
# NUM_TRAINING_DATA_POINTS = 5000 # max 45000
TRAIN_SIZE = 0.75


imputation_methods_list = ['Pseudo_obs', 'Margin', 'IPCW-T']

In [13]:
output_pickle_path = f"../real_data/imputed_times_lookup.pkl"

In [14]:
import pickle

if os.path.exists(output_pickle_path):
    print("Loading imputation times from existing file.")
    with open(output_pickle_path, 'rb') as f:
        imputed_times = pickle.load(f)
else:
    print("Imputation times not found, creating new file.")
    imputed_times = {}

Imputation times not found, creating new file.


### Run Imputation Experiments

In [15]:
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from models_causal_impute.survival_eval_impute import SurvivalEvalImputer

In [None]:
def prepare_actg_data_split(dataset_df, X_cols, W_col, cate_base_col, experiment_repeat_setup):
    split_results = {}
    length = len(experiment_repeat_setup)
    
    for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
        y_cols = [f't{rand_idx}', f'e{rand_idx}']
        # take the first half of the dataset for training and the second half for testing
        train_ids = experiment_repeat_setup[f'random_idx{rand_idx}'][:int(length*TRAIN_SIZE)].values
        test_ids =  experiment_repeat_setup[f'random_idx{rand_idx}'][int(length*TRAIN_SIZE):].values
        # test_ids = dataset_df['id'] # same as train_ids
        # train_ids = dataset_df['id']
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df[W_col].to_numpy()
        Y_train = train_df[y_cols].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df[W_col].to_numpy()
        Y_test = test_df[y_cols].to_numpy()

        cate_test_true = test_df[cate_base_col].to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results

In [17]:
X_bi_cols = ['gender', 'race', 'hemo', 'homo', 'drugs', 'str2', 'symptom']
X_cont_cols = ['age', 'wtkg',  'karnof', 'cd40', 'cd80']
U = ['z30']
W = ['trt']
y_cols = ['observed_time_month', 'effect_non_censor'] # ['time', 'cid']

X_cols = X_bi_cols + X_cont_cols
W_col = W[0]
cate_base_col = 'cate_base'


In [None]:
start_time = end_time = 0

for imputation_method_idx in trange(len(imputation_methods_list), desc="Imputation Methods"):
    imputation_method = imputation_methods_list[imputation_method_idx]

    if imputed_times.get(imputation_method) is None:
        print(f"Imputation times not found for {imputation_method}, creating new entry.")
        imputed_times[imputation_method] = {}

    # for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
    for setup_name, setup_dict in experiment_setups.items():

        hiv_dataset_idx = int(setup_name[-1])
        experiment_repeat_setup = experiment_repeat_setups[hiv_dataset_idx-1]
        # Check if imputed_times[imputation_method] has the setup_name
        if setup_name not in imputed_times[imputation_method]:
            print(f"Creating new entry for '{setup_name}' in imputed times['{imputation_method}'].")
            imputed_times[imputation_method][setup_name] = {}

        for scenario_key in setup_dict:
            dataset_df = setup_dict[scenario_key]

            # check if imputed_times[imputation_method][setup_name] has the scenario_key
            if scenario_key not in imputed_times[imputation_method][setup_name]:
                print(f"Creating new entry for '{scenario_key}' in imputed times['{imputation_method}']['{setup_name}'].")
                imputed_times[imputation_method][setup_name][scenario_key] = {}

            for num_training_data_points in [f'{int(TRAIN_SIZE*100)}%']:

                split_dict = prepare_actg_data_split(dataset_df, X_cols, W_col, cate_base_col, experiment_repeat_setup)

                # check if imputed_times[imputation_method][setup_name][scenario_key] has the num_training_data_points
                if num_training_data_points not in imputed_times[imputation_method][setup_name][scenario_key]:
                    # print(f"Creating new entry for '{num_training_data_points}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'].")
                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] = {}

                for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
                    X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

                    # Check if imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] has the rand_idx
                    if rand_idx not in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points]:
                        # print(f"Creating new entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}']['{num_training_data_points}'].")
                        imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {}
                    elif rand_idx in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points] and \
                        "runtime" in imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx]:
                        # print(f"Skipping existing entry for '{rand_idx}' in imputed times['{imputation_method}']['{setup_name}']['{scenario_key}'][{num_training_data_points}].")
                        continue

                    start_time = time.time()

                    # impute the missing values
                    survival_imputer = SurvivalEvalImputer(imputation_method=imputation_method, verbose=False)
                    Y_train_imputed, Y_test_imputed = survival_imputer.fit_transform(Y_train, Y_test)

                    end_time = time.time()

                    imputed_times[imputation_method][setup_name][scenario_key][num_training_data_points][rand_idx] = {
                        "Y_train_imputed": Y_train_imputed,
                        "Y_test_imputed": Y_test_imputed,
                        "runtime": end_time - start_time
                    }
                print(f"'{imputation_method}' Imputation completed for '{setup_name}', '{scenario_key}', " +
                      f"num_training: {num_training_data_points}, {rand_idx} in {end_time - start_time:.0f} seconds.")
                
            # Save progress to disk
            with open(output_pickle_path, "wb") as f:
                pickle.dump(imputed_times, f)


Imputation Methods:   0%|          | 0/3 [00:00<?, ?it/s]

Imputation times not found for Pseudo_obs, creating new entry.
Creating new entry for 'ACTG_175_HIV1' in imputed times['Pseudo_obs'].
Creating new entry for 'scenario_1' in imputed times['Pseudo_obs']['ACTG_175_HIV1'].
'Pseudo_obs' Imputation completed for 'ACTG_175_HIV1', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Creating new entry for 'ACTG_175_HIV2' in imputed times['Pseudo_obs'].
Creating new entry for 'scenario_1' in imputed times['Pseudo_obs']['ACTG_175_HIV2'].
'Pseudo_obs' Imputation completed for 'ACTG_175_HIV2', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Creating new entry for 'ACTG_175_HIV3' in imputed times['Pseudo_obs'].
Creating new entry for 'scenario_1' in imputed times['Pseudo_obs']['ACTG_175_HIV3'].


Imputation Methods:  33%|███▎      | 1/3 [00:07<00:14,  7.37s/it]

'Pseudo_obs' Imputation completed for 'ACTG_175_HIV3', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Imputation times not found for Margin, creating new entry.
Creating new entry for 'ACTG_175_HIV1' in imputed times['Margin'].
Creating new entry for 'scenario_1' in imputed times['Margin']['ACTG_175_HIV1'].
'Margin' Imputation completed for 'ACTG_175_HIV1', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Creating new entry for 'ACTG_175_HIV2' in imputed times['Margin'].
Creating new entry for 'scenario_1' in imputed times['Margin']['ACTG_175_HIV2'].
'Margin' Imputation completed for 'ACTG_175_HIV2', 'scenario_1', num_training: 75%, 9 in 0 seconds.


Imputation Methods:  67%|██████▋   | 2/3 [00:07<00:03,  3.23s/it]

Creating new entry for 'ACTG_175_HIV3' in imputed times['Margin'].
Creating new entry for 'scenario_1' in imputed times['Margin']['ACTG_175_HIV3'].
'Margin' Imputation completed for 'ACTG_175_HIV3', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Imputation times not found for IPCW-T, creating new entry.
Creating new entry for 'ACTG_175_HIV1' in imputed times['IPCW-T'].
Creating new entry for 'scenario_1' in imputed times['IPCW-T']['ACTG_175_HIV1'].
'IPCW-T' Imputation completed for 'ACTG_175_HIV1', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Creating new entry for 'ACTG_175_HIV2' in imputed times['IPCW-T'].
Creating new entry for 'scenario_1' in imputed times['IPCW-T']['ACTG_175_HIV2'].
'IPCW-T' Imputation completed for 'ACTG_175_HIV2', 'scenario_1', num_training: 75%, 9 in 0 seconds.
Creating new entry for 'ACTG_175_HIV3' in imputed times['IPCW-T'].
Creating new entry for 'scenario_1' in imputed times['IPCW-T']['ACTG_175_HIV3'].


Imputation Methods: 100%|██████████| 3/3 [00:08<00:00,  2.76s/it]

'IPCW-T' Imputation completed for 'ACTG_175_HIV3', 'scenario_1', num_training: 75%, 9 in 0 seconds.





In [19]:
imputed_times['Pseudo_obs']['ACTG_175_HIV1']['scenario_1'].keys()

dict_keys(['75%'])

### Get True ATE

In [14]:
TRUE_ATE = {}
for setup_name, setup_dict in experiment_setups.items():
    for scenario_key in setup_dict:
        dataset_df = setup_dict[scenario_key]
        true_ate = dataset_df['cate_base'].mean()
        TRUE_ATE[(setup_name, scenario_key)] = true_ate

In [15]:
TRUE_ATE

{('ACTG_175_HIV1', 'scenario_1'): 2.7977461375268904,
 ('ACTG_175_HIV2', 'scenario_1'): 2.603510045518606,
 ('ACTG_175_HIV3', 'scenario_1'): 2.051686700212568}