In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
from itertools import combinations

# Data loading function remains the same
def load_and_prepare_pancancer_data(labels_dir: str, data_dir: str):
    # This function is unchanged from the previous version...
    print("--- Loading and preparing all pan-cancer data... ---")
    X_train_orig = pd.read_csv(os.path.join(data_dir, "real_data_train.csv"), index_col=0)
    X_test_orig = pd.read_csv(os.path.join(data_dir, "test_data.csv"), index_col=0)
    train_stage_df = pd.read_csv(os.path.join(labels_dir, "train_stage.csv"), index_col=0)
    train_type_df = pd.read_csv(os.path.join(labels_dir, "train_cancer_type.csv"), index_col=0)
    test_stage_df = pd.read_csv(os.path.join(labels_dir, "test_stage.csv"), index_col=0)
    test_type_df = pd.read_csv(os.path.join(labels_dir, "test_cancer_type.csv"), index_col=0)
    train_labels_combined = train_stage_df.join(train_type_df).dropna(subset=['stage', 'cancertype'])
    test_labels_combined = test_stage_df.join(test_type_df).dropna(subset=['stage', 'cancertype'])
    train_common_idx = train_labels_combined.index.intersection(X_train_orig.index)
    test_common_idx = test_labels_combined.index.intersection(X_test_orig.index)
    X_train = X_train_orig.loc[train_common_idx].sort_index()
    y_train = train_labels_combined.loc[train_common_idx, 'stage'].sort_index()
    train_types = train_labels_combined.loc[train_common_idx, 'cancertype'].sort_index()
    X_test = X_test_orig.loc[test_common_idx].sort_index()
    y_test = test_labels_combined.loc[test_common_idx, 'stage'].sort_index()
    test_types = test_labels_combined.loc[test_common_idx, 'cancertype'].sort_index()
    print(f"  Found {len(X_train)} training samples and {len(X_test)} test samples across all cancers.")
    train_cancer_dummies = pd.get_dummies(train_types, prefix='cancer')
    test_cancer_dummies = pd.get_dummies(test_types, prefix='cancer')
    train_cancer_dummies, test_cancer_dummies = train_cancer_dummies.align(test_cancer_dummies, join='outer', axis=1, fill_value=0)
    X_train_final = pd.concat([X_train, train_cancer_dummies], axis=1)
    X_test_final = pd.concat([X_test, test_cancer_dummies], axis=1)
    return X_train_final, y_train, X_test_final, y_test


# MODIFIED FUNCTION - Now accepts a random_seed
def run_pancancer_modality_analysis(labels_dir: str, data_dir: str, random_seed: int):
    """
    Trains a single pan-cancer model with a specific random seed and evaluates
    its robustness to combinations of missing modalities.
    """
    # Data loading is now separated
    X_train, y_train, X_test, y_test = load_and_prepare_pancancer_data(labels_dir, data_dir)
    
    if X_train is None: # Check if data loading failed
        return None

    # Train a SINGLE pan-cancer model with the given seed
    print(f"  Training a single pan-cancer model with random_state={random_seed}...")
    model = RandomForestClassifier(n_estimators=100, random_state=random_seed, n_jobs=-1)
    model.fit(X_train, y_train)
    
    # Dynamically find which modalities are present in the data
    all_prefixes = {col.split('_')[0] for col in X_test.columns}
    possible_modalities = ['cna', 'rnaseq', 'rppa', 'wsi'] 
    available_modalities = sorted([m for m in possible_modalities if m in all_prefixes])
    
    test_conditions = {'full_data': X_test}

    # Generate all combinations of modalities to remove
    for r in range(1, len(available_modalities) + 1): # Go up to ALL modalities removed
        modality_combinations_to_remove = combinations(available_modalities, r)
        
        for combo in modality_combinations_to_remove:
            condition_name = f"no_{'_'.join(combo)}"
            cols_to_nullify = []
            for modality in combo:
                cols_to_nullify.extend([col for col in X_test.columns if col.startswith(modality + '_')])
            
            X_test_modified = X_test.copy()
            X_test_modified[cols_to_nullify] = np.nan
            test_conditions[condition_name] = X_test_modified
            
    # Evaluate the single model on each test condition
    results = []
    for condition_name, X_test_current in test_conditions.items():
        y_pred = model.predict(X_test_current)
        balanced_acc = balanced_accuracy_score(y_test, y_pred)
        macro_f1 = classification_report(y_test, y_pred, output_dict=True, zero_division=0)['macro avg']['f1-score']
        results.append({
            'test_condition': condition_name,
            'balanced_accuracy': balanced_acc,
            'macro_f1_score': macro_f1
        })
        
    return pd.DataFrame(results)

# --- Main Execution ---
if __name__ == '__main__':
    LABELS_DIR = "../../datasets_TCGA/downstream_labels"
    DATA_DIR = "./data_task_02"
    N_RUNS = 10  # Define how many times to run the experiment
    
    all_results_dfs = []
    
    # --- Main Repetition Loop ---
    for i in range(N_RUNS):
        print(f"\n==================== Starting Run {i+1}/{N_RUNS} ====================")
        # Pass 'i' as the random seed for reproducibility and variation
        results_df = run_pancancer_modality_analysis(
            labels_dir=LABELS_DIR,
            data_dir=DATA_DIR,
            random_seed=i 
        )
        
        if results_df is not None:
            results_df['run'] = i + 1 # Add a column to identify the run
            all_results_dfs.append(results_df)

    # --- Aggregate and Analyze Final Results ---
    if all_results_dfs:
        # Combine results from all runs into a single DataFrame
        final_results_all_runs = pd.concat(all_results_dfs, ignore_index=True)

        print("\n\n===== SUMMARY STATISTICS ACROSS ALL RUNS =====")
        # Calculate summary statistics for each test condition
        summary_stats = final_results_all_runs.groupby('test_condition')['balanced_accuracy'].agg(
            ['mean', 'std', 'median', 'min', 'max']
        ).sort_values(by='mean', ascending=False)
        print(summary_stats)

        # --- MODIFIED VISUALIZATION BLOCK ---
        print("\n\n--- Generating plot to visualize performance and variance ---")

        # 1. Define a logical order for the plot
        summary_stats['n_removed'] = summary_stats.index.str.count('_')
        plot_order = summary_stats.sort_values(by=['n_removed', 'mean'], ascending=[True, False]).index.tolist()

        # 2. Create the box plot to show the distribution of results
        plt.figure(figsize=(14, 8))
        ax = sns.boxplot(
            data=final_results_all_runs,
            x='test_condition',
            y='balanced_accuracy',
            order=plot_order,
            showfliers=False # Hide outlier points for clarity
        )
        # Overlay a stripplot to show individual data points from each run
        sns.stripplot(
            data=final_results_all_runs,
            x='test_condition',
            y='balanced_accuracy',
            order=plot_order,
            jitter=True,
            alpha=0.6,
            color='black'
        )

        # 3. Final plot formatting
        ax.set_title('Pan-Cancer Model Performance and Stability Across Multiple Runs', fontsize=16)
        ax.set_xlabel('Test Data Condition', fontsize=12)
        ax.set_ylabel('Balanced Accuracy', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
from itertools import combinations
import warnings
import torch
import json
import pathlib
from types import SimpleNamespace
import sys

# Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# --- Add your custom library path ---
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) 
from lib.test import coherent_test_cos_rejection, test_model
from lib.config import modalities_list
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion

