In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
import sys
import os

try:
    import statsmodels
except ImportError:
    import subprocess
    print("statsmodels not found. Installing...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "statsmodels"])
    import statsmodels

from statsmodels.stats.multitest import multipletests
import numpy as np
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)


def analyze_vs_ablation(combined_df: pd.DataFrame, metric: str = 'c_index'):
    """Performs statistical analysis comparing all imputation methods against the ablation baseline."""
    statistical_results = []
    # Exclude non-ablation baselines from this specific analysis
    test_conditions = [c for c in combined_df['test_condition'].unique() if c not in ['full_data', 'cancer_label_only']]

    for condition in test_conditions:
        condition_df = combined_df[combined_df['test_condition'] == condition]
        try:
            scores_ablation = condition_df[condition_df['test_type'] == 'ablation'][metric].dropna()
            imputation_types = [t for t in condition_df['test_type'].unique() if 'imputed' in t]
            for imp_type in imputation_types:
                scores_imputed = condition_df[condition_df['test_type'] == imp_type][metric].dropna()
                if len(scores_imputed) != len(scores_ablation) or len(scores_imputed) < 2: continue
                mean_gain = scores_imputed.mean() - scores_ablation.mean()
                stat, p_value = ttest_rel(scores_imputed, scores_ablation, alternative='greater')
                statistical_results.append({
                    'test_condition': condition,
                    'comparison': f"{imp_type} vs. ablation",
                    'mean_gain': mean_gain,
                    'p_value': p_value
                })
        except (KeyError, ValueError):
            continue

    if not statistical_results: return None
    stats_df = pd.DataFrame(statistical_results)
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject
    return stats_df

def analyze_generative_vs_standard(combined_df: pd.DataFrame, metric: str = 'c_index'):
    """Performs statistical analysis comparing generative models against standard imputation methods."""
    statistical_results = []
    test_conditions = [c for c in combined_df['test_condition'].unique() if c not in ['full_data', 'cancer_label_only']]

    for condition in test_conditions:
        condition_df = combined_df[combined_df['test_condition'] == condition]
        try:
            scores_mean = condition_df[condition_df['test_type'] == 'imputed_mean'][metric].dropna()
            scores_knn = condition_df[condition_df['test_type'] == 'imputed_knn'][metric].dropna()
            scores_multi = condition_df[condition_df['test_type'] == 'imputed_multi'][metric].dropna()
            scores_coherent = condition_df[condition_df['test_type'] == 'imputed_coherent'][metric].dropna()
        except KeyError: continue
        
        standard_models = {'mean': scores_mean, 'knn': scores_knn}
        generative_models = {'multi': scores_multi, 'coherent': scores_coherent}

        for gen_name, gen_scores in generative_models.items():
            for std_name, std_scores in standard_models.items():
                if len(gen_scores) != len(std_scores) or len(gen_scores) < 2: continue
                mean_gain = gen_scores.mean() - std_scores.mean()
                stat, p_value = ttest_rel(gen_scores, std_scores, alternative='greater')
                statistical_results.append({
                    'test_condition': condition,
                    'comparison': f"imputed_{gen_name} vs. imputed_{std_name}",
                    'mean_gain': mean_gain,
                    'p_value': p_value
                })
                
    if not statistical_results: return None
    stats_df = pd.DataFrame(statistical_results)
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject
    return stats_df

# --- NEW: Focused plotting function ---
def create_generative_focused_plot(data: pd.DataFrame, metric: str, title: str, save_path: str):
    """
    Creates a comparison plot focusing only on the generative models vs. baselines.
    """
    print(f"\n--- Generating focused plot for: {metric} ---")
    
    # Define the methods we want to show in this plot
    methods_to_show = ['full_data', 'ablation', 'imputed_multi', 'imputed_coherent']
    plot_data = data[data['test_type'].isin(methods_to_show)].copy()

    # Exclude imputation results for the 'cancer_label_only' case
    condition_to_exclude = (plot_data['test_condition'] == 'cancer_label_only') & (plot_data['test_type'] != 'ablation')
    plot_data = plot_data[~condition_to_exclude]

    # Define a logical order for the plot conditions on the x-axis
    plot_data['n_removed'] = plot_data['test_condition'].apply(
        lambda x: 0 if x == 'full_data' else (99 if x == 'cancer_label_only' else x.count('_') + 1)
    )
    plot_order = plot_data.sort_values(by=['n_removed', 'test_condition']).test_condition.unique()
    
    # Use a focused color palette
    palette = {
        'full_data': '#4C72B0', 
        'ablation': '#A9A9A9',
        'imputed_multi': '#FFB6C1',
        'imputed_coherent': '#DC143C'
    }
    hue_order = [h for h in ['full_data', 'ablation', 'imputed_multi', 'imputed_coherent'] if h in plot_data['test_type'].unique()]
    
    g = sns.catplot(
        data=plot_data, x='test_condition', y=metric, hue='test_type',
        order=plot_order, hue_order=hue_order, kind='bar', height=7, aspect=2.2,
        palette=palette, errorbar='sd'
    )
    
    sns.move_legend(g, "center right", bbox_to_anchor=(1.1, 0.5), frameon=True, title='Test Type')
    g.fig.suptitle(title, y=1.03, fontsize=18)
    g.set_axis_labels("Test Condition (Modalities Removed)", f"Mean {metric.replace('_', ' ').title()}", fontsize=14)
    g.set_xticklabels(rotation=45, ha='right')

    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close(g.fig)
    print(f"Plot saved to: {save_path}")

# --- Main Execution ---
if __name__ == '__main__':
    # We will analyze the results from the more complex "long_rf" experiment
    results_file = '../../results/downstream/task_06_imputing_test_set_surv/all_imputations_results_long_rf.csv'
    plots_dir = '../../results/downstream/task_06_imputing_test_set_surv'
    os.makedirs(plots_dir, exist_ok=True)
    
    try:
        final_results = pd.read_csv(results_file)
        print(f"Successfully loaded final results from '{results_file}'")
    except FileNotFoundError:
        print(f"Error: The results file '{results_file}' was not found.")
        print("Please run the previous experiment script first to generate the results.")
        exit()

    # --- Run and Print Statistical Analyses ---
    print("\n\n===== Analysis 1: Significance of Gain vs. Ablation (C-Index) =====")
    significance_vs_ablation = analyze_vs_ablation(final_results)
    if significance_vs_ablation is not None:
        print(significance_vs_ablation.sort_values(by=['significant_gain', 'p_adj_fdr'], ascending=[False, True]).to_string())
    
    print("\n\n===== Analysis 2: Significance of Gain for Generative vs. Standard Methods (C-Index) =====")
    significance_gen_vs_std = analyze_generative_vs_standard(final_results)
    if significance_gen_vs_std is not None:
        print(significance_gen_vs_std.sort_values(by=['significant_gain', 'p_adj_fdr'], ascending=[False, True]).to_string())
    
    # --- Create the new, focused plot ---
    final_plot_path = os.path.join(plots_dir, 'generative_model_comparison.png')
    create_generative_focused_plot(
        final_results, 
        metric='c_index',
        title='Generative Models for Survival Analysis',
        save_path=final_plot_path
    )