On the original low censoring rate ACTG data (L stands for Low censoring rate)

In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os

### Load data

In [5]:
store_files = [
    "../real_data/ACTG_175_HIV1.csv",
    "../real_data/ACTG_175_HIV2.csv",
    "../real_data/ACTG_175_HIV3.csv",
]

experiment_setups = {}

for path in store_files:
    base_name = os.path.splitext(os.path.basename(path))[0]  # e.g. ACTG_175_HIV1
    scenario_dict = {}
    for scenario in range(1, 2): # only one scenario per HIV data
        try:
            result = pd.read_csv(path)
            if result is not None:
                scenario_dict[f"scenario_{scenario}"] = result
        except Exception as e:
            # Log or ignore as needed
            continue
    experiment_setups[base_name] = scenario_dict

In [6]:
experiment_setups['ACTG_175_HIV1']['scenario_1']

Unnamed: 0,id,observed_time_month,effect_non_censor,trt,z30,gender,race,hemo,homo,drugs,...,e1,e2,e3,e4,e5,e6,e7,e8,e9,cate_base
0,1,30.000000,1,0,1,1,0,0,1,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.788055
1,2,30.000000,1,1,1,1,0,0,1,1,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.450623
2,3,26.466667,1,0,1,1,0,0,1,0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.516565
3,4,30.000000,1,0,1,1,0,0,1,1,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.994492
4,5,6.266667,1,0,1,1,0,0,1,0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.468248
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1049,1050,5.133333,1,0,1,1,0,1,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.586317
1050,1051,19.600000,1,0,1,1,0,1,0,0,...,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,2.132294
1051,1052,30.000000,1,1,0,0,0,1,0,0,...,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,2.734120
1052,1053,13.166667,0,0,1,1,1,1,0,0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.265178


In [7]:
TRUE_ATE = {('ACTG_175_HIV1', 'scenario_1'): 2.7977461375268904,
            ('ACTG_175_HIV2', 'scenario_1'): 2.603510045518606,
            ('ACTG_175_HIV3', 'scenario_1'): 2.051686700212568}

### EXPERIMENT CONSTANTS

In [8]:
NUM_REPEATS_TO_INCLUDE = 10
output_pickle_path = f"../results/actgL_csf_num_repeats_{NUM_REPEATS_TO_INCLUDE}.pkl"

horizon = 30
min_node_size = 18
train_size = 0.75

output_pickle_path = f"../results/real_data/models_causal_survival/causal_survival_forest/"
output_pickle_path += f"actgL_causal_survival_forest_repeats_{NUM_REPEATS_TO_INCLUDE}.pkl"

### Run Experiments

In [9]:
import sys
import os

# Add the parent directory of "notebooks" to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from models_causal_survival.causal_survival_forest import CausalSurvivalForestGRF
from sklearn.metrics import mean_squared_error
import time
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

In [11]:
def prepare_actg_data_split(dataset_df, X_cols, W_col, cate_base_col, experiment_repeat_setup, train_size=0.75):
    split_results = {}
    length = len(experiment_repeat_setup)
    
    for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
        # y_cols = [f't{rand_idx}', f'e{rand_idx}']
        y_cols = ['observed_time_month', 'effect_non_censor']
        train_ids = experiment_repeat_setup[f'random_idx{rand_idx}'][:int(length*train_size)].values
        test_ids =  experiment_repeat_setup[f'random_idx{rand_idx}'][int(length*train_size):].values
        
        train_df = dataset_df[dataset_df['id'].isin(train_ids)]
        test_df = dataset_df[dataset_df['id'].isin(test_ids)]

        X_train = train_df[X_cols].to_numpy()
        W_train = train_df[W_col].to_numpy()
        Y_train = train_df[y_cols].to_numpy()

        X_test = test_df[X_cols].to_numpy()
        W_test = test_df[W_col].to_numpy()
        Y_test = test_df[y_cols].to_numpy()

        cate_test_true = test_df[cate_base_col].to_numpy()

        split_results[rand_idx] = (X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true)

    return split_results


In [12]:
X_bi_cols = ['gender', 'race', 'hemo', 'homo', 'drugs', 'str2', 'symptom']
X_cont_cols = ['age', 'wtkg',  'karnof', 'cd40', 'cd80']
U = ['z30']
W = ['trt']
y_cols = ['observed_time_month', 'effect_non_censor'] # ['time', 'cid']

X_cols = X_bi_cols + X_cont_cols
W_col = W[0]
cate_base_col = 'cate_base'
experiment_repeat_setups = [pd.read_csv(f'../real_data/idx_split_HIV{i}.csv') for i in range(1, 4)]

In [13]:
failure_times_grid_size = 200
horizon = 30
min_node_size = 18

print("Output results path:", output_pickle_path)

results_dict = {}