# --- SECTION 1: HELPER FUNCTIONS (UNCHANGED) ---
def load_and_prepare_pancancer_data(labels_dir: str, data_dir: str):
    # This function is unchanged from the previous version...
    print("--- Loading and preparing all pan-cancer data... ---")
    X_train_orig = pd.read_csv(os.path.join(data_dir, "real_data_train.csv"), index_col=0)
    X_test_orig = pd.read_csv(os.path.join(data_dir, "test_data.csv"), index_col=0)
    train_stage_df = pd.read_csv(os.path.join(labels_dir, "train_stage.csv"), index_col=0)
    train_type_df = pd.read_csv(os.path.join(labels_dir, "train_cancer_type.csv"), index_col=0)
    test_stage_df = pd.read_csv(os.path.join(labels_dir, "test_stage.csv"), index_col=0)
    test_type_df = pd.read_csv(os.path.join(labels_dir, "test_cancer_type.csv"), index_col=0)
    train_labels_combined = train_stage_df.join(train_type_df).dropna(subset=['stage', 'cancertype'])
    test_labels_combined = test_stage_df.join(test_type_df).dropna(subset=['stage', 'cancertype'])
    train_common_idx = train_labels_combined.index.intersection(X_train_orig.index)
    test_common_idx = test_labels_combined.index.intersection(X_test_orig.index)
    X_train = X_train_orig.loc[train_common_idx].sort_index()
    y_train = train_labels_combined.loc[train_common_idx, 'stage'].sort_index()
    train_types = train_labels_combined.loc[train_common_idx, 'cancertype'].sort_index()
    X_test = X_test_orig.loc[test_common_idx].sort_index()
    y_test = test_labels_combined.loc[test_common_idx, 'stage'].sort_index()
    test_types = test_labels_combined.loc[test_common_idx, 'cancertype'].sort_index()
    print(f"  Found {len(X_train)} training samples and {len(X_test)} test samples across all cancers.")
    train_cancer_dummies = pd.get_dummies(train_types, prefix='cancer')
    test_cancer_dummies = pd.get_dummies(test_types, prefix='cancer')
    train_cancer_dummies, test_cancer_dummies = train_cancer_dummies.align(test_cancer_dummies, join='outer', axis=1, fill_value=0)
    X_train_final = pd.concat([X_train, train_cancer_dummies], axis=1)
    X_test_final = pd.concat([X_test, test_cancer_dummies], axis=1)
    return X_train_final, y_train, X_test_final, y_test

def load_single_model(target_mod, cond_mod, diffusion, config_args, device):
    # This function is unchanged from the previous version...
    path = pathlib.Path(f'../../{config_args.folder}/{config_args.dim}/{target_mod}_from_{cond_mod}')
    ckpt_path = path / f'train/best_by_{config_args.metric}.pth'
    if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found for single model: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    config = SimpleNamespace(**ckpt["config"])
    x_dim = config_args.modality_dims[target_mod]
    cond_dim = config_args.modality_dims[cond_mod]
    model = get_diffusion_model(config.architecture, diffusion, config, x_dim=x_dim, cond_dims=cond_dim).to(device)
    model.load_state_dict(ckpt[f"best_model_{config_args.metric}"])
    model.eval()
    return model, config, ckpt['best_loss']

def load_multi_model(target_mod, diffusion, config_args, device):
    # This function is unchanged from the previous version...
    base_dir = pathlib.Path(f"../../{config_args.folder}/{config_args.dim}/{target_mod}_from_multi{'_masked' if config_args.mask else ''}")
    ckpt_path = base_dir / 'train' / f'best_by_{config_args.metric}.pth'
    if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found for multi model: {ckpt_path}")
    with open(base_dir / 'cond_order.json', 'r') as f: cond_order = json.load(f)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    config = SimpleNamespace(**ckpt['config'])
    x_dim = config_args.modality_dims[target_mod]
    cond_dim_list = [config_args.modality_dims[c] for c in cond_order]
    model = get_diffusion_model(config.architecture, diffusion, config, x_dim=x_dim, cond_dims=cond_dim_list).to(device)
    model.load_state_dict(ckpt[f'best_model_{config_args.metric}'])
    model.eval()
    return model, config, cond_order

