In [1]:
import pickle
import os
import numpy as np
import pandas as pd

def load_all_simulation_results(n, parts=[1,2,3,4,5]):
    keys = [
        "survival_dfs", "cv_CI_REG", "cv_CI_REG_CV", "cv_CI_DELTA",
        "reg_CI_REG", "reg_CI_REG_CV", "reg_CI_DELTA",
        "delta_CI_REG", "delta_CI_REG_CV", "delta_CI_DELTA",
        "relax_CI_REG", "relax_CI_REG_CV", "relax_CI_DELTA"
    ]
    combined = { key: [] for key in keys }
    script_dir = os.getcwd()
    parent_dir = os.path.dirname(script_dir)
    for part in parts:
        file_path = os.path.join(parent_dir, "out", f"part_{part}_cv_censored_1D_survival_results_TargetingGRID0.02_nSamples{n}.pkl")
        with open(file_path, "rb") as f:
            results = pickle.load(f)
        for key in keys:
            combined[key].extend(results[key])
    sim_results = list(zip(
        combined["survival_dfs"],
        combined["cv_CI_REG"],
        combined["cv_CI_REG_CV"],
        combined["cv_CI_DELTA"],
        combined["reg_CI_REG"],
        combined["reg_CI_REG_CV"],
        combined["reg_CI_DELTA"],
        combined["delta_CI_REG"],
        combined["delta_CI_REG_CV"],
        combined["delta_CI_DELTA"],
        combined["relax_CI_REG"],
        combined["relax_CI_REG_CV"],
        combined["relax_CI_DELTA"]
    ))
    return sim_results

# ------------------------------------------
# 2. True Survival Function (evaluated at s)
# ------------------------------------------
def true_survival(s):
    return 1 - 3*(s**2) + 2*(s**3)

# ------------------------------------------
# 3. Compute Performance Metrics and Average CI Length
# ------------------------------------------
def compute_metrics_from_ci_dfs(ci_dfs):
    """
    Compute performance metrics from CI DataFrames by calculating metrics at each grid point
    and then averaging across grid points.

    Returns:
        abs_bias: Mean absolute bias (averaged over time points).
        std_err: Mean standard error (averaged over time points).
        mse: Mean squared error (averaged over time points).
        coverage: Overall coverage percentage.
        avg_ci_length: Average confidence interval length.
    """
    grid = ci_dfs[0]["s"].values
    bias_at_s = []
    mse_at_s = []
    std_err_at_s = []
    coverage_by_time = []
    ci_length_at_s = []
    
    for idx, s_val in enumerate(grid):
        errors_at_s = []
        indicators = []
        lengths = []
        for ci_df in ci_dfs:
            row = ci_df.iloc[idx]
            true_val = true_survival(s_val)
            error = row["survival_est"] - true_val
            errors_at_s.append(error)
            indicators.append(1 if (row["ci_lower"] <= true_val <= row["ci_upper"]) else 0)
            lengths.append(row["ci_upper"] - row["ci_lower"])
            
        bias_at_s.append(np.mean(np.abs(errors_at_s)))
        mse_at_s.append(np.mean(np.square(errors_at_s)))
        std_err_at_s.append(np.std(errors_at_s, ddof=1))
        coverage_by_time.append(np.mean(indicators))
        ci_length_at_s.append(np.mean(lengths))
    
    abs_bias = np.mean(bias_at_s)
    mse = np.mean(mse_at_s)
    std_err = np.mean(std_err_at_s)
    coverage = np.mean(coverage_by_time) * 100
    avg_ci_length = np.mean(ci_length_at_s)
    
    return abs_bias, std_err, mse, coverage, avg_ci_length