for setup_name, setup_dict in tqdm(experiment_setups.items(), desc="Experiment Setups"):
     
    results_dict[setup_name] = {}
    hiv_dataset_idx = int(setup_name[-1])
    experiment_repeat_setup = experiment_repeat_setups[hiv_dataset_idx-1]

    for scenario_key in tqdm(setup_dict, desc=f"{setup_name} Scenarios", leave=False):
        dataset_df = setup_dict[scenario_key]
        split_dict = prepare_actg_data_split(dataset_df, X_cols, W_col, cate_base_col, 
                                             experiment_repeat_setup, train_size=train_size)

        # Initialize results dictionary for this setup and scenario
        results_dict[setup_name][scenario_key] = {}

        start_time = time.time()

        for rand_idx in range(NUM_REPEATS_TO_INCLUDE):
            X_train, W_train, Y_train, X_test, W_test, Y_test, cate_test_true = split_dict[rand_idx]

            # Store placeholder for later population
            results_dict[setup_name][scenario_key][rand_idx] = {}

            # Train the model
            csf = CausalSurvivalForestGRF(failure_times_grid_size=failure_times_grid_size, 
                                          horizon=horizon, min_node_size=min_node_size, seed=2025+rand_idx)
            csf.fit(X_train, W_train, Y_train)

            ate_true = TRUE_ATE.get((setup_name, scenario_key), cate_test_true.mean())

            # Predict CATE
            # cate_test_pred = csf.predict_cate(X_test, W_test)
            mse_test, cate_test_pred, ate_test_pred = csf.evaluate(X_test, cate_test_true, W_test)

            # Save results
            results_dict[setup_name][scenario_key][rand_idx] = {
                "cate_true": cate_test_true,
                "cate_pred": cate_test_pred,
                "ate_true": ate_true,
                "ate_pred": ate_test_pred,
                "cate_mse": mse_test,
                "ate_bias": ate_test_pred - ate_true,
            }

        end_time = time.time()

        # Save results to the setup dictionary
        results_dict[setup_name][scenario_key]["average"] = {
            "mean_cate_mse": np.mean([results_dict[setup_name][scenario_key][rand_idx]["cate_mse"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_cate_mse": np.std([results_dict[setup_name][scenario_key][rand_idx]["cate_mse"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_pred": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_pred"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_pred": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_pred"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_true": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_true"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_true": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_true"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "mean_ate_bias": np.mean([results_dict[setup_name][scenario_key][rand_idx]["ate_bias"]
                                      for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "std_ate_bias": np.std([results_dict[setup_name][scenario_key][rand_idx]["ate_bias"]
                                    for rand_idx in range(NUM_REPEATS_TO_INCLUDE)]),
            "runtime": (end_time - start_time) / len(range(NUM_REPEATS_TO_INCLUDE))
        }

        # Save progress to disk
        with open(output_pickle_path, "wb") as f:
            pickle.dump(results_dict, f)

Output results path: ../results/real_data/models_causal_survival/causal_survival_forest/actgL_causal_survival_forest_repeats_10.pkl


Experiment Setups: 100%|██████████| 3/3 [00:23<00:00,  7.80s/it]


In [14]:
def summarize_experiment_results(results_dict):
    records = []

    for setup_name, setup_dict in results_dict.items():
        for scenario_key in setup_dict:
            avg_result = setup_dict[scenario_key].get("average", {})
            mean_mse = avg_result.get("mean_cate_mse", np.nan)
            std_mse = avg_result.get("std_cate_mse", np.nan)
            mean_ate_pred = avg_result.get("mean_ate_pred", np.nan)
            std_ate_pred = avg_result.get("std_ate_pred", np.nan)
            mean_ate_true = avg_result.get("mean_ate_true", np.nan)
            std_ate_true = avg_result.get("std_ate_true", np.nan)
            mean_ate_bias = avg_result.get("mean_ate_bias", np.nan)
            std_ate_bias = avg_result.get("std_ate_bias", np.nan)
            runtime = avg_result.get("runtime", np.nan)

            records.append({
                "setup_name": setup_name,
                "scenario_key": scenario_key,
                "CATE_MSE": f"{mean_mse:.3f} ± {std_mse:.3f}" if not pd.isna(mean_mse) else np.nan,
                "ATE_pred": f"{mean_ate_pred:.3f} ± {std_ate_pred:.3f}" if not pd.isna(mean_ate_pred) else np.nan,
                "ATE_true": f"{mean_ate_true:.3f} ± {std_ate_true:.3f}" if not pd.isna(mean_ate_true) else np.nan,
                "ATE_bias": f"{mean_ate_bias:.3f} ± {std_ate_bias:.3f}" if not pd.isna(mean_ate_bias) else np.nan,
                "runtime [s]": round(runtime) if not pd.isna(runtime) else np.nan
            })

    df = pd.DataFrame.from_records(records)
    return df

In [15]:
summary_df = summarize_experiment_results(results_dict)
summary_df

Unnamed: 0,setup_name,scenario_key,CATE_MSE,ATE_pred,ATE_true,ATE_bias,runtime [s]
0,ACTG_175_HIV1,scenario_1,0.117 ± 0.089,2.748 ± 0.236,2.798 ± 0.000,-0.050 ± 0.236,1
1,ACTG_175_HIV2,scenario_1,0.087 ± 0.062,2.571 ± 0.193,2.604 ± 0.000,-0.033 ± 0.193,1
2,ACTG_175_HIV3,scenario_1,0.169 ± 0.084,2.026 ± 0.251,2.052 ± 0.000,-0.026 ± 0.251,1