def impute_missing_modalities(X_test_with_nan, modalities_to_impute, available_modalities, gen_mode, config_args, diffusion, device):
    # This function is unchanged from the previous version...
    X_imputed = X_test_with_nan.copy()
    generation_order = sorted(modalities_to_impute)
    conditioning_modalities = [m for m in available_modalities if m not in modalities_to_impute]
    for i, target_mod in enumerate(generation_order):
        print(f"    Imputing '{target_mod}' (step {i+1}/{len(generation_order)}) with '{gen_mode}' model...")
        current_conds = conditioning_modalities + generation_order[:i]
        cond_data_list = []
        for cond_mod in current_conds:
            cond_cols = [c for c in X_imputed.columns if c.startswith(cond_mod + '_')]
            cond_data_list.append(X_imputed[cond_cols])
        if gen_mode == 'coherent':
            models = [load_single_model(target_mod, c, diffusion, config_args, device)[0] for c in current_conds]
            weights = [load_single_model(target_mod, c, diffusion, config_args, device)[2] for c in current_conds]
            _, generated_df, _ = coherent_test_cos_rejection(
                pd.DataFrame(np.zeros((X_imputed.shape[0], config_args.modality_dims[target_mod]))), 
                cond_data_list, models, diffusion, test_iterations=1, max_retries=10, 
                device=device, weights_list=weights
            )
        elif gen_mode == 'multi':
            model, _, cond_order = load_multi_model(target_mod, diffusion, config_args, device)
            final_cond_list = []
            for c_name in cond_order:
                if c_name in current_conds:
                    cond_cols = [c for c in X_imputed.columns if c.startswith(c_name + '_')]
                    final_cond_list.append(X_imputed[cond_cols])
                else: 
                    shape = (X_imputed.shape[0], config_args.modality_dims[c_name])
                    final_cond_list.append(pd.DataFrame(np.zeros(shape)))
            _, generated_df = test_model(
                pd.DataFrame(np.zeros((X_imputed.shape[0], config_args.modality_dims[target_mod]))),
                final_cond_list, model, diffusion, test_iterations=1, device=device
            )
        target_cols = [c for c in X_imputed.columns if c.startswith(target_mod + '_')]
        generated_df.columns = target_cols
        generated_df.index = X_imputed.index
        X_imputed[target_cols] = generated_df
    return X_imputed

def get_metrics(y_true, y_pred):
    """Helper function to calculate and return all desired metrics."""
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    macro_f1 = report['macro avg']['f1-score']
    return balanced_acc, macro_f1

# --- SECTION 2: MAIN ANALYSIS PIPELINE (MODIFIED) ---

def run_pancancer_analysis_with_imputation(random_seed: int, config_args, diffusion, device):
    X_train, y_train, X_test, y_test = load_and_prepare_pancancer_data(
        config_args.labels_dir, config_args.data_dir
    )
    if X_train is None: return None

    print(f"\n--- Training pan-cancer model with random_state={random_seed}... ---")
    classifier = RandomForestClassifier(n_estimators=100, random_state=random_seed, n_jobs=-1)
    classifier.fit(X_train, y_train)
    
    all_prefixes = {col.split('_')[0] for col in X_test.columns if '_' in col}
    possible_modalities = ['cna', 'rnaseq', 'rppa', 'wsi'] 
    available_modalities = sorted([m for m in possible_modalities if m in all_prefixes])
    print(f"  Available modalities for ablation/imputation: {available_modalities}")

    all_results = []
    
    # --- NEW: Step 1 - Evaluate on the full, unmodified test set first ---
    print("\n--- Processing Test Condition: full_data ---")
    y_pred_full = classifier.predict(X_test)
    b_acc, f1 = get_metrics(y_test, y_pred_full)
    all_results.append({
        'test_condition': 'full_data',
        'test_type': 'full_data',
        'balanced_accuracy': b_acc,
        'macro_f1_score': f1
    })

    # --- NEW: Step 2 - Loop through combinations of modalities to remove/impute ---
    for r in range(1, len(available_modalities) + 1):
        for combo in combinations(available_modalities, r):
            
            # NEW: Rename the final condition for clarity
            if len(combo) == len(available_modalities):
                condition_name = "cancer_label_only"
            else:
                condition_name = f"no_{'_'.join(combo)}"
                
            modalities_to_process = list(combo)
            print(f"\n--- Processing Test Condition: {condition_name} ---")

            X_test_ablated = X_test.copy()
            cols_to_nullify = [col for mod in modalities_to_process for col in X_test.columns if col.startswith(mod + '_')]
            X_test_ablated[cols_to_nullify] = np.nan
            
            y_pred_ablated = classifier.predict(X_test_ablated)
            b_acc, f1 = get_metrics(y_test, y_pred_ablated)
            all_results.append({
                'test_condition': condition_name,
                'test_type': 'ablation',
                'balanced_accuracy': b_acc,
                'macro_f1_score': f1
            })

            if len(modalities_to_process) < len(available_modalities):
                for gen_mode in ['multi', 'coherent']:
                    X_test_imputed = impute_missing_modalities(
                        X_test_ablated, modalities_to_process, available_modalities, 
                        gen_mode, config_args, diffusion, device
                    )
                    y_pred_imputed = classifier.predict(X_test_imputed)
                    b_acc, f1 = get_metrics(y_test, y_pred_imputed)
                    all_results.append({
                        'test_condition': condition_name,
                        'test_type': f'imputed_{gen_mode}',
                        'balanced_accuracy': b_acc,
                        'macro_f1_score': f1
                    })
            else:
                print(f"  Skipping generative imputation for {condition_name}.")
                for gen_mode in ['multi', 'coherent']:
                    all_results.append({
                        'test_condition': condition_name,
                        'test_type': f'imputed_{gen_mode}',
                        'balanced_accuracy': np.nan,
                        'macro_f1_score': np.nan
                    })

    return pd.DataFrame(all_results)

# --- SECTION 3: MAIN EXECUTION AND VISUALIZATION (MODIFIED) ---

def create_summary_plot(data: pd.DataFrame, metric: str, title: str):
    """Helper function to create a consistent plot for a given metric."""
    print(f"\n--- Generating plot for: {metric} ---")
    
    # Define a logical order for the plot
    data['n_removed'] = data['test_condition'].apply(
        lambda x: 0 if x == 'full_data' else x.count('_') + 1
    )
    # Special case for the final condition
    data.loc[data['test_condition'] == 'cancer_label_only', 'n_removed'] = data['n_removed'].max() + 1
    
    plot_order = data.sort_values(by=['n_removed', 'test_condition']).test_condition.unique()

    g = sns.catplot(
        data=data, x='test_condition', y=metric, hue='test_type',
        order=plot_order, kind='bar', height=6, aspect=2.5, legend_out=False, errorbar='sd'
    )
    g.fig.suptitle(title, y=1.03, fontsize=16)
    g.set_axis_labels("Test Condition (Modalities Removed / Available)", f"Mean {metric.replace('_', ' ').title()}")
    g.set_xticklabels(rotation=45, ha='right')
    plt.legend(title='Test Type')
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