# ------------------------------------------
# 4. Compute Oracle Coverage and Average Oracle CI Length
# ------------------------------------------
def compute_oracle_coverage(ci_dfs):
    """
    For each grid point s, compute the oracle CI length as:
         2 * 1.96 * std_est = 3.92 * std_est,
    where std_est is the standard deviation across replications at s.
    Returns the average oracle coverage and average oracle CI length.
    """
    grid = ci_dfs[0]["s"].values
    coverage_at_s = []
    oracle_ci_lengths = []
    
    for i, s_val in enumerate(grid):
        # Collect survival estimates across replications at this grid point.
        estimates = np.array([ci_df.iloc[i]["survival_est"] for ci_df in ci_dfs])
        std_est = np.std(estimates, ddof=1)
        ci_length = 2 * 1.96 * std_est  # Oracle CI length.
        oracle_ci_lengths.append(ci_length)
        indicators = [1 if ((est - 1.96 * std_est) <= true_survival(s_val) <= (est + 1.96 * std_est)) else 0
                      for est in estimates]
        coverage_at_s.append(np.mean(indicators))
    
    oracle_cov = np.mean(coverage_at_s) * 100
    avg_oracle_ci_length = np.mean(oracle_ci_lengths)
    
    return oracle_cov, avg_oracle_ci_length

# ------------------------------------------
# 5. Summarize Simulation Results Across Sample Sizes
# ------------------------------------------
sample_sizes = [500, 1000,1500,2000]
rows = []

for n in sample_sizes:
    sim_results = load_all_simulation_results(n)
    
    # Projection-based targeting:
    ci_proj_reg_list    = [res[4] for res in sim_results]
    ci_proj_reg_cv_list = [res[5] for res in sim_results]
    ci_proj_delta_list  = [res[6] for res in sim_results]
    
    # Delta-method targeting:
    ci_delta_reg_list    = [res[7] for res in sim_results]
    ci_delta_reg_cv_list = [res[8] for res in sim_results]
    ci_delta_delta_list  = [res[9] for res in sim_results]
    
    # Relaxed/GLM targeting:
    ci_relax_reg_list    = [res[10] for res in sim_results]
    ci_relax_reg_cv_list = [res[11] for res in sim_results]
    ci_relax_delta_list  = [res[12] for res in sim_results]
    
    # Compute metrics (with CI lengths) for Projection targeting:
    abs_bias_proj, std_err_proj, mse_proj, cov_proj_reg, ci_len_proj_reg = compute_metrics_from_ci_dfs(ci_proj_reg_list)
    _, _, _, cov_proj_reg_cv, ci_len_proj_reg_cv = compute_metrics_from_ci_dfs(ci_proj_reg_cv_list)
    _, _, _, cov_proj_delta, ci_len_proj_delta = compute_metrics_from_ci_dfs(ci_proj_delta_list)
    
    # For Delta-method targeting:
    abs_bias_delta, std_err_delta, mse_delta, cov_delta_reg, ci_len_delta_reg = compute_metrics_from_ci_dfs(ci_delta_reg_list)
    _, _, _, cov_delta_reg_cv, ci_len_delta_reg_cv = compute_metrics_from_ci_dfs(ci_delta_reg_cv_list)
    _, _, _, cov_delta_delta, ci_len_delta_delta = compute_metrics_from_ci_dfs(ci_delta_delta_list)
    
    # For Relaxed targeting:
    abs_bias_relax, std_err_relax, mse_relax, cov_relax_reg, ci_len_relax_reg = compute_metrics_from_ci_dfs(ci_relax_reg_list)
    _, _, _, cov_relax_reg_cv, ci_len_relax_reg_cv = compute_metrics_from_ci_dfs(ci_relax_reg_cv_list)
    _, _, _, cov_relax_delta, ci_len_relax_delta = compute_metrics_from_ci_dfs(ci_relax_delta_list)
    
    # Compute oracle coverage and oracle CI length.
    oracle_cov_proj, oracle_ci_len_proj = compute_oracle_coverage(ci_proj_reg_list)
    oracle_cov_delta, oracle_ci_len_delta = compute_oracle_coverage(ci_delta_reg_list)
    oracle_cov_relax, oracle_ci_len_relax = compute_oracle_coverage(ci_relax_reg_list)
    
    # Format the coverage and CI length into strings with three decimals for lengths.
    proj_cov_reg_str    = f"{cov_proj_reg:.3f}({ci_len_proj_reg:.3f})"
    proj_cov_reg_cv_str = f"{cov_proj_reg_cv:.3f}({ci_len_proj_reg_cv:.3f})"
    proj_cov_delta_str  = f"{cov_proj_delta:.3f}({ci_len_proj_delta:.3f})"
    
    delta_cov_reg_str    = f"{cov_delta_reg:.3f}({ci_len_delta_reg:.3f})"
    delta_cov_reg_cv_str = f"{cov_delta_reg_cv:.3f}({ci_len_delta_reg_cv:.3f})"
    delta_cov_delta_str  = f"{cov_delta_delta:.3f}({ci_len_delta_delta:.3f})"
    
    relax_cov_reg_str    = f"{cov_relax_reg:.3f}({ci_len_relax_reg:.3f})"
    relax_cov_reg_cv_str = f"{cov_relax_reg_cv:.3f}({ci_len_relax_reg_cv:.3f})"
    relax_cov_delta_str  = f"{cov_relax_delta:.3f}({ci_len_relax_delta:.3f})"
    
    oracle_cov_proj_str  = f"{oracle_cov_proj:.3f}({oracle_ci_len_proj:.3f})"
    oracle_cov_delta_str = f"{oracle_cov_delta:.3f}({oracle_ci_len_delta:.3f})"
    oracle_cov_relax_str = f"{oracle_cov_relax:.3f}({oracle_ci_len_relax:.3f})"
    
    # Append summary rows with the formatted coverage strings.
    rows.append({
        "Estimator": "A-TMLE",
        "n": n,
        "Targeting": "Projection",
        "Abs. Bias": abs_bias_proj,
        "Std. Err.": std_err_proj,
        "MSE": mse_proj,
        "Cov NP (%)": np.nan,   # Not computed.
        "Cov Proj (%)": proj_cov_reg_str,
        "Cov Proj CV (%)": proj_cov_reg_cv_str,
        "Cov Delta (%)": proj_cov_delta_str,
        "Oracle Cov (%)": oracle_cov_proj_str
    })
    rows.append({
        "Estimator": "A-TMLE",
        "n": n,
        "Targeting": "Delta-method",
        "Abs. Bias": abs_bias_delta,
        "Std. Err.": std_err_delta,
        "MSE": mse_delta,
        "Cov NP (%)": np.nan,
        "Cov Proj (%)": delta_cov_reg_str,
        "Cov Proj CV (%)": delta_cov_reg_cv_str,
        "Cov Delta (%)": delta_cov_delta_str,
        "Oracle Cov (%)": oracle_cov_delta_str
    })
    rows.append({
        "Estimator": "A-TMLE",
        "n": n,
        "Targeting": "Relaxed",
        "Abs. Bias": abs_bias_relax,
        "Std. Err.": std_err_relax,
        "MSE": mse_relax,
        "Cov NP (%)": np.nan,
        "Cov Proj (%)": relax_cov_reg_str,
        "Cov Proj CV (%)": relax_cov_reg_cv_str,
        "Cov Delta (%)": relax_cov_delta_str,
        "Oracle Cov (%)": oracle_cov_relax_str
    })

