In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# 1. CONFIGURATION
# =============================================================================
# Define all modalities available in the dataset
ALL_MODALITIES = ['cna', 'rna', 'rppa', 'wsi']

# --- Main Setting ---
# Define which modalities to treat as "missing" or "ablated" for the analysis.
# This can be a single modality, e.g., ('rna',) or multiple, e.g., ('rna', 'wsi')
ABLATED_MODALITIES = ('rna',)

# Mapping from the ablated modalities tuple to the corresponding generated data file.
# This allows the script to pick the correct file for the counterfactual analysis.
gen_path = '../../results/32/'

GENERATED_DATA_MAPPING_coh = {
    ('rna',): f'{gen_path}rnaseq_from_coherent/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_coherent/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

GENERATED_DATA_MAPPING_multi = {
    ('rna',): f'{gen_path}rnaseq_from_multi/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_multi/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

# --- Experiment Parameters ---
# Number of iterations for statistical significance in random experiments
N_ITERATIONS = 30
# Steps for the progressive ablation plot (from 0% to 100% of samples ablated)
ABLATION_STEPS = np.arange(0, 1.05, 0.05)
# Number of neighbors for the k-NN baseline
K_NEIGHBORS = 10


# =============================================================================
# 2. DATA LOADING AND PREPARATION
# =============================================================================
print("Loading datasets...")

# Load raw training and testing data
train_path = '../../datasets_TCGA/07_normalized/32/'
train_raw = {
    'cna': pd.read_csv(f'{train_path}cna_train.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_train.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_train.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_train.csv', sep=','),
}

test_path = '../../datasets_TCGA/07_normalized/32/'
test_raw = {
    'cna': pd.read_csv(f'{train_path}cna_test.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_test.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_test.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_test.csv', sep=','),
}

labels_path = '../../datasets_TCGA/downstream_labels/'
train_stage = pd.read_csv(f'{labels_path}train_stage.csv')
test_stage = pd.read_csv(f'{labels_path}test_stage.csv')


def clean_and_unify_tables(df_dict, modalities):
    """
    Unifies multiple modality-specific dataframes into a single dataframe.
    - Verifies sample order consistency.
    - Adds modality prefixes to column names.
    """
    sample_ids = df_dict[modalities[0]]['sample_id'].values
    for m in modalities:
        assert np.all(df_dict[m]['sample_id'].values == sample_ids), f"Sample IDs for {m} do not match."

    processed_dfs = []
    for m in modalities:
        df = df_dict[m].drop(columns=['sample_id'])
        df.columns = [f'{m}_{i+1}' for i in range(df.shape[1])]
        processed_dfs.append(df)
        
    return pd.concat(processed_dfs, axis=1)

print("Processing and unifying tables...")
train_full = clean_and_unify_tables(train_raw, ALL_MODALITIES)
test_full = clean_and_unify_tables(test_raw, ALL_MODALITIES)

print(f"Train data shape: {train_full.shape}")
print(f"Test data shape: {test_full.shape}")


# =============================================================================
# 3. LOAD GENERATED DATA FOR COUNTERFACTUALS
# =============================================================================
generated_file_key = ABLATED_MODALITIES
if generated_file_key not in GENERATED_DATA_MAPPING_coh:
    raise ValueError(f"No generated data file specified for the modality combination: {generated_file_key}")


generated_file_name_coh = GENERATED_DATA_MAPPING_coh[generated_file_key]
generated_file_name_multi = GENERATED_DATA_MAPPING_multi[generated_file_key]

# Load the generated data for the ablated modalities
try:
    gen_data_long_coh = {
        mod: pd.read_csv(generated_file_name_coh, sep=',')
        for mod in ABLATED_MODALITIES
    }
    gen_data_long_multi = {
        mod: pd.read_csv(generated_file_name_multi, sep=',')
        for mod in ABLATED_MODALITIES
    }
except FileNotFoundError:
     print(f"Error: The generated data file '{generated_file_name_coh}' was not found.")
     exit()

# Reshape generated data: from long format to a list of dataframes (10 candidates per sample)
n_test_samples = len(test_full)
assert all(len(df) == n_test_samples * K_NEIGHBORS for df in gen_data_long_coh.values()), "Generated data size mismatch."

generated_candidates_coh = {mod: [] for mod in ABLATED_MODALITIES}
generated_candidates_multi = {mod: [] for mod in ABLATED_MODALITIES}
for mod in ABLATED_MODALITIES:
    for i in range(n_test_samples):
        gather_indices = np.arange(K_NEIGHBORS) * n_test_samples + i
        generated_candidates_coh[mod].append(gen_data_long_coh[mod].iloc[gather_indices])
        generated_candidates_multi[mod].append(gen_data_long_multi[mod].iloc[gather_indices])

print(f"Reshaped generated data for {len(generated_candidates_coh[ABLATED_MODALITIES[0]])} test samples.")


# =============================================================================
# 4. MODEL TRAINING AND BASELINE EVALUATION
# =============================================================================
print("\n--- Training and Baseline Evaluation ---")

# Filter data to include only samples with known cancer stage labels
train_mask = ~train_stage['stage'].isna()
train_data = train_full[train_mask]
train_labels_filtered = train_stage['stage'][train_mask]

test_mask = ~test_stage['stage'].isna()
test_data = test_full[test_mask].reset_index(drop=True)
test_labels_filtered = test_stage['stage'][test_mask].reset_index(drop=True)

# Encode labels
le = LabelEncoder()
train_labels = le.fit_transform(train_labels_filtered)
test_labels = le.transform(test_labels_filtered)

# Train a Random Forest classifier
rf = RandomForestClassifier(n_estimators=100, random_state=42, oob_score=True)
print("Training Random Forest model...")
rf.fit(train_data.values, train_labels)

# --- Baseline Performance ---
test_predictions = rf.predict(test_data.values)
test_f1 = f1_score(test_labels, test_predictions, average='weighted')
print(f"Baseline F1-Score on Test Set (all modalities present): {test_f1:.4f}")



# =============================================================================
# 6. PROGRESSIVE ABLATION ANALYSIS
# =============================================================================
print("\n--- Generating Progressive Ablation Comparison ---")
ablated_cols = [c for c in test_data.columns if c.split('_')[0] in ABLATED_MODALITIES]

# --- 6a. Selective Ablation (Generative Model) ---
print("Calculating selective ablation scores (Generative Model)...")
# Filter generated data to align with the filtered test set
original_indices = test_mask[test_mask].index
filtered_gen_candidates_coh = {
    mod: [generated_candidates_coh[mod][i] for i in original_indices]
    for mod in ABLATED_MODALITIES
}
filtered_gen_candidates_multi = {
    mod: [generated_candidates_multi[mod][i] for i in original_indices]
    for mod in ABLATED_MODALITIES
}    

prediction_variances_gen_coh = []
prediction_variances_gen_multi = []
for i in range(len(test_data)):
    original_sample = test_data.iloc[[i]]
    ablated_sample = original_sample.copy()
    ablated_sample[ablated_cols] = np.nan
    ablated_prediction = rf.predict(ablated_sample.values)[0]
    
    prediction_input_coh = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
    prediction_input_multi = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
    for mod in ABLATED_MODALITIES:
        mod_cols = [c for c in test_data.columns if c.startswith(f'{mod}_')]
        prediction_input_coh[mod_cols] = filtered_gen_candidates_coh[mod][i].values
        prediction_input_multi[mod_cols] = filtered_gen_candidates_multi[mod][i].values
    
    generated_predictions_coh = rf.predict(prediction_input_coh.values)
    generated_predictions_multi = rf.predict(prediction_input_multi.values)

    avg_variance_coh = np.mean((generated_predictions_coh != ablated_prediction))
    avg_variance_multi = np.mean((generated_predictions_multi != ablated_prediction))

    prediction_variances_gen_coh.append(avg_variance_coh)
    prediction_variances_gen_multi.append(avg_variance_multi)
    
variance_df_gen_coh = pd.DataFrame({
    'variance': prediction_variances_gen_coh,
    'original_index': test_data.index
}).sort_values(by='variance', ascending=True)

variance_df_gen_multi = pd.DataFrame({
    'variance': prediction_variances_gen_multi,
    'original_index': test_data.index
}).sort_values(by='variance', ascending=True)

gen_selective_f1_scores_coh = []
for ratio in ABLATION_STEPS:
    data_for_pred = test_data.copy()
    n_to_ablate = int(len(data_for_pred) * ratio)
    ablation_indices = variance_df_gen_coh.head(n_to_ablate)['original_index'].values
    if len(ablation_indices) > 0:
        data_for_pred.loc[ablation_indices, ablated_cols] = np.nan
    preds = rf.predict(data_for_pred.values)
    gen_selective_f1_scores_coh.append(f1_score(test_labels, preds, average='weighted', zero_division=0))

gen_selective_f1_scores_multi = []
for ratio in ABLATION_STEPS:
    data_for_pred = test_data.copy()
    n_to_ablate = int(len(data_for_pred) * ratio)
    ablation_indices = variance_df_gen_multi.head(n_to_ablate)['original_index'].values
    if len(ablation_indices) > 0:
        data_for_pred.loc[ablation_indices, ablated_cols] = np.nan
    preds = rf.predict(data_for_pred.values)
    gen_selective_f1_scores_multi.append(f1_score(test_labels, preds, average='weighted', zero_division=0))



# --- 6c. Random Progressive Ablation ---
print("Calculating random progressive ablation scores...")
random_ablation_f1_scores = []
for ratio in ABLATION_STEPS:
    iter_scores = []
    for _ in range(N_ITERATIONS):
        data_for_pred = test_data.copy()
        n_to_ablate = int(len(data_for_pred) * ratio)
        ablation_indices = np.random.choice(data_for_pred.index, size=n_to_ablate, replace=False)
        if len(ablation_indices) > 0:
            data_for_pred.loc[ablation_indices, ablated_cols] = np.nan
        preds = rf.predict(data_for_pred.values)
        iter_scores.append(f1_score(test_labels, preds, average='weighted', zero_division=0))
    random_ablation_f1_scores.append(np.mean(iter_scores))


# =============================================================================
# 7. PLOTTING RESULTS
# =============================================================================
print("Plotting results...")
plt.style.use('seaborn-v0_8-whitegrid')
plt.figure(figsize=(14, 8))

# Plot all three results
plt.plot(ABLATION_STEPS, random_ablation_f1_scores, marker='^', linestyle=':', linewidth=2, label='Random Ablation')
plt.plot(ABLATION_STEPS, gen_selective_f1_scores_coh, marker='o', linestyle='--', linewidth=2, label='Selective Ablation (Coherent)')
plt.plot(ABLATION_STEPS, gen_selective_f1_scores_multi, marker='s', linestyle='-', linewidth=2, label='Selective Ablation (Multi)')


plt.title(f"Impact of Ablating {', '.join(ABLATED_MODALITIES).upper()}: Random vs. Counterfactual Inference", fontsize=16)
plt.xlabel('Ratio of Samples with Ablated Modalities', fontsize=12)
plt.ylabel('Stage Prediction F1-Score', fontsize=12)
plt.xticks(ABLATION_STEPS, rotation=45)
plt.xlim(-0.02, 1.02)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()



In [None]:

# =============================================================================
# 7. PLOTTING RESULTS
# =============================================================================
print("Plotting results...")
plt.style.use('seaborn-v0_8-whitegrid')
plt.figure(figsize=(14, 8))

# Invert x-axis values (1 - x)
inverted_steps = [1 - x for x in ABLATION_STEPS]

# Plot all three results with inverted x-axis
plt.plot(inverted_steps, random_ablation_f1_scores, marker='^', linestyle=':', linewidth=2, label='Random Prioritization')
plt.plot(inverted_steps, gen_selective_f1_scores_coh, marker='o', linestyle='--', linewidth=2, label='Informed Prioritization (Coherent)')
plt.plot(inverted_steps, gen_selective_f1_scores_multi, marker='s', linestyle='-', linewidth=2, label='Informed Prioritization (Multi)')


plt.title(f"Impact of Observing {', '.join(ABLATED_MODALITIES).upper()}: Random Prioritization vs. Counterfactual Inference", fontsize=16)
plt.xlabel('Ratio of Samples with Observed Modalities', fontsize=12)
plt.ylabel('Stage Prediction F1-Score', fontsize=12)
plt.xticks(inverted_steps, rotation=45)
plt.xlim(-0.02, 1.02)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# 1. CONFIGURATION
# =============================================================================
# Define all modalities available in the dataset
ALL_MODALITIES = ['cna', 'rna', 'rppa', 'wsi']

# --- Main Setting ---
# Define which modalities to treat as "missing" or "ablated" for the analysis.
ABLATED_MODALITIES = ('rna',)

# Mapping from the ablated modalities tuple to the corresponding generated data file.
gen_path = '../../results/32/'

GENERATED_DATA_MAPPING_coh = {
    ('rna',): f'{gen_path}rnaseq_from_coherent/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_coherent/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

GENERATED_DATA_MAPPING_multi = {
    ('rna',): f'{gen_path}rnaseq_from_multi/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_multi/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

# --- Experiment Parameters ---
# MODIFICATION: Number of full experiment repeats to calculate standard deviation
N_EXPERIMENT_REPEATS = 10 
# Number of iterations for the random ablation baseline (within each experiment)
N_ITERATIONS = 30
# Steps for the progressive ablation plot (from 0% to 100% of samples ablated)
ABLATION_STEPS = np.arange(0, 1.05, 0.05)
# Number of neighbors for the k-NN baseline
K_NEIGHBORS = 10


# =============================================================================
# 2. DATA LOADING AND PREPARATION
# =============================================================================
print("Loading datasets...")

# Load raw training and testing data using your specified paths
train_path = '../../datasets_TCGA/07_normalized/32/'
train_raw = {
    'cna': pd.read_csv(f'{train_path}cna_train.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_train.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_train.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_train.csv', sep=','),
}

test_path = '../../datasets_TCGA/07_normalized/32/'
test_raw = {
    'cna': pd.read_csv(f'{train_path}cna_test.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_test.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_test.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_test.csv', sep=','),
}

labels_path = '../../datasets_TCGA/downstream_labels/'
train_stage = pd.read_csv(f'{labels_path}train_stage.csv')
test_stage = pd.read_csv(f'{labels_path}test_stage.csv')


def clean_and_unify_tables(df_dict, modalities):
    """
    Unifies multiple modality-specific dataframes into a single dataframe.
    """
    sample_ids = df_dict[modalities[0]]['sample_id'].values
    for m in modalities:
        assert np.all(df_dict[m]['sample_id'].values == sample_ids), f"Sample IDs for {m} do not match."

    processed_dfs = []
    for m in modalities:
        df = df_dict[m].drop(columns=['sample_id'])
        df.columns = [f'{m}_{i+1}' for i in range(df.shape[1])]
        processed_dfs.append(df)
        
    return pd.concat(processed_dfs, axis=1)

print("Processing and unifying tables...")
train_full = clean_and_unify_tables(train_raw, ALL_MODALITIES)
test_full = clean_and_unify_tables(test_raw, ALL_MODALITIES)

print(f"Train data shape: {train_full.shape}")
print(f"Test data shape: {test_full.shape}")

# Data filtering and label encoding (done once before the main loop)
train_mask = ~train_stage['stage'].isna()
train_data = train_full[train_mask]
train_labels_filtered = train_stage['stage'][train_mask]

test_mask = ~test_stage['stage'].isna()
test_data_full = test_full[test_mask].reset_index(drop=True)
test_labels_filtered = test_stage['stage'][test_mask].reset_index(drop=True)

le = LabelEncoder()
train_labels = le.fit_transform(train_labels_filtered)
test_labels = le.transform(test_labels_filtered)

# This section for loading generated data is kept as is from your original code
generated_file_key = ABLATED_MODALITIES
if generated_file_key not in GENERATED_DATA_MAPPING_coh:
    raise ValueError(f"No generated data file specified for the modality combination: {generated_file_key}")

generated_file_name_coh = GENERATED_DATA_MAPPING_coh[generated_file_key]
generated_file_name_multi = GENERATED_DATA_MAPPING_multi[generated_file_key]

try:
    gen_data_long_coh = {mod: pd.read_csv(generated_file_name_coh, sep=',') for mod in ABLATED_MODALITIES}
    gen_data_long_multi = {mod: pd.read_csv(generated_file_name_multi, sep=',') for mod in ABLATED_MODALITIES}
except FileNotFoundError:
    print(f"Error: A generated data file was not found. Please check paths.")
    exit()

# =============================================================================
# 3. FULL EXPERIMENT LOOP
# =============================================================================
# MODIFICATION: Create master lists to store results from all runs
all_runs_gen_coh = []
all_runs_gen_multi = []
all_runs_random = []

for run in range(N_EXPERIMENT_REPEATS):
    print(f"\n{'='*20} STARTING EXPERIMENT RUN {run + 1}/{N_EXPERIMENT_REPEATS} {'='*20}")
    
    # --- 3a. Model Training (inside the loop) ---
    # MODIFICATION: A new model is trained in each run with a different random_state
    rf = RandomForestClassifier(n_estimators=100, random_state=42 + run, oob_score=True)
    print(f"Run {run+1}: Training Random Forest model with random_state={42+run}...")
    rf.fit(train_data.values, train_labels)

    # --- 3b. Progressive Ablation Analysis (inside the loop) ---
    print(f"Run {run+1}: Calculating progressive ablation scores...")
    
    # Use a clean copy of the test data for each run
    test_data = test_data_full.copy()
    ablated_cols = [c for c in test_data.columns if c.split('_')[0] in ABLATED_MODALITIES]
    
    # The variance of predictions now depends on the specific model trained in this run.
    # Therefore, this calculation must be inside the loop.
    print(f"Run {run+1}: Calculating selective ablation variances...")
    
    # The logic for reshaping generated data and getting candidates is kept from your code
    n_test_samples = len(test_full)
    generated_candidates_coh = {mod: [] for mod in ABLATED_MODALITIES}
    generated_candidates_multi = {mod: [] for mod in ABLATED_MODALITIES}
    for mod in ABLATED_MODALITIES:
        for i in range(n_test_samples):
            gather_indices = np.arange(K_NEIGHBORS) * n_test_samples + i
            generated_candidates_coh[mod].append(gen_data_long_coh[mod].iloc[gather_indices])
            generated_candidates_multi[mod].append(gen_data_long_multi[mod].iloc[gather_indices])
    
    original_indices = test_mask[test_mask].index
    filtered_gen_candidates_coh = {mod: [generated_candidates_coh[mod][i] for i in original_indices] for mod in ABLATED_MODALITIES}
    filtered_gen_candidates_multi = {mod: [generated_candidates_multi[mod][i] for i in original_indices] for mod in ABLATED_MODALITIES}
    
    # Re-calculate prediction variances for the current model
    prediction_variances_gen_coh, prediction_variances_gen_multi = [], []
    for i in range(len(test_data)):
        original_sample = test_data.iloc[[i]]
        ablated_sample = original_sample.copy()
        ablated_sample[ablated_cols] = np.nan
        ablated_prediction = rf.predict(ablated_sample.values)[0]
        
        prediction_input_coh = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
        prediction_input_multi = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
        for mod in ABLATED_MODALITIES:
            mod_cols = [c for c in test_data.columns if c.startswith(f'{mod}_')]
            prediction_input_coh[mod_cols] = filtered_gen_candidates_coh[mod][i].values
            prediction_input_multi[mod_cols] = filtered_gen_candidates_multi[mod][i].values
            
        generated_predictions_coh = rf.predict(prediction_input_coh.values)
        generated_predictions_multi = rf.predict(prediction_input_multi.values)
        prediction_variances_gen_coh.append(np.mean((generated_predictions_coh != ablated_prediction)))
        prediction_variances_gen_multi.append(np.mean((generated_predictions_multi != ablated_prediction)))
        
    variance_df_gen_coh = pd.DataFrame({'variance': prediction_variances_gen_coh, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)
    variance_df_gen_multi = pd.DataFrame({'variance': prediction_variances_gen_multi, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)

    # Calculate F1 scores for all methods for the current model
    gen_selective_f1_scores_coh, gen_selective_f1_scores_multi, random_ablation_f1_scores = [], [], []

    for ratio in ABLATION_STEPS:
        # Selective Ablation (Coherent)
        data_for_pred_coh = test_data.copy()
        n_to_ablate_coh = int(len(data_for_pred_coh) * ratio)
        ablation_indices_coh = variance_df_gen_coh.head(n_to_ablate_coh)['original_index'].values
        if len(ablation_indices_coh) > 0:
            data_for_pred_coh.loc[ablation_indices_coh, ablated_cols] = np.nan
        preds_coh = rf.predict(data_for_pred_coh.values)
        gen_selective_f1_scores_coh.append(f1_score(test_labels, preds_coh, average='weighted', zero_division=0))

        # Selective Ablation (Multi)
        data_for_pred_multi = test_data.copy()
        n_to_ablate_multi = int(len(data_for_pred_multi) * ratio)
        ablation_indices_multi = variance_df_gen_multi.head(n_to_ablate_multi)['original_index'].values
        if len(ablation_indices_multi) > 0:
            data_for_pred_multi.loc[ablation_indices_multi, ablated_cols] = np.nan
        preds_multi = rf.predict(data_for_pred_multi.values)
        gen_selective_f1_scores_multi.append(f1_score(test_labels, preds_multi, average='weighted', zero_division=0))

        # Random Ablation
        iter_scores = []
        for _ in range(N_ITERATIONS):
            data_for_pred_rand = test_data.copy()
            n_to_ablate_rand = int(len(data_for_pred_rand) * ratio)
            ablation_indices_rand = np.random.choice(data_for_pred_rand.index, size=n_to_ablate_rand, replace=False)
            if len(ablation_indices_rand) > 0:
                data_for_pred_rand.loc[ablation_indices_rand, ablated_cols] = np.nan
            preds_rand = rf.predict(data_for_pred_rand.values)
            iter_scores.append(f1_score(test_labels, preds_rand, average='weighted', zero_division=0))
        random_ablation_f1_scores.append(np.mean(iter_scores))

    # --- Store the results of the current run ---
    all_runs_gen_coh.append(gen_selective_f1_scores_coh)
    all_runs_gen_multi.append(gen_selective_f1_scores_multi)
    all_runs_random.append(random_ablation_f1_scores)

print(f"\n{'='*20} ALL EXPERIMENT RUNS COMPLETED {'='*20}")

# =============================================================================
# 4. AGGREGATE RESULTS AND CALCULATE STATISTICS
# =============================================================================
print("\nAggregating results and calculating statistics...")

# Calculate mean and standard deviation across the experiment runs (axis=0)
mean_gen_coh = np.mean(all_runs_gen_coh, axis=0)
std_gen_coh = np.std(all_runs_gen_coh, axis=0)

mean_gen_multi = np.mean(all_runs_gen_multi, axis=0)
std_gen_multi = np.std(all_runs_gen_multi, axis=0)

mean_random = np.mean(all_runs_random, axis=0)
std_random = np.std(all_runs_random, axis=0)

# =============================================================================
# 5. PLOTTING RESULTS
# =============================================================================
print("Plotting aggregated results...")
plt.style.use('seaborn-v0_8-whitegrid')

# --- Plot 1: Ablation ---
plt.figure(figsize=(14, 8))

# Plot Random Ablation with Std Dev
plt.plot(ABLATION_STEPS, mean_random, marker='^', linestyle=':', linewidth=2, label='Random Ablation')
plt.fill_between(ABLATION_STEPS, mean_random - std_random, mean_random + std_random, alpha=0.2)

# Plot Selective Ablation (Coherent) with Std Dev
plt.plot(ABLATION_STEPS, mean_gen_coh, marker='o', linestyle='--', linewidth=2, label='Selective Ablation (Coherent)')
plt.fill_between(ABLATION_STEPS, mean_gen_coh - std_gen_coh, mean_gen_coh + std_gen_coh, alpha=0.2)

# Plot Selective Ablation (Multi) with Std Dev
plt.plot(ABLATION_STEPS, mean_gen_multi, marker='s', linestyle='-', linewidth=2, label='Selective Ablation (Multi)')
plt.fill_between(ABLATION_STEPS, mean_gen_multi - std_gen_multi, mean_gen_multi + std_gen_multi, alpha=0.2)

plt.title(f"Impact of Ablating {', '.join(ABLATED_MODALITIES).upper()} (Avg. of {N_EXPERIMENT_REPEATS} Runs)", fontsize=16)
plt.xlabel('Ratio of Samples with Ablated Modalities', fontsize=12)
plt.ylabel('Stage Prediction F1-Score', fontsize=12)
plt.xticks(ABLATION_STEPS, [f"{x:.0%}" for x in ABLATION_STEPS], rotation=45)
plt.xlim(-0.02, 1.02)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()

# --- Plot 2: Inverted (Prioritization) ---
plt.figure(figsize=(14, 8))
inverted_steps = 1 - ABLATION_STEPS

# Plot Random Prioritization with Std Dev
plt.plot(inverted_steps, mean_random, marker='^', linestyle=':', linewidth=2, label='Random Prioritization')
plt.fill_between(inverted_steps, mean_random - std_random, mean_random + std_random, alpha=0.2)

# Plot Informed Prioritization (Coherent) with Std Dev
plt.plot(inverted_steps, mean_gen_coh, marker='o', linestyle='--', linewidth=2, label='Informed Prioritization (Coherent)')
plt.fill_between(inverted_steps, mean_gen_coh - std_gen_coh, mean_gen_coh + std_gen_coh, alpha=0.2)

# Plot Informed Prioritization (Multi) with Std Dev
plt.plot(inverted_steps, mean_gen_multi, marker='s', linestyle='-', linewidth=2, label='Informed Prioritization (Multi)')
plt.fill_between(inverted_steps, mean_gen_multi - std_gen_multi, mean_gen_multi + std_gen_multi, alpha=0.2)

plt.title(f"Impact of Observing {', '.join(ABLATED_MODALITIES).upper()} - Random Prioritization vs. Counterfactual Inference", fontsize=16)
plt.xlabel('Ratio of Samples with Observed Modalities', fontsize=12)
plt.ylabel('Stage Prediction F1-Score', fontsize=12)
plt.xticks(inverted_steps, [f"{x:.0%}" for x in inverted_steps], rotation=45)
plt.xlim(1.02, -0.02)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()

In [None]:
# --- Plot 2: Inverted (Prioritization) ---
plt.figure(figsize=(14, 8))
inverted_steps = 1 - ABLATION_STEPS

# Plot Random Prioritization with Std Dev
plt.plot(inverted_steps, mean_random, marker='^', linestyle=':', linewidth=2, label='Random Prioritization')
plt.fill_between(inverted_steps, mean_random - std_random, mean_random + std_random, alpha=0.2)

# Plot Informed Prioritization (Coherent) with Std Dev
plt.plot(inverted_steps, mean_gen_coh, marker='o', linestyle='--', linewidth=2, label='Informed Prioritization (Coherent)')
plt.fill_between(inverted_steps, mean_gen_coh - std_gen_coh, mean_gen_coh + std_gen_coh, alpha=0.2)

# Plot Informed Prioritization (Multi) with Std Dev
plt.plot(inverted_steps, mean_gen_multi, marker='s', linestyle='-', linewidth=2, label='Informed Prioritization (Multi)')
plt.fill_between(inverted_steps, mean_gen_multi - std_gen_multi, mean_gen_multi + std_gen_multi, alpha=0.2)

plt.title(f"Impact of Observing {', '.join(ABLATED_MODALITIES).upper()} - Random Prioritization vs. Counterfactual Inference", fontsize=16)
plt.xlabel('Ratio of Samples with Observed Modalities', fontsize=12)
plt.ylabel('Stage Prediction F1-Score', fontsize=12)
plt.xticks(inverted_steps, [f"{x:.0%}" for x in inverted_steps], rotation=45)
plt.xlim(1.02, -0.02)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

def plot_prioritization_impact_polished(
    ablation_steps: np.ndarray,
    mean_random: np.ndarray,
    std_random: np.ndarray,
    mean_gen_coh: np.ndarray,
    std_gen_coh: np.ndarray,
    mean_gen_multi: np.ndarray,
    std_gen_multi: np.ndarray,
    ablated_modalities: List[str],
    color_map: Optional[Dict[str, str]] = None,
    legend_labels: Optional[Dict[str, str]] = None,
    figsize: Tuple[float, float] = (14, 8),
    title: Optional[str] = None,
    xlabel: str = 'Ratio of Samples with Observed Modalities',
    ylabel: str = 'Stage Prediction F1-Score',
    savepath: Optional[str] = None
):
    """
    Generates a polished line plot to show the impact of observing modalities.

    This function is adapted from a bar chart plotting style to a line plot,
    maintaining a high-quality, publication-ready aesthetic. It visualizes
    the performance of different prioritization strategies as more modalities
    are observed.

    Args:
        ablation_steps: Steps of ablation (e.g., from 0.0 to 1.0).
        mean_random: Mean performance for random prioritization.
        std_random: Standard deviation for random prioritization.
        mean_gen_coh: Mean performance for informed (coherent) prioritization.
        std_gen_coh: Standard deviation for informed (coherent) prioritization.
        mean_gen_multi: Mean performance for informed (multi) prioritization.
        std_gen_multi: Standard deviation for informed (multi) prioritization.
        ablated_modalities: List of names of the modalities being ablated.
        color_map: Optional dictionary to override default colors.
        legend_labels: Optional dictionary to override default legend labels.
        figsize: The size of the figure.
        title: The main title of the plot.
        xlabel: The label for the x-axis.
        ylabel: The label for the y-axis.
        savepath: Optional path to save the figure.
    """
    # --- 1. AESTHETIC DEFAULTS & SETUP ---
    if color_map is None:
        color_map = {
            'random': '#9c2409',        # Muted Red
            'coherent': '#56b4e9',      # Sky Blue
            'multi': '#0072b2',         # Deeper Blue
        }

    if legend_labels is None:
        legend_labels = {
            'random': 'Random Prioritization',
            'coherent': 'Informed Prioritization (Coherent)',
            'multi': 'Informed Prioritization (Multi)'
        }

    if title is None:
        title = f"Impact of Observing {', '.join(ablated_modalities).upper()} - Random Prioritization vs. Counterfactual Inference"

    # --- 2. DATA PREPARATION ---
    # The x-axis is the ratio of observed data, which is 1 minus the ablation ratio
    inverted_steps = 1 - ablation_steps

    # --- 3. PLOTTING ---
    plt.style.use('seaborn-v0_8-white')
    fig, ax = plt.subplots(figsize=figsize)

    # Plot Random Prioritization
    ax.plot(inverted_steps, mean_random, marker='o', linestyle='-', linewidth=2,
            label=legend_labels['random'], color=color_map['random'], zorder=10)
    ax.fill_between(inverted_steps, mean_random - std_random, mean_random + std_random,
                    alpha=0.15, color=color_map['random'], zorder=5)

    # Plot Informed Prioritization (Coherent)
    ax.plot(inverted_steps, mean_gen_coh, marker='s', linestyle='--', linewidth=2,
            label=legend_labels['coherent'], color=color_map['coherent'], zorder=10)
    ax.fill_between(inverted_steps, mean_gen_coh - std_gen_coh, mean_gen_coh + std_gen_coh,
                    alpha=0.15, color=color_map['coherent'], zorder=5)

    # Plot Informed Prioritization (Multi)
    ax.plot(inverted_steps, mean_gen_multi, marker='^', linestyle='--', linewidth=2,
            label=legend_labels['multi'], color=color_map['multi'], zorder=10)
    ax.fill_between(inverted_steps, mean_gen_multi - std_gen_multi, mean_gen_multi + std_gen_multi,
                    alpha=0.15, color=color_map['multi'], zorder=5)


    # --- 4. VISUAL REFINEMENTS ---

    # "Classy" horizontal grid
    ax.yaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.xaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.set_axisbelow(True)

    # Less prominent axes and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_color('#B0B0B0')
    ax.tick_params(axis='x', colors='#505050', length=0)
    ax.tick_params(axis='y', colors='#505050', length=0)
    ax.xaxis.label.set_color('#303030')
    ax.yaxis.label.set_color('#303030')

    # X-axis ticks and limits (inverted)
    new_ticks = np.arange(0, 1.1, 0.1)
    ax.set_xticks(new_ticks)
    ax.set_xticklabels([f"{x:.1f}" for x in new_ticks], rotation=0)
    ax.set_xlim(1.02, -0.02) # Invert the axis from 100% down to 0%

    # Y-axis limits adjusted to data range for readability
    all_mins = np.concatenate([
        mean_random - std_random,
        mean_gen_coh - std_gen_coh,
        mean_gen_multi - std_gen_multi
    ])
    all_maxs = np.concatenate([
        mean_random + std_random,
        mean_gen_coh + std_gen_coh,
        mean_gen_multi + std_gen_multi
    ])
    plot_min, plot_max = all_mins.min(), all_maxs.max()
    padding = (plot_max - plot_min) * 0.05
    ax.set_ylim(plot_min - padding, plot_max + padding)

    # Final labels and title
    ax.set_xlabel(xlabel, fontsize=12, labelpad=15)
    ax.set_ylabel(ylabel, fontsize=12, labelpad=15)
    ax.set_title(title, fontsize=16, pad=20, weight='bold')

    # Legend on the right
    ax.legend(
        title='Prioritization Strategy',
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        frameon=False,
        fontsize=11,
        title_fontsize=12
    )
    
    # --- 5. FINAL LAYOUT ADJUSTMENTS ---
    fig.subplots_adjust(
        left=0.07,
        right=0.82, # Make space for the legend
        bottom=0.1,
        top=0.9
    )

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')

    plt.show()




# Call the polished plotting function
plot_prioritization_impact_polished(
    ablation_steps=ABLATION_STEPS,
    mean_random=mean_random,
    std_random=std_random,
    mean_gen_coh=mean_gen_coh,
    std_gen_coh=std_gen_coh,
    mean_gen_multi=mean_gen_multi,
    std_gen_multi=std_gen_multi,
    ablated_modalities=ABLATED_MODALITIES,
    savepath=f"../../results/downstream/task_07_counterfactual/{'_'.join(ABLATED_MODALITIES)}_prioritization_impact_plot.png"
)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

def plot_prioritization_impact_polished(
    ablation_steps: np.ndarray,
    mean_random: np.ndarray,
    std_random: np.ndarray,
    mean_gen_coh: np.ndarray,
    std_gen_coh: np.ndarray,
    mean_gen_multi: np.ndarray,
    std_gen_multi: np.ndarray,
    ablated_modalities: List[str],
    color_map: Optional[Dict[str, str]] = None,
    legend_labels: Optional[Dict[str, str]] = None,
    figsize: Tuple[float, float] = (14, 8),
    title: Optional[str] = None,
    xlabel: str = 'Ratio of Samples with Observed Modalities',
    ylabel: str = 'Stage Prediction F1-Score',
    savepath: Optional[str] = None
):
    """
    Generates a polished line plot to show the impact of observing modalities.

    This function is adapted from a bar chart plotting style to a line plot,
    maintaining a high-quality, publication-ready aesthetic. It visualizes
    the performance of different prioritization strategies as more modalities
    are observed.

    Args:
        ablation_steps: Steps of ablation (e.g., from 0.0 to 1.0).
        mean_random: Mean performance for random prioritization.
        std_random: Standard deviation for random prioritization.
        mean_gen_coh: Mean performance for informed (coherent) prioritization.
        std_gen_coh: Standard deviation for informed (coherent) prioritization.
        mean_gen_multi: Mean performance for informed (multi) prioritization.
        std_gen_multi: Standard deviation for informed (multi) prioritization.
        ablated_modalities: List of names of the modalities being ablated.
        color_map: Optional dictionary to override default colors.
        legend_labels: Optional dictionary to override default legend labels.
        figsize: The size of the figure.
        title: The main title of the plot.
        xlabel: The label for the x-axis.
        ylabel: The label for the y-axis.
        savepath: Optional path to save the figure.
    """
    # --- 1. AESTHETIC DEFAULTS & SETUP ---
    if color_map is None:
        color_map = {
            'random': '#9c2409',        # Muted Red
            'coherent': '#56b4e9',      # Sky Blue
            'multi': '#0072b2',         # Deeper Blue
        }

    if legend_labels is None:
        legend_labels = {
            'random': 'Random Prioritization',
            'coherent': 'Informed Prioritization (Coherent)',
            'multi': 'Informed Prioritization (Multi)'
        }

    if title is None:
        title = f"Impact of Observing {', '.join(ablated_modalities).upper()} - Random Prioritization vs. Counterfactual Inference"

    # --- 2. DATA PREPARATION ---
    # The x-axis is the ratio of observed data, which is 1 minus the ablation ratio
    inverted_steps = 1 - ablation_steps

    # --- 3. PLOTTING ---
    plt.style.use('seaborn-v0_8-white')
    fig, ax = plt.subplots(figsize=figsize)

    # Plot Random Prioritization with error bars
    ax.errorbar(inverted_steps, mean_random, yerr=std_random, marker='o', linestyle='-',
                linewidth=2, label=legend_labels['random'], color=color_map['random'],
                zorder=10, capsize=5)

    # Plot Informed Prioritization (Coherent) with error bars
    ax.errorbar(inverted_steps, mean_gen_coh, yerr=std_gen_coh, marker='s', linestyle='--',
                linewidth=2, label=legend_labels['coherent'], color=color_map['coherent'],
                zorder=10, capsize=5)

    # Plot Informed Prioritization (Multi) with error bars
    ax.errorbar(inverted_steps, mean_gen_multi, yerr=std_gen_multi, marker='^', linestyle='--',
            linewidth=2, label=legend_labels['multi'], color=color_map['multi'],
            zorder=10, capsize=5)


    # --- 4. VISUAL REFINEMENTS ---

    # "Classy" horizontal grid
    ax.yaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.xaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.set_axisbelow(True)

    # Less prominent axes and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_color('#B0B0B0')
    ax.tick_params(axis='x', colors='#505050', length=0)
    ax.tick_params(axis='y', colors='#505050', length=0)
    ax.xaxis.label.set_color('#303030')
    ax.yaxis.label.set_color('#303030')

    # X-axis ticks and limits (inverted)
    new_ticks = np.arange(0, 1.1, 0.1)
    ax.set_xticks(new_ticks)
    ax.set_xticklabels([f"{x:.1f}" for x in new_ticks], rotation=0)
    ax.set_xlim(1.02, -0.02) # Invert the axis from 100% down to 0%

    # Y-axis limits adjusted to data range for readability
    all_mins = np.concatenate([
        mean_random - std_random,
        mean_gen_coh - std_gen_coh,
        mean_gen_multi - std_gen_multi
    ])
    all_maxs = np.concatenate([
        mean_random + std_random,
        mean_gen_coh + std_gen_coh,
        mean_gen_multi + std_gen_multi
    ])
    plot_min, plot_max = all_mins.min(), all_maxs.max()
    padding = (plot_max - plot_min) * 0.05
    ax.set_ylim(plot_min - padding, plot_max + padding)

    # Final labels and title
    ax.set_xlabel(xlabel, fontsize=12, labelpad=15)
    ax.set_ylabel(ylabel, fontsize=12, labelpad=15)
    ax.set_title(title, fontsize=16, pad=20, weight='bold')

    # Legend on the right
    ax.legend(
        title='Prioritization Strategy',
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        frameon=False,
        fontsize=11,
        title_fontsize=12
    )
    
    # --- 5. FINAL LAYOUT ADJUSTMENTS ---
    fig.subplots_adjust(
        left=0.07,
        right=0.82, # Make space for the legend
        bottom=0.1,
        top=0.9
    )

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')

    plt.show()




# Call the polished plotting function
plot_prioritization_impact_polished(
    ablation_steps=ABLATION_STEPS,
    mean_random=mean_random,
    std_random=std_random,
    mean_gen_coh=mean_gen_coh,
    std_gen_coh=std_gen_coh,
    mean_gen_multi=mean_gen_multi,
    std_gen_multi=std_gen_multi,
    ablated_modalities=ABLATED_MODALITIES,
    savepath=f"../../results/downstream/task_07_counterfactual/{'_'.join(ABLATED_MODALITIES)}_prioritization_impact_plot_err.png"
)

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# 1. CONFIGURATION
# =============================================================================
# Define all modalities available in the dataset
ALL_MODALITIES = ['cna', 'rna', 'rppa', 'wsi']

# --- Main Setting ---
# Define which modalities to treat as "missing" or "ablated" for the analysis.
ABLATED_MODALITIES = ('wsi',)

# Mapping from the ablated modalities tuple to the corresponding generated data file.
gen_path = '../../results/32/'

GENERATED_DATA_MAPPING_coh = {
    ('rna',): f'{gen_path}rnaseq_from_coherent/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_coherent/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

GENERATED_DATA_MAPPING_multi = {
    ('rna',): f'{gen_path}rnaseq_from_multi/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_multi/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

# --- Experiment Parameters ---
# MODIFICATION: Number of full experiment repeats to calculate standard deviation
N_EXPERIMENT_REPEATS = 10 
# Number of iterations for the random ablation baseline (within each experiment)
N_ITERATIONS = 30
# Steps for the progressive ablation plot (from 0% to 100% of samples ablated)
ABLATION_STEPS = np.arange(0, 1.05, 0.05)
# Number of neighbors for the k-NN baseline
K_NEIGHBORS = 10


# =============================================================================
# 2. DATA LOADING AND PREPARATION
# =============================================================================
print("Loading datasets...")

# Load raw training and testing data using your specified paths
train_path = '../../datasets_TCGA/07_normalized/32/'
train_raw = {
    'cna': pd.read_csv(f'{train_path}cna_train.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_train.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_train.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_train.csv', sep=','),
}

test_path = '../../datasets_TCGA/07_normalized/32/'
test_raw = {
    'cna': pd.read_csv(f'{train_path}cna_test.csv', sep=','),
    'rna': pd.read_csv(f'{train_path}rnaseq_test.csv', sep=','),
    'rppa': pd.read_csv(f'{train_path}rppa_test.csv', sep=','),
    'wsi': pd.read_csv(f'{train_path}wsi_test.csv', sep=','),
}

labels_path = '../../datasets_TCGA/downstream_labels/'
train_stage = pd.read_csv(f'{labels_path}train_stage.csv')
test_stage = pd.read_csv(f'{labels_path}test_stage.csv')


def clean_and_unify_tables(df_dict, modalities):
    """
    Unifies multiple modality-specific dataframes into a single dataframe.
    """
    sample_ids = df_dict[modalities[0]]['sample_id'].values
    for m in modalities:
        assert np.all(df_dict[m]['sample_id'].values == sample_ids), f"Sample IDs for {m} do not match."

    processed_dfs = []
    for m in modalities:
        df = df_dict[m].drop(columns=['sample_id'])
        df.columns = [f'{m}_{i+1}' for i in range(df.shape[1])]
        processed_dfs.append(df)
        
    return pd.concat(processed_dfs, axis=1)

print("Processing and unifying tables...")
train_full = clean_and_unify_tables(train_raw, ALL_MODALITIES)
test_full = clean_and_unify_tables(test_raw, ALL_MODALITIES)

print(f"Train data shape: {train_full.shape}")
print(f"Test data shape: {test_full.shape}")

# Data filtering and label encoding (done once before the main loop)
train_mask = ~train_stage['stage'].isna()
train_data = train_full[train_mask]
train_labels_filtered = train_stage['stage'][train_mask]

test_mask = ~test_stage['stage'].isna()
test_data_full = test_full[test_mask].reset_index(drop=True)
test_labels_filtered = test_stage['stage'][test_mask].reset_index(drop=True)

le = LabelEncoder()
train_labels = le.fit_transform(train_labels_filtered)
test_labels = le.transform(test_labels_filtered)

# This section for loading generated data is kept as is from your original code
generated_file_key = ABLATED_MODALITIES
if generated_file_key not in GENERATED_DATA_MAPPING_coh:
    raise ValueError(f"No generated data file specified for the modality combination: {generated_file_key}")

generated_file_name_coh = GENERATED_DATA_MAPPING_coh[generated_file_key]
generated_file_name_multi = GENERATED_DATA_MAPPING_multi[generated_file_key]

try:
    gen_data_long_coh = {mod: pd.read_csv(generated_file_name_coh, sep=',') for mod in ABLATED_MODALITIES}
    gen_data_long_multi = {mod: pd.read_csv(generated_file_name_multi, sep=',') for mod in ABLATED_MODALITIES}
except FileNotFoundError:
    print(f"Error: A generated data file was not found. Please check paths.")
    exit()

# =============================================================================
# 3. FULL EXPERIMENT LOOP
# =============================================================================
# MODIFICATION: Create master lists to store results from all runs
all_runs_gen_coh = []
all_runs_gen_multi = []
all_runs_random = []

for run in range(N_EXPERIMENT_REPEATS):
    print(f"\n{'='*20} STARTING EXPERIMENT RUN {run + 1}/{N_EXPERIMENT_REPEATS} {'='*20}")
    
    # --- 3a. Model Training (inside the loop) ---
    # MODIFICATION: A new model is trained in each run with a different random_state
    rf = RandomForestClassifier(n_estimators=100, random_state=42 + run, oob_score=True)
    print(f"Run {run+1}: Training Random Forest model with random_state={42+run}...")
    rf.fit(train_data.values, train_labels)

    # --- 3b. Progressive Ablation Analysis (inside the loop) ---
    print(f"Run {run+1}: Calculating progressive ablation scores...")
    
    # Use a clean copy of the test data for each run
    test_data = test_data_full.copy()
    ablated_cols = [c for c in test_data.columns if c.split('_')[0] in ABLATED_MODALITIES]
    
    # The variance of predictions now depends on the specific model trained in this run.
    # Therefore, this calculation must be inside the loop.
    print(f"Run {run+1}: Calculating selective ablation variances...")
    
    # The logic for reshaping generated data and getting candidates is kept from your code
    n_test_samples = len(test_full)
    generated_candidates_coh = {mod: [] for mod in ABLATED_MODALITIES}
    generated_candidates_multi = {mod: [] for mod in ABLATED_MODALITIES}
    for mod in ABLATED_MODALITIES:
        for i in range(n_test_samples):
            gather_indices = np.arange(K_NEIGHBORS) * n_test_samples + i
            generated_candidates_coh[mod].append(gen_data_long_coh[mod].iloc[gather_indices])
            generated_candidates_multi[mod].append(gen_data_long_multi[mod].iloc[gather_indices])
    
    original_indices = test_mask[test_mask].index
    filtered_gen_candidates_coh = {mod: [generated_candidates_coh[mod][i] for i in original_indices] for mod in ABLATED_MODALITIES}
    filtered_gen_candidates_multi = {mod: [generated_candidates_multi[mod][i] for i in original_indices] for mod in ABLATED_MODALITIES}
    
    # Re-calculate prediction variances for the current model
    prediction_variances_gen_coh, prediction_variances_gen_multi = [], []
    for i in range(len(test_data)):
        original_sample = test_data.iloc[[i]]
        ablated_sample = original_sample.copy()
        ablated_sample[ablated_cols] = np.nan
        ablated_prediction = rf.predict(ablated_sample.values)[0]
        
        prediction_input_coh = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
        prediction_input_multi = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
        for mod in ABLATED_MODALITIES:
            mod_cols = [c for c in test_data.columns if c.startswith(f'{mod}_')]
            prediction_input_coh[mod_cols] = filtered_gen_candidates_coh[mod][i].values
            prediction_input_multi[mod_cols] = filtered_gen_candidates_multi[mod][i].values
            
        generated_predictions_coh = rf.predict(prediction_input_coh.values)
        generated_predictions_multi = rf.predict(prediction_input_multi.values)
        prediction_variances_gen_coh.append(np.mean((generated_predictions_coh != ablated_prediction)))
        prediction_variances_gen_multi.append(np.mean((generated_predictions_multi != ablated_prediction)))
        
    variance_df_gen_coh = pd.DataFrame({'variance': prediction_variances_gen_coh, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)
    variance_df_gen_multi = pd.DataFrame({'variance': prediction_variances_gen_multi, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)

    # Calculate F1 scores for all methods for the current model
    gen_selective_f1_scores_coh, gen_selective_f1_scores_multi, random_ablation_f1_scores = [], [], []

    for ratio in ABLATION_STEPS:
        # Selective Ablation (Coherent)
        data_for_pred_coh = test_data.copy()
        n_to_ablate_coh = int(len(data_for_pred_coh) * ratio)
        ablation_indices_coh = variance_df_gen_coh.head(n_to_ablate_coh)['original_index'].values
        if len(ablation_indices_coh) > 0:
            data_for_pred_coh.loc[ablation_indices_coh, ablated_cols] = np.nan
        preds_coh = rf.predict(data_for_pred_coh.values)
        gen_selective_f1_scores_coh.append(f1_score(test_labels, preds_coh, average='weighted', zero_division=0))

        # Selective Ablation (Multi)
        data_for_pred_multi = test_data.copy()
        n_to_ablate_multi = int(len(data_for_pred_multi) * ratio)
        ablation_indices_multi = variance_df_gen_multi.head(n_to_ablate_multi)['original_index'].values
        if len(ablation_indices_multi) > 0:
            data_for_pred_multi.loc[ablation_indices_multi, ablated_cols] = np.nan
        preds_multi = rf.predict(data_for_pred_multi.values)
        gen_selective_f1_scores_multi.append(f1_score(test_labels, preds_multi, average='weighted', zero_division=0))

        # Random Ablation
        iter_scores = []
        for _ in range(N_ITERATIONS):
            data_for_pred_rand = test_data.copy()
            n_to_ablate_rand = int(len(data_for_pred_rand) * ratio)
            ablation_indices_rand = np.random.choice(data_for_pred_rand.index, size=n_to_ablate_rand, replace=False)
            if len(ablation_indices_rand) > 0:
                data_for_pred_rand.loc[ablation_indices_rand, ablated_cols] = np.nan
            preds_rand = rf.predict(data_for_pred_rand.values)
            iter_scores.append(f1_score(test_labels, preds_rand, average='weighted', zero_division=0))
        random_ablation_f1_scores.append(np.mean(iter_scores))

    # --- Store the results of the current run ---
    all_runs_gen_coh.append(gen_selective_f1_scores_coh)
    all_runs_gen_multi.append(gen_selective_f1_scores_multi)
    all_runs_random.append(random_ablation_f1_scores)

print(f"\n{'='*20} ALL EXPERIMENT RUNS COMPLETED {'='*20}")

# =============================================================================
# 4. AGGREGATE RESULTS AND CALCULATE STATISTICS
# =============================================================================
print("\nAggregating results and calculating statistics...")

# Calculate mean and standard deviation across the experiment runs (axis=0)
mean_gen_coh = np.mean(all_runs_gen_coh, axis=0)
std_gen_coh = np.std(all_runs_gen_coh, axis=0)

mean_gen_multi = np.mean(all_runs_gen_multi, axis=0)
std_gen_multi = np.std(all_runs_gen_multi, axis=0)

mean_random = np.mean(all_runs_random, axis=0)
std_random = np.std(all_runs_random, axis=0)



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Optional

def plot_prioritization_impact_polished(
    ablation_steps: np.ndarray,
    mean_random: np.ndarray,
    std_random: np.ndarray,
    mean_gen_coh: np.ndarray,
    std_gen_coh: np.ndarray,
    mean_gen_multi: np.ndarray,
    std_gen_multi: np.ndarray,
    ablated_modalities: List[str],
    color_map: Optional[Dict[str, str]] = None,
    legend_labels: Optional[Dict[str, str]] = None,
    figsize: Tuple[float, float] = (14, 8),
    title: Optional[str] = None,
    xlabel: str = 'Ratio of Samples with Observed Modalities',
    ylabel: str = 'Stage Prediction F1-Score',
    savepath: Optional[str] = None
):
    """
    Generates a polished line plot to show the impact of observing modalities.

    This function is adapted from a bar chart plotting style to a line plot,
    maintaining a high-quality, publication-ready aesthetic. It visualizes
    the performance of different prioritization strategies as more modalities
    are observed.

    Args:
        ablation_steps: Steps of ablation (e.g., from 0.0 to 1.0).
        mean_random: Mean performance for random prioritization.
        std_random: Standard deviation for random prioritization.
        mean_gen_coh: Mean performance for informed (coherent) prioritization.
        std_gen_coh: Standard deviation for informed (coherent) prioritization.
        mean_gen_multi: Mean performance for informed (multi) prioritization.
        std_gen_multi: Standard deviation for informed (multi) prioritization.
        ablated_modalities: List of names of the modalities being ablated.
        color_map: Optional dictionary to override default colors.
        legend_labels: Optional dictionary to override default legend labels.
        figsize: The size of the figure.
        title: The main title of the plot.
        xlabel: The label for the x-axis.
        ylabel: The label for the y-axis.
        savepath: Optional path to save the figure.
    """
    # --- 1. AESTHETIC DEFAULTS & SETUP ---
    if color_map is None:
        color_map = {
            'random': '#9c2409',        # Muted Red
            'coherent': '#56b4e9',      # Sky Blue
            'multi': '#0072b2',         # Deeper Blue
        }

    if legend_labels is None:
        legend_labels = {
            'random': 'Random Prioritization',
            'coherent': 'Informed Prioritization (Coherent)',
            'multi': 'Informed Prioritization (Multi)'
        }

    if title is None:
        title = f"Impact of Observing {', '.join(ablated_modalities).upper()} - Random Prioritization vs. Counterfactual Inference"

    # --- 2. DATA PREPARATION ---
    # The x-axis is the ratio of observed data, which is 1 minus the ablation ratio
    inverted_steps = 1 - ablation_steps

    # --- 3. PLOTTING ---
    plt.style.use('seaborn-v0_8-white')
    fig, ax = plt.subplots(figsize=figsize)

    # Plot Random Prioritization
    ax.plot(inverted_steps, mean_random, marker='o', linestyle='-', linewidth=2,
            label=legend_labels['random'], color=color_map['random'], zorder=10)
    ax.fill_between(inverted_steps, mean_random - std_random, mean_random + std_random,
                    alpha=0.15, color=color_map['random'], zorder=5)

    # Plot Informed Prioritization (Coherent)
    ax.plot(inverted_steps, mean_gen_coh, marker='s', linestyle='--', linewidth=2,
            label=legend_labels['coherent'], color=color_map['coherent'], zorder=10)
    ax.fill_between(inverted_steps, mean_gen_coh - std_gen_coh, mean_gen_coh + std_gen_coh,
                    alpha=0.15, color=color_map['coherent'], zorder=5)

    # Plot Informed Prioritization (Multi)
    ax.plot(inverted_steps, mean_gen_multi, marker='^', linestyle='--', linewidth=2,
            label=legend_labels['multi'], color=color_map['multi'], zorder=10)
    ax.fill_between(inverted_steps, mean_gen_multi - std_gen_multi, mean_gen_multi + std_gen_multi,
                    alpha=0.15, color=color_map['multi'], zorder=5)


    # --- 4. VISUAL REFINEMENTS ---

    # "Classy" horizontal grid
    ax.yaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.xaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
    ax.set_axisbelow(True)

    # Less prominent axes and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_color('#B0B0B0')
    ax.tick_params(axis='x', colors='#505050', length=0)
    ax.tick_params(axis='y', colors='#505050', length=0)
    ax.xaxis.label.set_color('#303030')
    ax.yaxis.label.set_color('#303030')

    # X-axis ticks and limits (inverted)
    new_ticks = np.arange(0, 1.1, 0.1)
    ax.set_xticks(new_ticks)
    ax.set_xticklabels([f"{x:.1f}" for x in new_ticks], rotation=0)
    ax.set_xlim(1.02, -0.02) # Invert the axis from 100% down to 0%

    # Y-axis limits adjusted to data range for readability
    all_mins = np.concatenate([
        mean_random - std_random,
        mean_gen_coh - std_gen_coh,
        mean_gen_multi - std_gen_multi
    ])
    all_maxs = np.concatenate([
        mean_random + std_random,
        mean_gen_coh + std_gen_coh,
        mean_gen_multi + std_gen_multi
    ])
    plot_min, plot_max = all_mins.min(), all_maxs.max()
    padding = (plot_max - plot_min) * 0.05
    ax.set_ylim(plot_min - padding, plot_max + padding)

    # Final labels and title
    ax.set_xlabel(xlabel, fontsize=12, labelpad=15)
    ax.set_ylabel(ylabel, fontsize=12, labelpad=15)
    ax.set_title(title, fontsize=16, pad=20, weight='bold')

    # Legend on the right
    ax.legend(
        title='Prioritization Strategy',
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        frameon=False,
        fontsize=11,
        title_fontsize=12
    )
    
    # --- 5. FINAL LAYOUT ADJUSTMENTS ---
    fig.subplots_adjust(
        left=0.07,
        right=0.82, # Make space for the legend
        bottom=0.1,
        top=0.9
    )

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')

    plt.show()




# Call the polished plotting function
plot_prioritization_impact_polished(
    ablation_steps=ABLATION_STEPS,
    mean_random=mean_random,
    std_random=std_random,
    mean_gen_coh=mean_gen_coh,
    std_gen_coh=std_gen_coh,
    mean_gen_multi=mean_gen_multi,
    std_gen_multi=std_gen_multi,
    ablated_modalities=ABLATED_MODALITIES,
    savepath=f"../../results/downstream/task_07_counterfactual/{'_'.join(ABLATED_MODALITIES)}_prioritization_impact_plot.png"
)

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import json

# =============================================================================
# 1. CONFIGURATION
# =============================================================================
# Define all modalities available in the dataset
ALL_MODALITIES = ['cna', 'rna', 'rppa', 'wsi']

# --- Experiment Parameters ---
# MODIFICATION: Number of full experiment repeats to calculate standard deviation
N_EXPERIMENT_REPEATS =  10
# Number of iterations for the random ablation baseline (within each experiment)
N_ITERATIONS = 30
# Steps for the progressive ablation plot (from 0% to 100% of samples ablated)
ABLATION_STEPS = np.arange(0, 1.05, 0.05)
# Number of neighbors for the k-NN baseline
K_NEIGHBORS = 10

# --- File Paths ---
gen_path = '../../results/32/'
train_path = '../../datasets_TCGA/07_normalized/32/'
test_path = '../../datasets_TCGA/07_normalized/32/'
labels_path = '../../datasets_TCGA/downstream_labels/'

# Mapping from the ablated modalities tuple to the corresponding generated data file.
GENERATED_DATA_MAPPING_coh = {
    ('rna',): f'{gen_path}rnaseq_from_coherent/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_coherent/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}

GENERATED_DATA_MAPPING_multi = {
    ('rna',): f'{gen_path}rnaseq_from_multi/test/generated_samples_from_cna_rppa_wsi_best_mse.csv',
    ('wsi',): f'{gen_path}wsi_from_multi/test/generated_samples_from_cna_rnaseq_rppa_best_mse.csv',
}


# =============================================================================
# 2. HELPER FUNCTIONS
# =============================================================================
def clean_and_unify_tables(df_dict, modalities):
    """
    Unifies multiple modality-specific dataframes into a single dataframe.
    """
    sample_ids = df_dict[modalities[0]]['sample_id'].values
    for m in modalities:
        assert np.all(df_dict[m]['sample_id'].values == sample_ids), f"Sample IDs for {m} do not match."

    processed_dfs = []
    for m in modalities:
        df = df_dict[m].drop(columns=['sample_id'])
        df.columns = [f'{m}_{i+1}' for i in range(df.shape[1])]
        processed_dfs.append(df)
        
    return pd.concat(processed_dfs, axis=1)

# =============================================================================
# 3. CORE EXPERIMENT FUNCTION
# =============================================================================
def run_ablation_experiment(ablated_modalities: Tuple[str, ...]):
    """
    Runs the full progressive ablation experiment for a given set of ablated modalities.
    
    Args:
        ablated_modalities: A tuple containing the names of modalities to treat as missing.

    Returns:
        A dictionary containing the mean and standard deviation of F1 scores for each method.
    """
    print(f"\n{'='*30}\nRUNNING EXPERIMENT FOR ABLATED MODALITY: {str(ablated_modalities).upper()}\n{'='*30}")

    # --- 3a. Data Loading and Preparation ---
    print("Loading datasets...")
    train_raw = {
        'cna': pd.read_csv(f'{train_path}cna_train.csv', sep=','),
        'rna': pd.read_csv(f'{train_path}rnaseq_train.csv', sep=','),
        'rppa': pd.read_csv(f'{train_path}rppa_train.csv', sep=','),
        'wsi': pd.read_csv(f'{train_path}wsi_train.csv', sep=','),
    }

    test_raw = {
        'cna': pd.read_csv(f'{test_path}cna_test.csv', sep=','),
        'rna': pd.read_csv(f'{test_path}rnaseq_test.csv', sep=','),
        'rppa': pd.read_csv(f'{test_path}rppa_test.csv', sep=','),
        'wsi': pd.read_csv(f'{test_path}wsi_test.csv', sep=','),
    }

    train_stage = pd.read_csv(f'{labels_path}train_stage.csv')
    test_stage = pd.read_csv(f'{labels_path}test_stage.csv')

    print("Processing and unifying tables...")
    train_full = clean_and_unify_tables(train_raw, ALL_MODALITIES)
    test_full = clean_and_unify_tables(test_raw, ALL_MODALITIES)

    train_mask = ~train_stage['stage'].isna()
    train_data = train_full[train_mask]
    train_labels_filtered = train_stage['stage'][train_mask]

    test_mask = ~test_stage['stage'].isna()
    test_data_full = test_full[test_mask].reset_index(drop=True)
    test_labels_filtered = test_stage['stage'][test_mask].reset_index(drop=True)

    le = LabelEncoder()
    train_labels = le.fit_transform(train_labels_filtered)
    test_labels = le.transform(test_labels_filtered)

    # Load generated data
    generated_file_key = ablated_modalities
    if generated_file_key not in GENERATED_DATA_MAPPING_coh:
        raise ValueError(f"No generated data file specified for the modality combination: {generated_file_key}")

    try:
        gen_data_long_coh = {mod: pd.read_csv(GENERATED_DATA_MAPPING_coh[generated_file_key], sep=',') for mod in ablated_modalities}
        gen_data_long_multi = {mod: pd.read_csv(GENERATED_DATA_MAPPING_multi[generated_file_key], sep=',') for mod in ablated_modalities}
    except FileNotFoundError as e:
        print(f"Error: A generated data file was not found: {e}. Please check paths.")
        return None

    # --- 3b. Full Experiment Loop ---
    all_runs_gen_coh, all_runs_gen_multi, all_runs_random = [], [], []

    for run in range(N_EXPERIMENT_REPEATS):
        print(f"\n--- Starting Experiment Run {run + 1}/{N_EXPERIMENT_REPEATS} for {str(ablated_modalities).upper()} ---")
        
        # A new model is trained in each run with a different random_state
        rf = RandomForestClassifier(n_estimators=100, random_state=42 + run, oob_score=True, n_jobs=-1)
        rf.fit(train_data.values, train_labels)

        # Progressive Ablation Analysis
        test_data = test_data_full.copy()
        ablated_cols = [c for c in test_data.columns if c.split('_')[0] in ablated_modalities]
        
        # Prepare generated candidates
        n_test_samples = len(test_full)
        generated_candidates_coh = {mod: [df.iloc[np.arange(K_NEIGHBORS) * n_test_samples + i] for i in range(n_test_samples)] for mod, df in gen_data_long_coh.items()}
        generated_candidates_multi = {mod: [df.iloc[np.arange(K_NEIGHBORS) * n_test_samples + i] for i in range(n_test_samples)] for mod, df in gen_data_long_multi.items()}
        
        original_indices = test_mask[test_mask].index
        filtered_gen_candidates_coh = {mod: [generated_candidates_coh[mod][i] for i in original_indices] for mod in ablated_modalities}
        filtered_gen_candidates_multi = {mod: [generated_candidates_multi[mod][i] for i in original_indices] for mod in ablated_modalities}
        
        # Calculate prediction variances
        prediction_variances_gen_coh, prediction_variances_gen_multi = [], []
        for i in range(len(test_data)):
            original_sample = test_data.iloc[[i]]
            ablated_sample = original_sample.copy()
            ablated_sample[ablated_cols] = 0 # Use 0 for nan to avoid sklearn warning
            ablated_prediction = rf.predict(ablated_sample.values)[0]
            
            prediction_input_coh = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
            prediction_input_multi = pd.concat([original_sample] * K_NEIGHBORS, ignore_index=True)
            for mod in ablated_modalities:
                mod_cols = [c for c in test_data.columns if c.startswith(f'{mod}_')]
                prediction_input_coh[mod_cols] = filtered_gen_candidates_coh[mod][i].values
                prediction_input_multi[mod_cols] = filtered_gen_candidates_multi[mod][i].values
            
            generated_predictions_coh = rf.predict(prediction_input_coh.values)
            generated_predictions_multi = rf.predict(prediction_input_multi.values)
            prediction_variances_gen_coh.append(np.mean((generated_predictions_coh != ablated_prediction)))
            prediction_variances_gen_multi.append(np.mean((generated_predictions_multi != ablated_prediction)))

        variance_df_gen_coh = pd.DataFrame({'variance': prediction_variances_gen_coh, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)
        variance_df_gen_multi = pd.DataFrame({'variance': prediction_variances_gen_multi, 'original_index': test_data.index}).sort_values(by='variance', ascending=True)

        # Calculate F1 scores for all methods
        gen_selective_f1_scores_coh, gen_selective_f1_scores_multi, random_ablation_f1_scores = [], [], []

        for ratio in ABLATION_STEPS:
            # Selective Ablation (Coherent)
            data_for_pred_coh = test_data.copy()
            n_to_ablate = int(len(data_for_pred_coh) * ratio)
            ablation_indices = variance_df_gen_coh.head(n_to_ablate)['original_index'].values
            if len(ablation_indices) > 0:
                data_for_pred_coh.loc[ablation_indices, ablated_cols] = 0
            preds_coh = rf.predict(data_for_pred_coh.values)
            gen_selective_f1_scores_coh.append(f1_score(test_labels, preds_coh, average='weighted', zero_division=0))

            # Selective Ablation (Multi)
            data_for_pred_multi = test_data.copy()
            ablation_indices = variance_df_gen_multi.head(n_to_ablate)['original_index'].values
            if len(ablation_indices) > 0:
                data_for_pred_multi.loc[ablation_indices, ablated_cols] = 0
            preds_multi = rf.predict(data_for_pred_multi.values)
            gen_selective_f1_scores_multi.append(f1_score(test_labels, preds_multi, average='weighted', zero_division=0))

            # Random Ablation
            iter_scores = []
            for _ in range(N_ITERATIONS):
                data_for_pred_rand = test_data.copy()
                ablation_indices = np.random.choice(data_for_pred_rand.index, size=n_to_ablate, replace=False)
                if len(ablation_indices) > 0:
                    data_for_pred_rand.loc[ablation_indices, ablated_cols] = 0
                preds_rand = rf.predict(data_for_pred_rand.values)
                iter_scores.append(f1_score(test_labels, preds_rand, average='weighted', zero_division=0))
            random_ablation_f1_scores.append(np.mean(iter_scores))

        all_runs_gen_coh.append(gen_selective_f1_scores_coh)
        all_runs_gen_multi.append(gen_selective_f1_scores_multi)
        all_runs_random.append(random_ablation_f1_scores)

    # --- 3c. Aggregate Results ---
    print("\nAggregating results...")
    results = {
        'mean_random': np.mean(all_runs_random, axis=0),
        'std_random': np.std(all_runs_random, axis=0),
        'mean_gen_coh': np.mean(all_runs_gen_coh, axis=0),
        'std_gen_coh': np.std(all_runs_gen_coh, axis=0),
        'mean_gen_multi': np.mean(all_runs_gen_multi, axis=0),
        'std_gen_multi': np.std(all_runs_gen_multi, axis=0)
    }
    return results

# =============================================================================
# 4. PLOTTING FUNCTION
# =============================================================================
def plot_multi_panel_prioritization(
    results_dict: Dict[str, Dict],
    ablation_steps: np.ndarray,
    figsize: Tuple[float, float] = (16, 7),
    savepath: Optional[str] = None
):
    """
    Generates a polished multi-panel line plot to compare prioritization impact across different modalities.
    
    Args:
        results_dict: A dictionary where keys are modality names and values are the results from run_ablation_experiment.
        ablation_steps: The steps of ablation used in the experiment.
        figsize: The overall size of the figure.
        savepath: Optional path to save the figure.
    """
    # --- 1. Setup ---
    modalities_to_plot = list(results_dict.keys())
    n_panels = len(modalities_to_plot)
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, n_panels, figsize=figsize, sharey=True)
    if n_panels == 1: # Ensure axes is always a list
        axes = [axes]

    color_map = {'random': '#9c2409', 'coherent': '#56b4e9', 'multi': '#0072b2'}

    legend_labels = {'random': 'Random', 'coherent': 'Informed (Coherent)', 'multi': 'Informed (Multi)'}
    #inverted_steps = 1 - ablation_steps
    inverted_steps = ablation_steps

    # --- 2. Plotting Loop ---
    for i, modality_key in enumerate(modalities_to_plot):
        ax = axes[i]
        results = results_dict[modality_key]
        
        # Plot each line
        for method in ['random', 'coherent', 'multi']:
            # Correctly access the keys in the results dictionary
            if method == 'random':
                mean = results['mean_random']
                std = results['std_random']
            elif method == 'coherent':
                mean = results['mean_gen_coh']
                std = results['std_gen_coh']
            else: # method == 'multi'
                mean = results['mean_gen_multi']
                std = results['std_gen_multi']
            
            ax.plot(inverted_steps, mean, marker='o', linestyle='-', markersize=5,
                    label=legend_labels[method], color=color_map[method])
            ax.fill_between(inverted_steps, mean - std, mean + std, alpha=0.15, color=color_map[method])

        # --- 3. Panel-specific Aesthetics ---
        ax.set_title(f"Impact of Observing {modality_key.upper()}", fontsize=14, weight='bold')
        ax.set_xlabel('Ratio of Samples with Observed Modality', fontsize=12)
        #ax.set_xlim(1.02, -0.02) # Invert axis
        # "Classy" horizontal grid
        ax.yaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
        ax.xaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
        ax.set_axisbelow(True)

        # Less prominent axes and ticks
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_color('#B0B0B0')
        ax.tick_params(axis='x', colors='#505050', length=0)
        ax.tick_params(axis='y', colors='#505050', length=0)
        ax.xaxis.label.set_color('#303030')
        ax.yaxis.label.set_color('#303030')

        ax.set_xlim(-0.02, 1.02) # Keep axis from 0 to 1
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.tick_params(labelsize=10)

    # --- 4. Global Aesthetics & Layout ---
    axes[0].set_ylabel('Stage Prediction F1-Score', fontsize=12)
    
    # Create a single legend for the whole figure
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), ncol=3, fontsize=11, frameon=False, title='Prioritization Strategy')
    
    fig.suptitle("Random Prioritization vs. Counterfactual Inference", fontsize=18, weight='bold', y=1.0)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96]) # Adjust layout to make space for suptitle and legend

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')
    plt.show()


# =============================================================================
# 5. MAIN EXECUTION
# =============================================================================
if __name__ == '__main__':
    # Define the experiments to run
    modalities_to_test = [('rna',), ('wsi',)]
    all_results = {}

    for modality_tuple in modalities_to_test:
        # Generate a simple key for the results dictionary
        key = '_'.join(modality_tuple) 
        result = run_ablation_experiment(modality_tuple)
        if result:
            all_results[key] = result
    
    if all_results:
        plot_multi_panel_prioritization(
            results_dict=all_results,
            ablation_steps=ABLATION_STEPS,
            savepath=f"../../results/downstream/task_07_counterfactual/multi_panel_prioritization_impact.png"
        )
    else:
        print("\nNo experiments were successfully completed. Plotting skipped.")
    
    with open("../../results/downstream/task_07_counterfactual/rna_wsi_results.json", "w") as f:
        json.dump(all_results, f, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else x)

In [None]:
import json

with open("../../results/downstream/task_07_counterfactual/rna_wsi_results.json", "w") as f:
    json.dump(all_results, f, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else x)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Tuple, Optional



def plot_multi_panel_prioritization(
    results_dict: Dict[str, Dict],
    ablation_steps: np.ndarray,
    figsize: Tuple[float, float] = (16, 7),
    savepath: Optional[str] = None
):
    """
    Generates a polished multi-panel line plot to compare prioritization impact across different modalities.
    Uses separate y-axes for each subplot.
    """
    modalities_to_plot = list(results_dict.keys())
    n_panels = len(modalities_to_plot)
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, n_panels, figsize=figsize, sharey=False)
    if n_panels == 1:
        axes = [axes]

    color_map = {'random': '#9c2409', 'coherent': '#56b4e9', 'multi': '#0072b2'}
    legend_labels = {'random': 'Random', 'coherent': 'Informed (Coherent Denoising)', 'multi': 'Informed (Multi-condition)'}
    inverted_steps = 1- ablation_steps

    for i, modality_key in enumerate(modalities_to_plot):
        ax = axes[i]
        results = results_dict[modality_key]
        for method in ['random', 'coherent', 'multi']:
            if method == 'random':
                mean = np.array(results['mean_random'])
                std = np.array(results['std_random'])
            elif method == 'coherent':
                mean = np.array(results['mean_gen_coh'])
                std = np.array(results['std_gen_coh'])
            else:
                mean = np.array(results['mean_gen_multi'])
                std = np.array(results['std_gen_multi'])

            ax.plot(inverted_steps, mean, marker='o', linestyle='--', markersize=5,
                    label=legend_labels[method], color=color_map[method])
            ax.fill_between(inverted_steps, mean - std, mean + std, alpha=0.15, color=color_map[method])

        ax.set_title(f"Impact of Observing {modality_key.upper()}", fontsize=16, weight='bold')
        ax.set_xlabel('Ratio of Samples with Observed Modality', fontsize=14)
        ax.yaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
        ax.xaxis.grid(True, linewidth=0.7, color="#CDCCCC", zorder=0)
        ax.set_axisbelow(True)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_color('#B0B0B0')
        ax.tick_params(axis='x', colors='#505050', length=0)
        ax.tick_params(axis='y', colors='#505050', length=0)
        ax.xaxis.label.set_color('#303030')
        ax.yaxis.label.set_color('#303030')
        ax.set_xlim(-0.02, 1.02)
        ax.set_xticks(np.arange(0, 1.1, 0.2))
        ax.tick_params(labelsize=10)
        ax.set_ylabel('Stage Prediction F1-Score', fontsize=14)

    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), ncol=3, fontsize=12, frameon=False, title='Prioritization Strategy')
    if legend.get_title() is not None:
        legend.get_title().set_fontsize(14)

    fig.suptitle("Random Prioritization vs. Counterfactual Inference", fontsize=18, weight='bold', y=1.0)
    fig.tight_layout(rect=[0, 0.05, 1, 0.96], w_pad=5)  # Add horizontal space with w_pad

    if savepath:
        plt.savefig(savepath, dpi=300, bbox_inches='tight')
    plt.show()



# read all_results from the JSON file
import json
with open("../../results/downstream/task_07_counterfactual/rna_wsi_results.json", "r") as f:
    all_results = json.load(f)

ABLATION_STEPS = np.arange(0, 1.05, 0.05)

plot_multi_panel_prioritization(
    results_dict=all_results,
    ablation_steps=ABLATION_STEPS,
    savepath=f"../../results/downstream/task_07_counterfactual/counterfactual_rna_wsi.png"
)


In [None]:
all_results