if __name__ == '__main__':
    config_args = SimpleNamespace(
        folder='results', metric='mse', dim='32', mask=False,
        labels_dir="../../datasets_TCGA/downstream_labels",
        data_dir="./data_task_02",
        modality_dims={'cna': 32, 'rnaseq': 32, 'rppa': 32, 'wsi': 32}
    )
    device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    diffusion = GaussianDiffusion(num_timesteps=1000).to(device)
    N_RUNS = 5
    all_run_dfs = []

    for i in range(N_RUNS):
        print(f"\n{'='*25} Starting Run {i+1}/{N_RUNS} {'='*25}")
        results_df = run_pancancer_analysis_with_imputation(
            random_seed=i, config_args=config_args, diffusion=diffusion, device=device
        )
        if results_df is not None:
            results_df['run'] = i + 1
            all_run_dfs.append(results_df)

    if all_run_dfs:
        final_results = pd.concat(all_run_dfs, ignore_index=True)
        final_results.dropna(subset=['balanced_accuracy', 'macro_f1_score'], inplace=True)

        print("\n\n===== SUMMARY STATISTICS ACROSS ALL RUNS =====")
        summary_stats = final_results.groupby(['test_condition', 'test_type'])[['balanced_accuracy', 'macro_f1_score']].agg(
            ['mean', 'std', 'median']
        )
        # Define a logical sort order for the summary table
        summary_stats['n_removed'] = summary_stats.index.get_level_values('test_condition').map(
            lambda x: 0 if x == 'full_data' else (x.count('_') + 2 if x == 'cancer_label_only' else x.count('_') + 1)
        )
        print(summary_stats.sort_values(by=['n_removed', ('balanced_accuracy', 'mean')], ascending=[True, False]).drop(columns='n_removed').to_string())

        # Create a plot for each metric
        create_summary_plot(final_results, 'balanced_accuracy', 'Comparison of Imputation Strategies (Balanced Accuracy)')
        create_summary_plot(final_results, 'macro_f1_score', 'Comparison of Imputation Strategies (Macro F1-Score)')

In [None]:
def create_summary_plot(data: pd.DataFrame, metric: str, title: str, save_path = None):
    """Helper function to create a consistent plot for a given metric."""
    print(f"\n--- Generating plot for: {metric} ---")
    
    # Define a logical order for the plot
    data['n_removed'] = data['test_condition'].apply(
        lambda x: 0 if x == 'full_data' else x.count('_') + 1
    )
    # Special case for the final condition
    data.loc[data['test_condition'] == 'cancer_label_only', 'n_removed'] = data['n_removed'].max() + 1
    
    plot_order = data.sort_values(by=['n_removed', 'test_condition']).test_condition.unique()

    g = sns.catplot(
        data=data, x='test_condition', y=metric, hue='test_type',
        order=plot_order, kind='bar', height=6, aspect=2.5, legend_out=True, errorbar='sd'
    )
    g.fig.suptitle(title, y=1.03, fontsize=16)
    g.set_axis_labels("Test Condition (Modalities Removed)", f"Mean {metric.replace('_', ' ').title()}")
    g.set_xticklabels(rotation=45, ha='right')
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    g._legend.set_title('Test Type')
    g._legend.set_bbox_to_anchor((1.02, 0.5))
    g._legend.set_frame_on(True)
    g._legend.set_loc('center left')
    if save_path is not None:
        g.savefig(save_path)
        print(f"Plot saved to {save_path}")
    plt.show()        


final_results = pd.concat(all_run_dfs, ignore_index=True)
final_results.dropna(subset=['balanced_accuracy', 'macro_f1_score'], inplace=True)

print("\n\n===== SUMMARY STATISTICS ACROSS ALL RUNS =====")
summary_stats = final_results.groupby(['test_condition', 'test_type'])[['balanced_accuracy', 'macro_f1_score']].agg(['mean', 'std', 'median'])

# Define a logical sort order for the summary table
summary_stats['n_removed'] = summary_stats.index.get_level_values('test_condition').map(
    lambda x: 0 if x == 'full_data' else (x.count('_') + 2 if x == 'cancer_label_only' else x.count('_') + 1)
)
print(summary_stats.sort_values(by=['n_removed', ('balanced_accuracy', 'mean')], ascending=[True, False]).drop(columns='n_removed').to_string())

results_path = '../../results/downstream/task_05_imputing_test_set'
# Ensure the results directory exists
os.makedirs(results_path, exist_ok=True)

# Save the final results to a CSV file
final_results.to_csv(os.path.join(results_path, 'results.csv'), index=False)


# Create a plot for each metric
create_summary_plot(final_results, 'balanced_accuracy', 'Comparison of Imputation Strategies (Balanced Accuracy)', save_path=os.path.join(results_path, 'balanced_accuracy_plot.png'))
create_summary_plot(final_results, 'macro_f1_score', 'Comparison of Imputation Strategies (Macro F1-Score)', save_path=os.path.join(results_path, 'f1_score_plot.png'))

In [None]:
import pandas as pd
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests
import numpy as np

def analyze_significance_of_imputation(results_df: pd.DataFrame, metric: str = 'balanced_accuracy'):
    """
    Performs a statistical analysis to find where imputation provides a significant gain
    over the ablation (NaN) method.
    
    Args:
        results_df: The DataFrame loaded from your results CSV file.
        metric: The performance metric to analyze ('balanced_accuracy' or 'macro_f1_score').
    """
    
    # Ensure the dataframe is not empty and has the required columns
    if results_df.empty or not all(c in results_df.columns for c in ['test_condition', 'test_type', metric]):
        print("DataFrame is empty or missing required columns.")
        return None

    # This list will store the results of each statistical test
    statistical_results = []
    
    # Get all unique test conditions, excluding the 'full_data' baseline
    test_conditions = [c for c in results_df['test_condition'].unique() if c != 'full_data']

    # --- Loop through each condition to perform comparisons ---
    for condition in test_conditions:
        condition_df = results_df[results_df['test_condition'] == condition]
        
        # Get the performance scores for each test type from all runs
        try:
            scores_ablation = condition_df[condition_df['test_type'] == 'ablation'][metric].dropna()
            scores_multi = condition_df[condition_df['test_type'] == 'imputed_multi'][metric].dropna()
            scores_coherent = condition_df[condition_df['test_type'] == 'imputed_coherent'][metric].dropna()
        except KeyError:
            print(f"Warning: Could not find all test types for condition '{condition}'. Skipping.")
            continue
            
        # Define the comparisons we want to make
        comparisons = {
            "imputed_multi vs. ablation": (scores_multi, scores_ablation),
            "imputed_coherent vs. ablation": (scores_coherent, scores_ablation),
        }

        for comp_name, (scores_imputed, scores_base) in comparisons.items():
            # Ensure we have paired data to compare
            if len(scores_imputed) != len(scores_base) or len(scores_imputed) == 0:
                continue

            # Calculate the mean performance gain
            mean_gain = scores_imputed.mean() - scores_base.mean()
            
            # Perform the one-sided Wilcoxon signed-rank test
            # H1: The distribution of scores_imputed is greater than scores_base
            try:
                stat, p_value = wilcoxon(scores_imputed, scores_base, alternative='greater')
            except ValueError:
                # This can happen if all differences are zero
                p_value = 1.0

            statistical_results.append({
                'test_condition': condition,
                'comparison': comp_name,
                'mean_gain': mean_gain,
                'p_value': p_value
            })

    if not statistical_results:
        print("No valid comparisons could be made.")
        return None

    # --- Apply Multiple Testing Correction ---
    stats_df = pd.DataFrame(statistical_results)
    
    # Use the Benjamini-Hochberg FDR correction method
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject # True if p_adj < 0.05

    return stats_df