summary_df = pd.DataFrame(rows)
print(summary_df)

# Optionally, generate and print a LaTeX version of the summary table.
latex_table = summary_df.to_latex(index=False, float_format="%.4f")
print(latex_table)


   Estimator     n     Targeting  Abs. Bias  Std. Err.       MSE  Cov NP (%)  \
0     A-TMLE   500    Projection   0.013736   0.016609  0.000367         NaN   
1     A-TMLE   500  Delta-method   0.013929   0.016851  0.000377         NaN   
2     A-TMLE   500       Relaxed   0.358527   0.055079  0.189461         NaN   
3     A-TMLE  1000    Projection   0.009829   0.011679  0.000182         NaN   
4     A-TMLE  1000  Delta-method   0.009916   0.011848  0.000189         NaN   
5     A-TMLE  1000       Relaxed   0.402177   0.046973  0.239075         NaN   
6     A-TMLE  1500    Projection   0.008441   0.009920  0.000133         NaN   
7     A-TMLE  1500  Delta-method   0.008621   0.010320  0.000144         NaN   
8     A-TMLE  1500       Relaxed   0.421148   0.056387  0.263935         NaN   
9     A-TMLE  2000    Projection   0.007104   0.008406  0.000094         NaN   
10    A-TMLE  2000  Delta-method   0.007283   0.008726  0.000100         NaN   
11    A-TMLE  2000       Relaxed   0.438