# --- Main Execution ---
if __name__ == '__main__':
    

    # Load the results file you saved from the previous run
    try:
        results_file = '../../results/downstream/task_05_imputing_test_set/results.csv'
        all_results = pd.read_csv(results_file)
        print(f"Successfully loaded results from '{results_file}'")
    except FileNotFoundError:
        print(f"Error: The results file '{results_file}' was not found.")
        print("Please run the previous script first to generate the results.")
        exit()

    # Analyze the results based on Balanced Accuracy
    print("\n\n===== Analyzing Significance for Balanced Accuracy =====")
    significance_df_accuracy = analyze_significance_of_imputation(all_results, metric='balanced_accuracy')

    if significance_df_accuracy is not None:
        # Sort the table to easily see the most significant gains
        sorted_accuracy_results = significance_df_accuracy.sort_values(
            by=['significant_gain', 'p_adj_fdr'], 
            ascending=[False, True]
        )
        print("The table below shows if an imputation method provided a statistically significant gain over ablation.")
        print("A 'significant_gain' of True means we can be confident the improvement was not due to random chance.")
        print(sorted_accuracy_results.to_string())
    
    # Analyze the results based on Macro F1-Score
    print("\n\n===== Analyzing Significance for Macro F1-Score =====")
    significance_df_f1 = analyze_significance_of_imputation(all_results, metric='macro_f1_score')

    if significance_df_f1 is not None:
        sorted_f1_results = significance_df_f1.sort_values(
            by=['significant_gain', 'p_adj_fdr'], 
            ascending=[False, True]
        )
        print("The table below shows if an imputation method provided a statistically significant gain over ablation.")
        print("A 'significant_gain' of True means we can be confident the improvement was not due to random chance.")
        print(sorted_f1_results.to_string())

In [None]:
import pandas as pd
from scipy.stats import ttest_rel 
from statsmodels.stats.multitest import multipletests
import numpy as np

def analyze_significance_with_ttest(results_df: pd.DataFrame, metric: str = 'balanced_accuracy'):
    """
    Performs a statistical analysis using a paired t-test to find where imputation 
    provides a significant gain over the ablation (NaN) method.
    
    Args:
        results_df: The DataFrame loaded from your results CSV file.
        metric: The performance metric to analyze ('balanced_accuracy' or 'macro_f1_score').
    """
    
    if results_df.empty or not all(c in results_df.columns for c in ['test_condition', 'test_type', metric]):
        print("DataFrame is empty or missing required columns.")
        return None

    statistical_results = []
    
    test_conditions = [c for c in results_df['test_condition'].unique() if c != 'full_data']

    # --- Loop through each condition to perform comparisons ---
    for condition in test_conditions:
        condition_df = results_df[results_df['test_condition'] == condition]
        
        try:
            scores_ablation = condition_df[condition_df['test_type'] == 'ablation'][metric].dropna()
            scores_multi = condition_df[condition_df['test_type'] == 'imputed_multi'][metric].dropna()
            scores_coherent = condition_df[condition_df['test_type'] == 'imputed_coherent'][metric].dropna()
        except KeyError:
            print(f"Warning: Could not find all test types for condition '{condition}'. Skipping.")
            continue
            
        comparisons = {
            "imputed_multi vs. ablation": (scores_multi, scores_ablation),
            "imputed_coherent vs. ablation": (scores_coherent, scores_ablation),
        }

        for comp_name, (scores_imputed, scores_base) in comparisons.items():
            if len(scores_imputed) != len(scores_base) or len(scores_imputed) < 2: # t-test needs at least 2 observations
                continue

            mean_gain = scores_imputed.mean() - scores_base.mean()
            
            # --- THIS IS THE MODIFIED PART ---
            # Perform the one-sided paired t-test instead of the Wilcoxon test.
            # H1: The mean of scores_imputed is greater than the mean of scores_base.
            stat, p_value = ttest_rel(scores_imputed, scores_base, alternative='greater')
            # --------------------------------

            statistical_results.append({
                'test_condition': condition,
                'comparison': comp_name,
                'mean_gain': mean_gain,
                'p_value': p_value
            })

    if not statistical_results:
        print("No valid comparisons could be made.")
        return None

    # --- Apply Multiple Testing Correction (this part is unchanged) ---
    stats_df = pd.DataFrame(statistical_results)
    
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject

    return stats_df

# --- Main Execution ---
if __name__ == '__main__':
    # Load the results file you saved from the previous run
    try:
        results_file = '../../results/downstream/task_05_imputing_test_set/results.csv'
        all_results = pd.read_csv(results_file)
        print(f"Successfully loaded results from '{results_file}'")
    except FileNotFoundError:
        print(f"Error: The results file '{results_file}' was not found.")
        print("Please run the previous script first to generate the results.")
        exit()

    # Analyze the results based on Balanced Accuracy
    print("\n\n===== Analyzing Significance for Balanced Accuracy (using Paired T-Test) =====")
    significance_df_accuracy = analyze_significance_with_ttest(all_results, metric='balanced_accuracy')

    if significance_df_accuracy is not None:
        sorted_accuracy_results = significance_df_accuracy.sort_values(
            by=['significant_gain', 'p_adj_fdr'], 
            ascending=[False, True]
        )
        print("The table below shows if an imputation method provided a statistically significant gain over ablation.")
        print("A 'significant_gain' of True means we can be confident the improvement was not due to random chance.")
        print(sorted_accuracy_results.to_string())
    
    # Analyze the results based on Macro F1-Score
    print("\n\n===== Analyzing Significance for Macro F1-Score (using Paired T-Test) =====")
    significance_df_f1 = analyze_significance_with_ttest(all_results, metric='macro_f1_score')

    if significance_df_f1 is not None:
        sorted_f1_results = significance_df_f1.sort_values(
            by=['significant_gain', 'p_adj_fdr'], 
            ascending=[False, True]
        )
        print("The table below shows if an imputation method provided a statistically significant gain over ablation.")
        print("A 'significant_gain' of True means we can be confident the improvement was not due to random chance.")
        print(sorted_f1_results.to_string())

In [None]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
from itertools import combinations
import warnings
# NEW: Import standard imputers
from sklearn.impute import SimpleImputer, KNNImputer

# Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)


# --- SECTION 1: HELPER FUNCTIONS (NOW WITH STANDARD IMPUTERS) ---

def load_and_prepare_pancancer_data(labels_dir: str, data_dir: str):
    """Loads and aligns all data for a pan-cancer analysis."""
    print("--- Loading and preparing all pan-cancer data... ---")
    X_train_orig = pd.read_csv(os.path.join(data_dir, "real_data_train.csv"), index_col=0)
    X_test_orig = pd.read_csv(os.path.join(data_dir, "test_data.csv"), index_col=0)
    
    # Corrected column name to 'pathologic_stage'
    train_stage_df = pd.read_csv(os.path.join(labels_dir, "train_stage.csv"), index_col=0)
    train_type_df = pd.read_csv(os.path.join(labels_dir, "train_cancer_type.csv"), index_col=0)
    test_stage_df = pd.read_csv(os.path.join(labels_dir, "test_stage.csv"), index_col=0)
    test_type_df = pd.read_csv(os.path.join(labels_dir, "test_cancer_type.csv"), index_col=0)
    
    train_labels_combined = train_stage_df.join(train_type_df).dropna(subset=['stage', 'cancertype'])
    test_labels_combined = test_stage_df.join(test_type_df).dropna(subset=['stage', 'cancertype'])

    train_common_idx = train_labels_combined.index.intersection(X_train_orig.index)
    test_common_idx = test_labels_combined.index.intersection(X_test_orig.index)

    X_train = X_train_orig.loc[train_common_idx].sort_index()
    y_train = train_labels_combined.loc[train_common_idx, 'stage'].sort_index()
    train_types = train_labels_combined.loc[train_common_idx, 'cancertype'].sort_index()

    X_test = X_test_orig.loc[test_common_idx].sort_index()
    y_test = test_labels_combined.loc[test_common_idx, 'stage'].sort_index()
    test_types = test_labels_combined.loc[test_common_idx, 'cancertype'].sort_index()

    print(f"  Found {len(X_train)} training samples and {len(X_test)} test samples across all cancers.")
    
    train_cancer_dummies = pd.get_dummies(train_types, prefix='cancer')
    test_cancer_dummies = pd.get_dummies(test_types, prefix='cancer')
    train_cancer_dummies, test_cancer_dummies = train_cancer_dummies.align(test_cancer_dummies, join='outer', axis=1, fill_value=0)
    
    X_train_final = pd.concat([X_train, train_cancer_dummies], axis=1)
    X_test_final = pd.concat([X_test, test_cancer_dummies], axis=1)
    
    return X_train_final, y_train, X_test_final, y_test

def get_metrics(y_true, y_pred):
    """Helper function to calculate and return all desired metrics."""
    balanced_acc = balanced_accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    macro_f1 = report['macro avg']['f1-score']
    return balanced_acc, macro_f1


# --- SECTION 2: MAIN ANALYSIS PIPELINE (MODIFIED TO USE STANDARD IMPUTERS) ---

def run_pancancer_analysis_with_standard_imputation(random_seed: int):
    """
    Trains a classifier and tests it against ablated and standard imputed data.
    """
    X_train, y_train, X_test, y_test = load_and_prepare_pancancer_data(
        "../../datasets_TCGA/downstream_labels",
        "./data_task_02"
    )

    print(f"\n--- Training pan-cancer model with random_state={random_seed}... ---")
    classifier = RandomForestClassifier(n_estimators=100, random_state=random_seed, n_jobs=-1)
    # The classifier is trained on the original data with NaNs, as this is the 'ablation' baseline
    classifier.fit(X_train, y_train)
    
    all_prefixes = {col.split('_')[0] for col in X_test.columns if '_' in col}
    possible_modalities = ['cna', 'rnaseq', 'rppa', 'wsi'] 
    available_modalities = sorted([m for m in possible_modalities if m in all_prefixes])
    print(f"  Available modalities for ablation/imputation: {available_modalities}")

    all_results = []
    
    # --- Step 1 - Evaluate on the full, unmodified test set first ---
    print("\n--- Processing Test Condition: full_data ---")
    y_pred_full = classifier.predict(X_test)
    b_acc, f1 = get_metrics(y_test, y_pred_full)
    all_results.append({
        'test_condition': 'full_data',
        'test_type': 'full_data',
        'balanced_accuracy': b_acc,
        'macro_f1_score': f1
    })

    # --- Step 2 - Loop through combinations of modalities to remove/impute ---
    for r in range(1, len(available_modalities) + 1):
        for combo in combinations(available_modalities, r):
            
            if len(combo) == len(available_modalities):
                condition_name = "cancer_label_only"
            else:
                condition_name = f"no_{'_'.join(combo)}"
            
            print(f"\n--- Processing Test Condition: {condition_name} ---")

            cols_to_nullify = [col for mod in combo for col in X_test.columns if col.startswith(mod + '_')]
            X_test_ablated = X_test.copy()
            X_test_ablated[cols_to_nullify] = np.nan
            
            # --- Define the imputation strategies ---
            imputation_strategies = {
                'ablation': X_test_ablated,
            }

            # --- Mean Imputation ---
            print("    Imputing with 'mean'...")
            mean_imputer = SimpleImputer(strategy='mean')
            # Fit *only* on the training data to learn the means
            mean_imputer.fit(X_train)
            # Transform the ablated test data
            X_test_imputed_mean_np = mean_imputer.transform(X_test_ablated)
            imputation_strategies['imputed_mean'] = pd.DataFrame(X_test_imputed_mean_np, index=X_test.index, columns=X_test.columns)

            # --- KNN Imputation ---
            print("    Imputing with 'knn'...")
            knn_imputer = KNNImputer(n_neighbors=5)
            # Fit *only* on the training data to learn the KNN model
            knn_imputer.fit(X_train)
            # Transform the ablated test data
            X_test_imputed_knn_np = knn_imputer.transform(X_test_ablated)
            imputation_strategies['imputed_knn'] = pd.DataFrame(X_test_imputed_knn_np, index=X_test.index, columns=X_test.columns)

            # --- Evaluate each strategy ---
            for test_type, X_test_current in imputation_strategies.items():
                y_pred = classifier.predict(X_test_current)
                b_acc, f1 = get_metrics(y_test, y_pred)
                all_results.append({
                    'test_condition': condition_name,
                    'test_type': test_type,
                    'balanced_accuracy': b_acc,
                    'macro_f1_score': f1
                })

    return pd.DataFrame(all_results)

# --- SECTION 3: MAIN EXECUTION AND VISUALIZATION ---

def create_summary_plot(data: pd.DataFrame, metric: str, title: str):
    """Helper function to create a consistent plot for a given metric."""
    # ... (code is unchanged)
    print(f"\n--- Generating plot for: {metric} ---")
    data['n_removed'] = data['test_condition'].apply(lambda x: 0 if x == 'full_data' else x.count('_') + 1)
    data.loc[data['test_condition'] == 'cancer_label_only', 'n_removed'] = data['n_removed'].max() + 1
    plot_order = data.sort_values(by=['n_removed', 'test_condition']).test_condition.unique()
    g = sns.catplot(
        data=data, x='test_condition', y=metric, hue='test_type',
        order=plot_order, kind='bar', height=6, aspect=2.5, legend_out=False, errorbar='sd'
    )
    g.fig.suptitle(title, y=1.03, fontsize=16)
    g.set_axis_labels("Test Condition (Modalities Removed / Available)", f"Mean {metric.replace('_', ' ').title()}")
    g.set_xticklabels(rotation=45, ha='right')
    plt.legend(title='Test Type')
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

if __name__ == '__main__':
    N_RUNS = 5
    all_run_dfs = []

    for i in range(N_RUNS):
        print(f"\n{'='*25} Starting Run {i+1}/{N_RUNS} {'='*25}")
        results_df = run_pancancer_analysis_with_standard_imputation(random_seed=i)
        
        if results_df is not None:
            results_df['run'] = i + 1
            all_run_dfs.append(results_df)

    if all_run_dfs:
        final_results = pd.concat(all_run_dfs, ignore_index=True)
        

        print("\n\n===== SUMMARY STATISTICS ACROSS ALL RUNS =====")
        summary_stats = final_results.groupby(['test_condition', 'test_type'])[['balanced_accuracy', 'macro_f1_score']].agg(
            ['mean', 'std', 'median']
        )
        summary_stats['n_removed'] = summary_stats.index.get_level_values('test_condition').map(
            lambda x: 0 if x == 'full_data' else (x.count('_') + 2 if x == 'cancer_label_only' else x.count('_') + 1)
        )
        print(summary_stats.sort_values(by=['n_removed', ('balanced_accuracy', 'mean')], ascending=[True, False]).drop(columns='n_removed').to_string())

        create_summary_plot(final_results, 'balanced_accuracy', 'Comparison of Standard Imputation Strategies (Balanced Accuracy)')
        create_summary_plot(final_results, 'macro_f1_score', 'Comparison of Standard Imputation Strategies (Macro F1-Score)')

In [None]:

final_results.to_csv('../../results/downstream/task_05_imputing_test_set/standard_imputation_results.csv', index=False)
      

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import multipletests
import numpy as np
import os

# The analysis functions from the previous step are unchanged.
def analyze_vs_ablation(combined_df: pd.DataFrame, metric: str = 'balanced_accuracy'):
    """
    Performs statistical analysis comparing all imputation methods against the ablation baseline.
    """
    statistical_results = []
    test_conditions = [c for c in combined_df['test_condition'].unique() if c not in ['full_data', 'cancer_label_only']]

    for condition in test_conditions:
        condition_df = combined_df[combined_df['test_condition'] == condition]
        try:
            scores_ablation = condition_df[condition_df['test_type'] == 'ablation'][metric].dropna()
            imputation_types = [t for t in condition_df['test_type'].unique() if 'imputed' in t]
            for imp_type in imputation_types:
                scores_imputed = condition_df[condition_df['test_type'] == imp_type][metric].dropna()
                if len(scores_imputed) != len(scores_ablation) or len(scores_imputed) < 2: continue
                mean_gain = scores_imputed.mean() - scores_ablation.mean()
                stat, p_value = ttest_rel(scores_imputed, scores_ablation, alternative='greater')
                statistical_results.append({
                    'test_condition': condition,
                    'comparison': f"{imp_type} vs. ablation",
                    'mean_gain': mean_gain,
                    'p_value': p_value
                })
        except (KeyError, ValueError):
            continue

    if not statistical_results: return None
    stats_df = pd.DataFrame(statistical_results)
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject
    return stats_df

def analyze_generative_vs_standard(combined_df: pd.DataFrame, metric: str = 'balanced_accuracy'):
    """
    Performs statistical analysis comparing generative models against standard imputation methods.
    """
    statistical_results = []
    test_conditions = [c for c in combined_df['test_condition'].unique() if c not in ['full_data', 'cancer_label_only']]

    for condition in test_conditions:
        condition_df = combined_df[combined_df['test_condition'] == condition]
        try:
            scores_mean = condition_df[condition_df['test_type'] == 'imputed_mean'][metric].dropna()
            scores_knn = condition_df[condition_df['test_type'] == 'imputed_knn'][metric].dropna()
            scores_multi = condition_df[condition_df['test_type'] == 'imputed_multi'][metric].dropna()
            scores_coherent = condition_df[condition_df['test_type'] == 'imputed_coherent'][metric].dropna()
        except KeyError:
            continue
        
        standard_models = {'mean': scores_mean, 'knn': scores_knn}
        generative_models = {'multi': scores_multi, 'coherent': scores_coherent}

        for gen_name, gen_scores in generative_models.items():
            for std_name, std_scores in standard_models.items():
                if len(gen_scores) != len(std_scores) or len(gen_scores) < 2: continue
                
                mean_gain = gen_scores.mean() - std_scores.mean()
                stat, p_value = ttest_rel(gen_scores, std_scores, alternative='greater')
                
                statistical_results.append({
                    'test_condition': condition,
                    'comparison': f"imputed_{gen_name} vs. imputed_{std_name}",
                    'mean_gain': mean_gain,
                    'p_value': p_value
                })
                
    if not statistical_results: return None
    stats_df = pd.DataFrame(statistical_results)
    reject, p_adj, _, _ = multipletests(stats_df['p_value'], alpha=0.05, method='fdr_bh')
    stats_df['p_adj_fdr'] = p_adj
    stats_df['significant_gain'] = reject
    return stats_df


# --- CORRECTED PLOTTING FUNCTION ---
def create_final_comparison_plot(data: pd.DataFrame, metric: str, title: str, save_path= None):
    """Creates the final comparison plot with a custom color palette and legend placement."""
    print(f"\n--- Generating final comparison plot for: {metric} ---")
    
    condition_to_exclude = (data['test_condition'] == 'cancer_label_only') & (data['test_type'] != 'ablation')
    plot_data = data[~condition_to_exclude].copy()

    plot_data['n_removed'] = plot_data['test_condition'].apply(
        lambda x: 0 if x == 'full_data' else (99 if x == 'cancer_label_only' else x.count('_') + 1)
    )
    plot_order = plot_data.sort_values(by=['n_removed', 'test_condition']).test_condition.unique()
    
    palette = {
        'full_data': "#67F122",
        'ablation': "#E7C923",
        'imputed_mean': "#EE8DCC",
        'imputed_knn': "#AD2C9C",
        'imputed_multi': "#115887",
        'imputed_coherent': "#14C1DC"
    }
    
    hue_order = ['full_data', 'ablation', 'imputed_mean', 'imputed_knn', 'imputed_multi', 'imputed_coherent']
    plot_data_hue_order = [h for h in hue_order if h in plot_data['test_type'].unique()]
    
    g = sns.catplot(
        data=plot_data, x='test_condition', y=metric, hue='test_type',
        order=plot_order, hue_order=plot_data_hue_order, kind='bar', height=7, aspect=2.2,
        palette=palette,
        errorbar='sd',
        # THIS IS THE FIX: Remove legend=False. Let catplot create the legend.
        # legend=False 
    )
    
    # Now that the legend exists, we can move it.
    sns.move_legend(
        g, "center right",
        bbox_to_anchor=(1.1, 0.5), 
        frameon=True,
        title='Test Type'
    )

    g.fig.suptitle(title, y=1.03, fontsize=18)
    g.set_axis_labels("Test Condition (Modalities Removed)", f"Mean {metric.replace('_', ' ').title()}", fontsize=14)
    g.set_xticklabels(rotation=45, ha='right')
    if save_path is not None:
        g.savefig(save_path)
        print(f"Plot saved to {save_path}")
    plt.show()

# --- Main Execution ---
if __name__ == '__main__':
    # 1. Load both result files
    try:
        results_generative_file = '../../results/downstream/task_05_imputing_test_set/results.csv'
        df_generative = pd.read_csv(results_generative_file)
        print(f"Successfully loaded generative results from '{results_generative_file}'")
        
        results_standard_file = '../../results/downstream/task_05_imputing_test_set/standard_imputation_results.csv'
        df_standard = pd.read_csv(results_standard_file)
        print(f"Successfully loaded standard imputation results from '{results_standard_file}'")

    except FileNotFoundError as e:
        print(f"Error: Could not find a results file. {e}")
        print("Please ensure both 'pancancer_imputation_comparison_results.csv' and 'pancancer_standard_imputation_results.csv' are present.")
        exit()

    df_standard_subset = df_standard[df_standard['test_type'].isin(['ablation', 'imputed_mean', 'imputed_knn', 'full_data'])]
    df_generative_subset = df_generative[df_generative['test_type'].isin(['imputed_multi', 'imputed_coherent'])]
    final_combined_results = pd.concat([df_standard_subset, df_generative_subset], ignore_index=True)
    print("\nDataFrames successfully combined.")
    
    # --- MODIFIED: Print the full sorted tables, not just the significant results ---
    print("\n\n===== Analysis 1: Significance of Gain vs. Ablation (Balanced Accuracy) =====")
    significance_vs_ablation = analyze_vs_ablation(final_combined_results, metric='balanced_accuracy')
    if significance_vs_ablation is not None:
        sorted_results_1 = significance_vs_ablation.sort_values(by=['significant_gain', 'p_adj_fdr'], ascending=[False, True])
        print(sorted_results_1.to_string())
    
    print("\n\n===== Analysis 2: Significance of Gain for Generative vs. Standard Methods (Balanced Accuracy) =====")
    significance_gen_vs_std = analyze_generative_vs_standard(final_combined_results, metric='balanced_accuracy')
    if significance_gen_vs_std is not None:
        sorted_results_2 = significance_gen_vs_std.sort_values(by=['significant_gain', 'p_adj_fdr'], ascending=[False, True])
        print(sorted_results_2.to_string())

    results_path = '../../results/downstream/task_05_imputing_test_set'
    # Ensure the results directory exists
    os.makedirs(results_path, exist_ok=True)
    
    # --- Final Visualization ---
    create_final_comparison_plot(final_combined_results, 'balanced_accuracy', 'Final Comparison of All Imputation Strategies (Balanced Accuracy)', save_path=os.path.join(results_path, 'balanced_accuracy_imputations.png'))
    create_final_comparison_plot(final_combined_results, 'macro_f1_score', 'Final Comparison of All Imputation Strategies (Macro F1-Score)', save_path=os.path.join(results_path, 'f1_score_imputations.png'))


