# ==============================================================================
# Phase 1: Data Preparation
# This script handles loading, preprocessing, and exporting gene expression
# and phenotype data.
#
# Before running:
# 1. Ensure you have pandas and numpy installed: pip install pandas numpy
# 2. Update the 'raw_expr_file' and 'raw_pheno_file' paths below
#    to point to your actual raw data files.
# ==============================================================================

In [None]:
import pandas as pd
import numpy as np
import os

def load_data(expr_path, pheno_path):
    """
    Loads gene expression and phenotype data from specified file paths.

    Args:
        expr_path (str): Path to the gene expression data file.
        pheno_path (str): Path to the phenotype data file.

    Returns:
        tuple: A tuple containing:
            - expr (pd.DataFrame): Gene expression DataFrame (genes x samples).
            - pheno (pd.DataFrame): Phenotype DataFrame.
    """
    print(f"\n--- Data Preparation: Loading Data ---")
    print(f"Loading gene expression data from: {expr_path}")
    expr = pd.read_csv(expr_path, sep="\t", index_col=0)
    print(f"Gene expression data shape: {expr.shape}")

    print(f"Loading phenotype data from: {pheno_path}")
    pheno = pd.read_csv(pheno_path, sep="\t", index_col=0)
    print(f"Phenotype data shape: {pheno.shape}")

    return expr, pheno

def preprocess_data(expr, pheno):
    """
    Transposes gene expression data and matches samples with phenotype data.

    Args:
        expr (pd.DataFrame): Gene expression DataFrame (genes x samples).
        pheno (pd.DataFrame): Phenotype DataFrame.

    Returns:
        tuple: A tuple containing:
            - expr_T (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
            - pheno_matched (pd.DataFrame): Matched phenotype DataFrame.
    """
    print("\n--- Data Preparation: Preprocessing Data ---")
    # Transpose gene expression data to (samples x genes)
    print("Transposing gene expression data...")
    expr_T = expr.T
    expr_T.index.name = 'sample_id'
    print(f"Transposed gene expression data shape: {expr_T.shape}")

    # Match samples in both files
    print("Matching samples between expression and phenotype data...")
    common_samples = expr_T.index.intersection(pheno.index)
    expr_matched = expr_T.loc[common_samples]
    pheno_matched = pheno.loc[common_samples]
    print(f"Number of matched samples: {len(common_samples)}")
    print(f"Matched expression data shape: {expr_matched.shape}")
    print(f"Matched phenotype data shape: {pheno_matched.shape}")

    return expr_matched, pheno_matched

def export_data(dataframe, output_path, file_format='csv'):
    """
    Exports a DataFrame to a specified file format.

    Args:
        dataframe (pd.DataFrame): The DataFrame to export.
        output_path (str): The path to save the exported file.
        file_format (str): The format to save the file ('csv', 'tsv', 'parquet', 'excel').
    """
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    if file_format == 'csv':
        dataframe.to_csv(output_path, index=True)
    elif file_format == 'tsv':
        dataframe.to_csv(output_path, sep='\t', index=True)
    elif file_format == 'parquet':
        dataframe.to_parquet(output_path, index=True)
    elif file_format == 'excel':
        dataframe.to_excel(output_path, index=True)
    else:
        print(f"Unsupported file format: {file_format}. Supported formats are 'csv', 'tsv', 'parquet', 'excel'.")
        return

    print(f"Data successfully exported to: {output_path} in {file_format} format.")

# ==============================================================================
# Main Execution Block for Data Preparation
# ==============================================================================
if __name__ == "__main__":
    print("Starting Data Preparation Phase...")

    # --- Configuration ---
    # IMPORTANT: Update these paths to your actual raw data files
    raw_expr_file = r"C:\Users\shrav\Desktop\PYTHON\Cancer\Pan Cancer Analysis\EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena"
    raw_pheno_file = r"C:\Users\shrav\Desktop\PYTHON\Cancer\Pan Cancer Analysis\TCGA_phenotype_denseDataOnlyDownload.tsv"
    
    # Output directory for processed data
    processed_data_dir = "processed_data"
    os.makedirs(processed_data_dir, exist_ok=True) # Ensure output directory exists

    # Define paths for processed files
    processed_expr_file = os.path.join(processed_data_dir, "expr_processed.tsv")
    processed_pheno_file = os.path.join(processed_data_dir, "pheno_processed.tsv")

    try:
        if os.path.exists(raw_expr_file) and os.path.exists(raw_pheno_file):
            expr_raw, pheno_raw = load_data(raw_expr_file, raw_pheno_file)
            expr_processed, pheno_processed = preprocess_data(expr_raw, pheno_raw)
            
            export_data(expr_processed, processed_expr_file, file_format='tsv')
            export_data(pheno_processed, processed_pheno_file, file_format='tsv')
        else:
            print(f"Raw data files not found at '{raw_expr_file}' or '{raw_pheno_file}'.")
            print("Generating dummy data for demonstration purposes.")
            np.random.seed(42)
            num_samples = 100
            num_genes = 500
            genes = [f'Gene_{i}' for i in range(num_genes)]
            samples = [f'Sample_{i}' for i in range(num_samples)]
            
            expr_processed = pd.DataFrame(np.random.rand(num_samples, num_genes), index=samples, columns=genes)
            tumor_types = ['BRCA', 'LUAD', 'COAD', 'KIRC', 'LIHC']
            pheno_processed = pd.DataFrame({
                '_primary_site': np.random.choice(tumor_types, num_samples),
                'age_at_diagnosis': np.random.randint(30, 80, num_samples)
            }, index=samples)
            
            export_data(expr_processed, processed_expr_file, file_format='tsv')
            export_data(pheno_processed, processed_pheno_file, file_format='tsv')
            print("Dummy processed data saved to 'processed_data' directory.")

    except Exception as e:
        print(f"Error during Data Preparation: {e}")
        print("Please check your file paths and data format. Exiting.")

    print("\nData Preparation Phase complete.")



# ==============================================================================
# Phase 1.1: EDA & Visualization (Enhanced for Publication/Sharing)
# This script performs exploratory data analysis, dimensionality reduction
# (PCA, UMAP), and generates various plots with improved aesthetics for
# sharing on platforms like LinkedIn and blogs.
#
# Before running:
# 1. Ensure you have pandas, numpy, matplotlib, seaborn, scikit-learn,
#    and umap-learn installed: pip install pandas numpy matplotlib seaborn scikit-learn umap-learn
# 2. This script assumes 'processed_data/expr_processed.tsv' and
#    'processed_data/pheno_processed.tsv' exist from Phase 1.
# ==============================================================================

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
import umap.umap_ as umap
import os

# Set a consistent style for all plots for a professional look
# 'seaborn-v0_8-darkgrid' is a good default for clean, readable plots
# You can experiment with other styles like 'ggplot', 'seaborn-v0_8-whitegrid', etc.
plt.style.use('seaborn-v0_8-darkgrid')

def generate_summary_statistics(dataframe, name="Data"):
    """
    Generates and prints summary statistics for a given DataFrame.

    Args:
        dataframe (pd.DataFrame): The DataFrame for which to generate statistics.
        name (str): A descriptive name for the DataFrame (e.g., "Gene Expression", "Phenotype").
    """
    print(f"\n--- EDA: Summary Statistics for {name} ---")
    print(dataframe.describe())
    print(f"\nMissing values in {name}:\n{dataframe.isnull().sum().sum()} total missing values.")
    if dataframe.isnull().sum().sum() > 0:
        print(f"Missing values per column:\n{dataframe.isnull().sum()[dataframe.isnull().sum() > 0]}")
    print(f"\nDataFrame Info for {name}:")
    dataframe.info()
    print("-" * (25 + len(name)))

def plot_tumor_type_distribution(phenotype_df, tumor_type_column='_primary_disease', output_path=None):
    """
    Plots the distribution of tumor types from the phenotype DataFrame.
    Enhanced for readability and aesthetics for sharing.

    Args:
        phenotype_df (pd.DataFrame): The phenotype DataFrame.
        tumor_type_column (str): The name of the column containing tumor type information.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if tumor_type_column not in phenotype_df.columns:
        print(f"Error: '{tumor_type_column}' not found in phenotype DataFrame. Cannot plot tumor type distribution.")
        return

    print(f"\n--- EDA: Plotting Tumor Type Distribution ---")
    plt.figure(figsize=(14, 8)) # Slightly larger figure size
    
    # Use a vibrant and distinct palette, as requested (e.g., 'mako' or 'Spectral')
    # Using 'mako' as per user's snippet
    tumor_counts = phenotype_df[tumor_type_column].value_counts().sort_values(ascending=False)
    sns.barplot(y=tumor_counts.index, x=tumor_counts.values, palette='mako')
    
    plt.title(f'Distribution of {tumor_type_column.replace("_", " ").title()} Across Samples', fontsize=18, weight='bold', color='darkblue') # Larger, bold title
    plt.xlabel('Number of Samples', fontsize=14, color='dimgray') # Larger label
    plt.ylabel(tumor_type_column.replace("_", " ").title(), fontsize=14, color='dimgray') # Larger label
    plt.xticks(fontsize=12) # Larger tick labels
    plt.yticks(fontsize=12) # Larger tick labels
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray') # Subtler grid
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300, bbox_inches='tight') # High DPI, ensure tight bounding box
        print(f"Tumor type distribution plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_expression_summary_histograms(expr_df, output_path=None):
    """
    Plots histograms of mean and standard deviation of gene expression.
    Enhanced for readability and aesthetics for sharing.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        output_path (str, optional): Base path to save the plots.
    """
    print(f"\n--- EDA: Plotting Gene Expression Summary Histograms ---")
    summary_stats = expr_df.describe().T
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6)) # Two subplots for mean and std
    
    # Plot Mean Distribution
    sns.histplot(summary_stats['mean'], bins=50, kde=True, ax=axes[0], color='skyblue', edgecolor='black')
    axes[0].set_title('Distribution of Gene Expression Mean', fontsize=16, weight='bold', color='darkblue')
    axes[0].set_xlabel('Mean Expression', fontsize=12, color='dimgray')
    axes[0].set_ylabel('Number of Genes', fontsize=12, color='dimgray')
    axes[0].tick_params(labelsize=10)
    axes[0].grid(axis='y', linestyle='--', alpha=0.6, color='lightgray')

    # Plot Standard Deviation Distribution
    sns.histplot(summary_stats['std'], bins=50, kde=True, ax=axes[1], color='salmon', edgecolor='black')
    axes[1].set_title('Distribution of Gene Expression Standard Deviation', fontsize=16, weight='bold', color='darkblue')
    axes[1].set_xlabel('Standard Deviation of Expression', fontsize=12, color='dimgray')
    axes[1].set_ylabel('Number of Genes', fontsize=12, color='dimgray')
    axes[1].tick_params(labelsize=10)
    axes[1].grid(axis='y', linestyle='--', alpha=0.6, color='lightgray')

    plt.suptitle("Distribution of Gene Expression Mean and Standard Deviation Across All Genes", fontsize=20, weight='bold', color='black', y=1.02)
    plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout for suptitle
    
    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Expression summary histograms saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def perform_pca(expr_df, n_components=2):
    """
    Performs Principal Component Analysis (PCA) on the gene expression data.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        n_components (int): Number of principal components to compute.

    Returns:
        tuple: A tuple containing:
            - pca_result (pd.DataFrame): DataFrame with PCA components.
            - pca_model (PCA): The fitted PCA model.
    """
    print(f"\n--- EDA: Performing PCA with {n_components} components ---")
    expr_df_filled = expr_df.fillna(expr_df.mean()) # Safeguard for NaNs

    pca = PCA(n_components=n_components)
    principal_components = pca.fit_transform(expr_df_filled)
    pca_result = pd.DataFrame(data=principal_components,
                              columns=[f'PC{i+1}' for i in range(n_components)],
                              index=expr_df.index)
    print(f"Explained variance ratio by components: {pca.explained_variance_ratio_}")
    print(f"Cumulative explained variance: {np.sum(pca.explained_variance_ratio_)}")
    return pca_result, pca

def plot_pca(pca_result_df, phenotype_df, color_column='_primary_disease', output_path=None, pca_model=None):
    """
    Plots the PCA results, colored by a specified phenotype column.
    Enhanced for readability and aesthetics for sharing.

    Args:
        pca_result_df (pd.DataFrame): DataFrame with PCA components.
        phenotype_df (pd.DataFrame): Matched phenotype DataFrame.
        color_column (str): The column in phenotype_df to use for coloring the plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
        pca_model (PCA, optional): The fitted PCA model to extract explained variance.
    """
    if color_column not in phenotype_df.columns:
        print(f"Error: '{color_column}' not found in phenotype DataFrame. Cannot color PCA plot.")
        return

    plot_df = pca_result_df.merge(phenotype_df[[color_column]], left_index=True, right_index=True)

    print(f"\n--- EDA: Plotting PCA results, colored by '{color_column}' ---")
    plt.figure(figsize=(12, 10)) # Adjusted figure size for better aspect ratio
    sns.scatterplot(x='PC1', y='PC2', hue=color_column, data=plot_df,
                    palette='tab20', s=80, alpha=0.85, edgecolor='black', linewidth=0.7) # Slightly larger points, black edge
    
    # Add explained variance to axis labels if pca_model is provided
    pc1_variance = f" ({pca_model.explained_variance_ratio_[0]*100:.1f}% variance)" if pca_model else ""
    pc2_variance = f" ({pca_model.explained_variance_ratio_[1]*100:.1f}% variance)" if pca_model else ""

    plt.title(f'PCA of Gene Expression (Colored by {color_column.replace("_", " ").title()})', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel(f'Principal Component 1{pc1_variance}', fontsize=14, color='dimgray')
    plt.ylabel(f'Principal Component 2{pc2_variance}', fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    # Clean up legend as per user's snippet
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles=handles, labels=labels, title=color_column.replace("_", " ").title(),
               bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize='small', title_fontsize=12) # Larger legend title/text
    
    plt.grid(True, linestyle=':', alpha=0.5, color='lightgray') # Dotted grid, subtler
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to prevent legend overlap

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"PCA plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def perform_umap(expr_df, n_components=2, random_state=42):
    """
    Performs UMAP dimensionality reduction on the gene expression data.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        n_components (int): Number of dimensions for the UMAP embedding.
        random_state (int): Random seed for reproducibility.

    Returns:
        pd.DataFrame: DataFrame with UMAP components.
    """
    print(f"\n--- EDA: Performing UMAP with {n_components} components ---")
    expr_df_filled = expr_df.fillna(expr_df.mean()) # Safeguard for NaNs

    reducer = umap.UMAP(n_components=n_components, random_state=random_state)
    umap_embedding = reducer.fit_transform(expr_df_filled)
    umap_result = pd.DataFrame(data=umap_embedding,
                               columns=[f'UMAP{i+1}' for i in range(n_components)],
                               index=expr_df.index)
    return umap_result

def plot_umap(umap_result_df, phenotype_df, color_column='_primary_disease', output_path=None):
    """
    Plots the UMAP results, colored by a specified phenotype column.
    Enhanced for readability and aesthetics for sharing.

    Args:
        umap_result_df (pd.DataFrame): DataFrame with UMAP components.
        phenotype_df (pd.DataFrame): Matched phenotype DataFrame.
        color_column (str): The column in phenotype_df to use for coloring the plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if color_column not in phenotype_df.columns:
        print(f"Error: '{color_column}' not found in phenotype DataFrame. Cannot color UMAP plot.")
        return

    plot_df = umap_result_df.merge(phenotype_df[[color_column]], left_index=True, right_index=True)

    print(f"\n--- EDA: Plotting UMAP results, colored by '{color_column}' ---")
    plt.figure(figsize=(12, 10)) # Adjusted figure size for better aspect ratio
    sns.scatterplot(x='UMAP1', y='UMAP2', hue=color_column, data=plot_df,
                    palette='tab20', s=80, alpha=0.85, edgecolor='black', linewidth=0.7) # Slightly larger points, black edge
    plt.title(f'UMAP of Gene Expression (Colored by {color_column.replace("_", " ").title()})', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel(f'UMAP Component 1', fontsize=14, color='dimgray')
    plt.ylabel(f'UMAP Component 2', fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(title=color_column.replace("_", " ").title(), bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, title_fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.5, color='lightgray')
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"UMAP plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_top_variable_genes(expr_df, top_n=20, output_path=None):
    """
    Identifies and plots the top N most variable genes by standard deviation.
    Enhanced for readability and aesthetics for sharing.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        top_n (int): Number of top variable genes to plot.
        output_path (str, optional): Path to save the plot.
    """
    print(f"\n--- EDA: Plotting Top {top_n} Most Variable Genes ---")
    gene_std = expr_df.std().sort_values(ascending=False)
    top_genes = gene_std.head(top_n)
    print(f"Top {top_n} most variable genes:\n{top_genes}")

    plt.figure(figsize=(12, 8))
    sns.barplot(x=top_genes.values, y=top_genes.index, palette='viridis')
    plt.title(f'Top {top_n} Most Variable Genes by Standard Deviation', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel('Standard Deviation of Expression', fontsize=14, color='dimgray')
    plt.ylabel('Gene', fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Top variable genes plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_gene_correlation_heatmap(expr_df, gene_list, output_path=None):
    """
    Plots a correlation heatmap for a given list of genes.
    Enhanced for readability and aesthetics for sharing.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        gene_list (list): List of gene names for which to plot correlations.
        output_path (str, optional): Path to save the plot.
    """
    present_genes = [gene for gene in gene_list if gene in expr_df.columns]
    if not present_genes:
        print(f"Error: None of the specified genes {gene_list} found in expression data. Cannot plot correlation heatmap.")
        return
    if len(present_genes) < 2:
        print(f"Warning: Only {len(present_genes)} gene(s) found. Need at least 2 for a correlation heatmap. Skipping.")
        return

    print(f"\n--- EDA: Plotting Correlation Heatmap for Selected Genes ---")
    
    # Select only the present genes and compute correlation
    corr_matrix = expr_df[present_genes].corr()

    plt.figure(figsize=(12, 10))
    sns.heatmap(corr_matrix, cmap='coolwarm', center=0, annot=True, fmt=".2f", linewidths=.5, linecolor='lightgray',
                cbar_kws={'label': 'Pearson Correlation Coefficient'})
    plt.title(f'Correlation Between Selected Genes ({len(present_genes)} Genes)', fontsize=18, weight='bold', color='darkblue')
    plt.xticks(fontsize=10, rotation=90)
    plt.yticks(fontsize=10, rotation=0)
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Gene correlation heatmap saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_specific_gene_expression_boxplots(expr_df, phenotype_df, gene_list, tumor_type_column='_primary_disease', output_dir_base=None):
    """
    Plots box plots for expression of specific genes across different tumor types.
    Enhanced for readability and aesthetics for sharing.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        phenotype_df (pd.DataFrame): Matched phenotype DataFrame.
        gene_list (list): List of gene names to plot.
        tumor_type_column (str): The column in phenotype_df representing tumor types.
        output_dir_base (str, optional): Base directory to save individual gene plots.
    """
    if tumor_type_column not in phenotype_df.columns:
        print(f"Error: '{tumor_type_column}' not found in phenotype DataFrame. Cannot plot gene expression boxplots.")
        return

    print(f"\n--- EDA: Plotting Specific Gene Expression Across Tumor Types ---")
    
    if output_dir_base:
        os.makedirs(output_dir_base, exist_ok=True)

    for gene in gene_list:
        if gene in expr_df.columns:
            plot_df = pd.DataFrame({
                'Expression': expr_df[gene],
                'TumorType': phenotype_df.loc[expr_df.index, tumor_type_column]
            }).dropna() # Drop NA if any for plotting

            if plot_df.empty:
                print(f"Warning: No valid data for gene '{gene}' after merging with phenotype. Skipping plot.")
                continue

            plt.figure(figsize=(12, 7))
            # FIX: Add hue=TumorType and legend=False to suppress FutureWarning
            sns.boxplot(x='TumorType', y='Expression', data=plot_df, palette='Set2', hue='TumorType', legend=False)
            plt.xticks(rotation=90, fontsize=10)
            plt.yticks(fontsize=10)
            plt.title(f'{gene} Expression Across {tumor_type_column.replace("_", " ").title()}', fontsize=18, weight='bold', color='darkblue')
            plt.xlabel(tumor_type_column.replace("_", " ").title(), fontsize=14, color='dimgray')
            plt.ylabel(f'{gene} Expression', fontsize=14, color='dimgray')
            plt.grid(axis='y', linestyle='--', alpha=0.6, color='lightgray')
            plt.tight_layout()

            if output_dir_base:
                output_path = os.path.join(output_dir_base, f"{gene}_expression_across_cancers.png")
                plt.savefig(output_path, dpi=300, bbox_inches='tight')
                print(f"Plot for {gene} saved to: {output_path}")
            else:
                plt.show()
            plt.close()
        else:
            print(f"Warning: Gene '{gene}' not found in expression data. Skipping boxplot for this gene.")


# ==============================================================================
# Main Execution Block for EDA & Visualization
# ==============================================================================
if __name__ == "__main__":
    print("Starting EDA & Visualization Phase...")

    # --- Configuration ---
    processed_data_dir = "processed_data"
    eda_plots_dir = "eda_plots"
    os.makedirs(eda_plots_dir, exist_ok=True) # Ensure output directory exists

    processed_expr_file = os.path.join(processed_data_dir, "expr_processed.tsv")
    processed_pheno_file = os.path.join(processed_data_dir, "pheno_processed.tsv")
    
    # IMPORTANT: Use '_primary_disease' as the target column as per your snippet
    target_column = '_primary_disease' 

    expr_processed = None
    pheno_processed = None

    try:
        if os.path.exists(processed_expr_file) and os.path.exists(processed_pheno_file):
            print("Loading processed data for EDA example...")
            expr_processed = pd.read_csv(processed_expr_file, sep='\t', index_col=0)
            pheno_processed = pd.read_csv(processed_pheno_file, sep='\t', index_col=0)
            print("Processed data loaded successfully.")
        else:
            print("Processed data files not found. Generating dummy data for demonstration.")
            np.random.seed(42)
            num_samples = 100
            num_genes = 500
            genes = [f'Gene_{i}' for i in range(num_genes)]
            samples = [f'Sample_{i}' for i in range(num_samples)]
            
            expr_processed = pd.DataFrame(np.random.rand(num_samples, num_genes), index=samples, columns=genes)
            tumor_types = ['BRCA', 'LUAD', 'COAD', 'KIRC', 'LIHC', 'STAD', 'BLCA', 'HNSC', 'LGG', 'OV'] # More types for dummy
            pheno_processed = pd.DataFrame({
                '_primary_disease': np.random.choice(tumor_types, num_samples),
                'age_at_diagnosis': np.random.randint(30, 80, num_samples)
            }, index=samples)
            print("Dummy data generated.")

        # Ensure the phenotype column exists for plotting
        if target_column not in pheno_processed.columns:
            print(f"'{target_column}' column not found in phenotype data. Creating a dummy column for plotting.")
            pheno_processed[target_column] = np.random.choice(['DiseaseA', 'DiseaseB', 'DiseaseC'], len(pheno_processed))

        # ======================================================================
        # EDA Step 1: Generate Summary Statistics
        # ======================================================================
        generate_summary_statistics(expr_processed, "Gene Expression (Processed)")
        generate_summary_statistics(pheno_processed, "Phenotype (Processed)")

        # ======================================================================
        # EDA Step 2: Tumor Type Distribution
        # (Corresponds to your "Step 5: Tumor Type Distribution")
        # ======================================================================
        plot_tumor_type_distribution(pheno_processed, target_column,
                                     output_path=os.path.join(eda_plots_dir, "tumor_type_distribution.png"))

        # ======================================================================
        # EDA Step 3: Summary Statistics of Gene Expression (Histograms)
        # (Corresponds to your "Step 6: Summary Statistics of Expression")
        # ======================================================================
        plot_expression_summary_histograms(expr_processed,
                                           output_path=os.path.join(eda_plots_dir, "expression_mean_std_hist.png"))

        # ======================================================================
        # EDA Step 4: PCA of Gene Expression
        # (Corresponds to your "Step 7: PCA")
        # ======================================================================
        pca_results, pca_model = perform_pca(expr_processed, n_components=2)
        plot_pca(pca_results, pheno_processed, target_column,
                 output_path=os.path.join(eda_plots_dir, "pca_tumor_type.png"),
                 pca_model=pca_model)

        # ======================================================================
        # EDA Step 5: UMAP of Gene Expression
        # (Corresponds to your "UMAP" section)
        # ======================================================================
        umap_results = perform_umap(expr_processed, n_components=2)
        plot_umap(umap_results, pheno_processed, target_column,
                  output_path=os.path.join(eda_plots_dir, "umap_tumor_type.png"))

        # ======================================================================
        # EDA Step 6: Top Variable Genes
        # (Corresponds to your "gene_std" and "top_genes" section)
        # ======================================================================
        plot_top_variable_genes(expr_processed, top_n=20,
                                output_path=os.path.join(eda_plots_dir, "top_variable_genes.png"))

        # ======================================================================
        # EDA Step 7: Correlation Heatmap for Top Variable Genes
        # (Corresponds to your "top_corr_genes" section)
        # Re-using the top 20 genes from the previous step for correlation
        # ======================================================================
        gene_std_for_corr = expr_processed.std().sort_values(ascending=False)
        top_20_variable_genes = gene_std_for_corr.head(20).index.tolist()
        plot_gene_correlation_heatmap(expr_processed, top_20_variable_genes,
                                      output_path=os.path.join(eda_plots_dir, "gene_correlation_heatmap.png"))

        # ======================================================================
        # EDA Step 8: Specific Cancer Gene Expression Boxplots
        # (Corresponds to your "cancer_genes" section)
        # ======================================================================
        cancer_genes_to_plot = ['TP53', 'EGFR', 'MYC', 'BRCA1', 'CDKN2A']
        plot_specific_gene_expression_boxplots(expr_processed, pheno_processed, cancer_genes_to_plot,
                                               tumor_type_column=target_column,
                                               output_dir_base=os.path.join(eda_plots_dir, "gene_expression_boxplots"))


        print("\nEDA and Visualization steps completed. Check 'eda_plots' directory for generated plots.")

    except Exception as e:
        print(f"Error during EDA & Visualization: {e}")
        print("Please ensure processed data files are available or paths are correct.")

    print("\nEDA & Visualization Phase complete.")



# ==============================================================================
# Phase 1.2: Machine Learning (Enhanced Plots for Publication/Sharing)
# This script handles data splitting, model training, evaluation,
# confusion matrix plotting, feature importance extraction, and
# advanced visualizations like correlation heatmaps and similarity networks.
#
# Before running:
# 1. Ensure you have all necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn joblib networkx
# 2. This script assumes 'processed_data/expr_processed.tsv' and
#    'processed_data/pheno_processed.tsv' exist from Phase 1.
# ==============================================================================

In [None]:
# ==============================================================================
# Phase 3: Machine Learning (Enhanced Plots for Publication/Sharing)
# This script handles data splitting, model training, evaluation,
# confusion matrix plotting, feature importance extraction, and
# advanced visualizations like correlation heatmaps and similarity networks.
#
# Before running:
# 1. Ensure you have all necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn joblib networkx
# 2. This script assumes 'processed_data/expr_processed.tsv' and
#    'processed_data/pheno_processed.tsv' exist from Phase 1.
# ==============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import joblib # For saving/loading models
import networkx as nx # For the similarity network plot

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report

# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid') # A good default for clean, readable plots

def split_data(X, y, test_size=0.2, random_state=42, stratify=None):
    """
    Splits the data into training and testing sets.

    Args:
        X (pd.DataFrame): Feature DataFrame (gene expression).
        y (pd.Series): Target Series (phenotype column, e.g., tumor type).
        test_size (float): Proportion of the dataset to include in the test split.
        random_state (int): Controls the shuffling applied to the data before applying the split.
        stratify (array-like or None): If not None, data is split in a stratified fashion,
                                       using this as the class labels. Useful for imbalanced datasets.

    Returns:
        tuple: X_train, X_test, y_train, y_test
    """
    print(f"\n--- Machine Learning: Splitting data (test_size={test_size}) ---")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=stratify
    )
    print(f"X_train shape: {X_train.shape}")
    print(f"X_test shape: {X_test.shape}")
    print(f"y_train shape: {y_train.shape}")
    print(f"y_test shape: {y_test.shape}")
    return X_train, X_test, y_train, y_test

def train_model(model_name, X_train, y_train, **kwargs):
    """
    Trains a specified machine learning model.

    Args:
        model_name (str): Name of the model to train ('RandomForest', 'SVM', 'LogisticRegression').
        X_train (pd.DataFrame): Training features.
        y_train (pd.Series): Training target.
        **kwargs: Additional arguments for the model constructor.

    Returns:
        trained_model: The fitted machine learning model.
    """
    print(f"\n--- Machine Learning: Training {model_name} model ---")
    model = None
    if model_name == 'RandomForest':
        model = RandomForestClassifier(random_state=42, **kwargs)
    elif model_name == 'SVM':
        model = SVC(probability=True, random_state=42, **kwargs)
    elif model_name == 'LogisticRegression':
        model = LogisticRegression(random_state=42, solver='liblinear', **kwargs)
    else:
        raise ValueError(f"Unsupported model_name: {model_name}")

    model.fit(X_train, y_train)
    print(f"{model_name} model training complete.")
    return model

def evaluate_model(model, X_test, y_test, model_name="Model"):
    """
    Evaluates the trained model and prints performance metrics.

    Args:
        model: The trained machine learning model.
        X_test (pd.DataFrame): Test features.
        y_test (pd.Series): Test target.
        model_name (str): Name of the model for reporting.

    Returns:
        dict: A dictionary of evaluation metrics.
    """
    print(f"\n--- Machine Learning: Evaluating {model_name} model ---")
    y_pred = model.predict(X_test)

    metrics = {
        'accuracy': accuracy_score(y_test, y_pred),
        'precision': precision_score(y_test, y_pred, average='weighted', zero_division=0),
        'recall': recall_score(y_test, y_pred, average='weighted', zero_division=0),
        'f1_score': f1_score(y_test, y_pred, average='weighted', zero_division=0)
    }

    print(f"--- {model_name} Performance Metrics ---")
    for metric, value in metrics.items():
        print(f"{metric.replace('_', ' ').title()}: {value:.4f}")

    try:
        if len(np.unique(y_test)) > 2:
            y_proba = model.predict_proba(X_test)
            metrics['roc_auc'] = roc_auc_score(y_test, y_proba, multi_class='ovr', average='weighted')
            print(f"ROC AUC (Weighted OvR): {metrics['roc_auc']:.4f}")
        else:
            y_proba = model.predict_proba(X_test)[:, 1]
            metrics['roc_auc'] = roc_auc_score(y_test, y_proba)
            print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    except AttributeError:
        print("Model does not support predict_proba for ROC AUC calculation.")
        metrics['roc_auc'] = np.nan
    except ValueError as e:
        print(f"Could not calculate ROC AUC: {e}")
        metrics['roc_auc'] = np.nan

    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, zero_division=0))

    return metrics

def plot_confusion_matrix(model, X_test, y_test, class_names_for_plot, model_name="Model", output_path=None):
    """
    Plots the confusion matrix for the trained model.
    Enhanced for readability and aesthetics for sharing.

    Args:
        model: The trained machine learning model.
        X_test (pd.DataFrame): Test features.
        y_test (pd.Series): Test target.
        class_names_for_plot (list): List of original string class labels.
        model_name (str): Name of the model for plot title.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    print(f"\n--- Machine Learning: Plotting confusion matrix for {model_name} ---")
    y_pred = model.predict(X_test)

    # Get the unique integer labels that confusion_matrix expects (0, 1, ..., N-1)
    # These correspond to the order of class_names_for_plot
    integer_labels = range(len(class_names_for_plot))
    cm = confusion_matrix(y_test, y_pred, labels=integer_labels)
    
    # Use the original string labels for the DataFrame for readability
    cm_df = pd.DataFrame(cm, index=class_names_for_plot, columns=class_names_for_plot)

    plt.figure(figsize=(14, 12)) # Adjusted figure size for better readability with many classes
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues', cbar=True, # Set annot=True to show numbers
                linewidths=.7, linecolor='black', annot_kws={"size": 10}) # Adjusted annotation font size
    plt.title(f'Confusion Matrix for {model_name}', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel('Predicted Label', fontsize=14, color='dimgray')
    plt.ylabel('True Label', fontsize=14, color='dimgray')
    plt.xticks(rotation=45, ha='right', fontsize=10) # Adjusted tick font size
    plt.yticks(rotation=0, fontsize=10) # Adjusted tick font size
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Confusion matrix plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def get_feature_importance(model, feature_names, model_name="Model", top_n=20):
    """
    Extracts and prints feature importance (if available for the model).

    Args:
        model: The trained machine learning model.
        feature_names (list): List of feature names (gene names).
        model_name (str): Name of the model.
        top_n (int): Number of top features to display.

    Returns:
        pd.Series or None: A Series of feature importances, or None if not applicable.
    """
    print(f"\n--- Machine Learning: Extracting feature importance for {model_name} ---")
    importances = None
    if hasattr(model, 'feature_importances_'):
        importances = pd.Series(model.feature_importances_, index=feature_names)
        importances = importances.sort_values(ascending=False)
        print(f"Top {top_n} Feature Importances for {model_name}:\n{importances.head(top_n)}")
    elif hasattr(model, 'coef_'):
        if model.coef_.ndim > 1:
            importances = pd.Series(np.sum(np.abs(model.coef_), axis=0), index=feature_names)
        else:
            importances = pd.Series(np.abs(model.coef_), index=feature_names)
        importances = importances.sort_values(ascending=False)
        print(f"Top {top_n} Absolute Coefficients (Feature Importance) for {model_name}:\n{importances.head(top_n)}")
    else:
        print(f"Feature importance/coefficients not available for {model_name} model type.")
    return importances

def plot_top_important_genes(feature_importances_series, top_n=30, output_path=None):
    """
    Plots the top N most important genes based on feature importance.
    Enhanced for readability and aesthetics for sharing.

    Args:
        feature_importances_series (pd.Series): Series of gene importances (index=gene_name, value=importance).
        top_n (int): Number of top genes to plot.
        output_path (str, optional): Path to save the plot.
    """
    print(f"\n--- Machine Learning: Plotting Top {top_n} Most Important Genes ---")
    top_genes = feature_importances_series.head(top_n) # Assumes series is already sorted
    
    plt.figure(figsize=(12, 9)) # Adjusted figure size
    sns.barplot(x=top_genes.values, y=top_genes.index, palette='viridis')
    plt.title(f"Top {top_n} Most Important Genes (Random Forest Feature Importance)", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("Feature Importance (Gini Importance)", fontsize=14, color='dimgray')
    plt.ylabel("Gene", fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Top important genes plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_correlation_heatmap_top_genes(expr_df, feature_importances_series, top_n=50, output_path=None):
    """
    Plots a correlation heatmap for the top N most important genes.
    Enhanced for readability and aesthetics for sharing.

    Args:
        expr_df (pd.DataFrame): Transposed gene expression DataFrame (samples x genes).
        feature_importances_series (pd.Series): Series of gene importances (index=gene_name, value=importance).
        top_n (int): Number of top genes to include in the heatmap.
        output_path (str, optional): Path to save the plot.
    """
    print(f"\n--- Machine Learning: Plotting Correlation Heatmap of Top {top_n} Important Genes ---")
    top_genes_for_corr = feature_importances_series.head(top_n).index.tolist()
    
    # Ensure genes are actually in the expression data
    present_genes = [gene for gene in top_genes_for_corr if gene in expr_df.columns]
    if not present_genes:
        print(f"Error: None of the top {top_n} genes found in expression data. Cannot plot correlation heatmap.")
        return
    if len(present_genes) < 2:
        print(f"Warning: Only {len(present_genes)} top gene(s) found. Need at least 2 for a correlation heatmap. Skipping.")
        return

    corr_matrix = expr_df[present_genes].corr()

    plt.figure(figsize=(16, 14)) # Larger figure for better visibility of labels
    sns.heatmap(corr_matrix, cmap='coolwarm', center=0, annot=False, fmt=".2f", # No annotations for very large heatmaps
                linewidths=.5, linecolor='lightgray', cbar_kws={'label': 'Pearson Correlation Coefficient'})
    plt.title(f"Correlation Heatmap of Top {len(present_genes)} Important Genes", fontsize=18, weight='bold', color='darkblue')
    plt.xticks(fontsize=8, rotation=90) # Smaller font for many labels
    plt.yticks(fontsize=8, rotation=0)
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Correlation heatmap saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_tumor_similarity_network(y_true, y_pred, class_names, threshold=5, output_path=None):
    """
    Builds and plots a network graph showing similarity between tumor types
    based on misclassification patterns in the confusion matrix.
    Enhanced for readability and aesthetics for sharing.

    Args:
        y_true (np.array): True labels (encoded).
        y_pred (np.array): Predicted labels (encoded).
        class_names (list): Original class names corresponding to encoded labels.
        threshold (int): Minimum number of misclassifications between two classes
                         to draw an edge in the network.
        output_path (str, optional): Path to save the plot.
    """
    print(f"\n--- Machine Learning: Plotting Tumor Type Similarity Network ---")
    # Get the unique integer labels that confusion_matrix expects (0, 1, ..., N-1)
    integer_labels = range(len(class_names))
    cm = confusion_matrix(y_true, y_pred, labels=integer_labels)
    conf_df = pd.DataFrame(cm, index=class_names, columns=class_names)

    G = nx.Graph()

    # Add nodes
    for class_name in class_names:
        G.add_node(class_name)

    # Add edges based on off-diagonal confusion matrix values
    for i, true_label in enumerate(class_names):
        for j, pred_label in enumerate(class_names):
            if i != j: # Only consider misclassifications (off-diagonal)
                misclassification_count = conf_df.iloc[i, j]
                if misclassification_count >= threshold:
                    G.add_edge(true_label, pred_label, weight=misclassification_count)

    if not G.edges():
        print(f"No strong misclassification links found above threshold {threshold}. Skipping network plot.")
        return

    plt.figure(figsize=(18, 16)) # Increased figure size further for more space
    
    # Use spring_layout for a more organic, force-directed layout
    # Adjusted k for more spacing, and increased iterations for better convergence
    pos = nx.spring_layout(G, seed=42, k=1.0, iterations=100) 

    # Node sizes based on total misclassifications involving that node
    node_sizes = []
    for node in G.nodes():
        # Sum of misclassifications for this node (both as true and predicted)
        total_misclass = conf_df.loc[node, :].drop(node, errors='ignore').sum() + \
                         conf_df.loc[:, node].drop(node, errors='ignore').sum()
        node_sizes.append(500 + total_misclass * 10) # Base size + scaled by misclassifications

    # Edge widths based on misclassification count
    edges = G.edges(data=True)
    weights = [edge[2]['weight'] for edge in edges]
    edge_widths = [w * 0.2 for w in weights] # Scale down for visualization

    # Draw network elements
    nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=node_sizes, edgecolors='black', linewidths=1.0, alpha=0.9)
    nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.6, edge_color='gray', style='dashed') # Dashed edges
    # FIX: Reduced font size for labels to prevent overlap
    nx.draw_networkx_labels(G, pos, font_size=7, font_weight='bold', font_color='black') # Smaller font for contrast

    # Add edge labels (misclassification counts)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=6, font_color='darkred') # Smaller font for edge labels

    plt.title(f"Tumor Type Similarity Network (Misclassifications > {threshold})", fontsize=20, weight='bold', color='darkblue', pad=20)
    plt.axis('off') # Turn off axis
    plt.tight_layout() # Ensure tight layout for saving

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight') # Ensure tight bounding box for saving
        print(f"Tumor similarity network saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def save_model(model, path):
    """Saves the trained model to a file."""
    output_dir = os.path.dirname(path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")
    joblib.dump(model, path)
    print(f"Model saved to {path}")

def load_model(path):
    """Loads a trained model from a file."""
    model = joblib.load(path)
    print(f"Model loaded from {path}")
    return model

# ==============================================================================
# Main Execution Block for Machine Learning
# ==============================================================================
if __name__ == "__main__":
    print("Starting Machine Learning Phase...")

    # --- Configuration ---
    processed_data_dir = "processed_data"
    ml_results_dir = "ml_results"
    os.makedirs(ml_results_dir, exist_ok=True) # Ensure output directory exists

    processed_expr_file = os.path.join(processed_data_dir, "expr_processed.tsv")
    processed_pheno_file = os.path.join(processed_data_dir, "pheno_processed.tsv")
    trained_model_path = os.path.join(ml_results_dir, "random_forest_model.joblib")
    
    # IMPORTANT: Use '_primary_disease' as the target column as per your snippet
    target_column = '_primary_disease' 

    expr_processed = None
    pheno_processed = None

    try:
        if os.path.exists(processed_expr_file) and os.path.exists(processed_pheno_file):
            print("Loading processed data for Machine Learning example...")
            expr_processed = pd.read_csv(processed_expr_file, sep='\t', index_col=0)
            pheno_processed = pd.read_csv(processed_pheno_file, sep='\t', index_col=0)
            print("Processed data loaded successfully.")
        else:
            print("Processed data files not found. Generating dummy data for demonstration.")
            np.random.seed(42)
            num_samples = 100
            num_genes = 500
            genes = [f'Gene_{i}' for i in range(num_genes)]
            samples = [f'Sample_{i}' for i in range(num_samples)]
            
            expr_processed = pd.DataFrame(np.random.rand(num_samples, num_genes), index=samples, columns=genes)
            tumor_types = ['BRCA', 'LUAD', 'COAD', 'KIRC', 'LIHC', 'STAD', 'BLCA', 'HNSC', 'LGG', 'OV']
            pheno_processed = pd.DataFrame({
                '_primary_disease': np.random.choice(tumor_types, num_samples),
                'age_at_diagnosis': np.random.randint(30, 80, num_samples)
            }, index=samples)
            print("Dummy data generated.")

        # === Data Preprocessing for ML ===
        # Ensure the target column exists
        if target_column not in pheno_processed.columns:
            print(f"Error: Target column '{target_column}' not found in phenotype data.")
            print("Creating a dummy target column for demonstration.")
            pheno_processed[target_column] = np.random.choice(['TypeA', 'TypeB', 'TypeC'], len(pheno_processed))

        # Labels (target variable)
        labels = pheno_processed.loc[expr_processed.index, target_column]
        le = LabelEncoder()
        y = le.fit_transform(labels)
        class_names = le.classes_ # Store original class names

        # Features (all genes)
        X = expr_processed.fillna(0) # Fill NaN values, if any, before scaling

        # Scaling gene expression
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        # === Train-test split ===
        X_train, X_test, y_train, y_test = split_data(
            X_scaled, y, test_size=0.2, stratify=y, random_state=42
        )

        print(f"\nTraining set shape: {X_train.shape}")
        print(f"Test set shape: {X_test.shape}")
        print(f"Number of tumor types: {len(class_names)}")
        print(f"Tumor types: {class_names.tolist()}")


        # === Train Random Forest Model ===
        rf_model = train_model(
            'RandomForest',
            X_train, y_train,
            n_estimators=200,
            max_depth=None, # Let the model determine depth
            n_jobs=-1,      # Use all available cores
            verbose=1       # Print training progress
        )

        # === Evaluate Random Forest Model ===
        rf_metrics = evaluate_model(rf_model, X_test, y_test, model_name='RandomForest')
        
        # Plot Confusion Matrix
        plot_confusion_matrix(rf_model, X_test, y_test, class_names, 'RandomForest',
                              output_path=os.path.join(ml_results_dir, "rf_confusion_matrix.png"))
        
        # Save the trained Random Forest model
        save_model(rf_model, trained_model_path)

        # === Feature Importance ===
        # Get feature importances from the trained RF model
        rf_feature_importances_series = get_feature_importance(rf_model, X.columns, 'RandomForest', top_n=len(X.columns))
        
        if rf_feature_importances_series is not None and not rf_feature_importances_series.empty:
            # Plot Top Important Genes (top 30 as per your snippet)
            plot_top_important_genes(rf_feature_importances_series, top_n=30,
                                     output_path=os.path.join(ml_results_dir, "top_rf_genes_barplot.png"))

            # Plot Correlation Heatmap of Top 50 Important Genes
            plot_correlation_heatmap_top_genes(X, rf_feature_importances_series, top_n=50,
                                               output_path=os.path.join(ml_results_dir, "correlation_top50_genes.png"))
        else:
            print("Skipping feature importance plots as importances could not be retrieved.")

        # === Tumor Type Similarity Network (from Confusion Matrix) ===
        plot_tumor_similarity_network(y_test, rf_model.predict(X_test), class_names, threshold=5,
                                      output_path=os.path.join(ml_results_dir, "cancer_similarity_network.png"))


        print("\nMachine Learning analysis completed. Check 'ml_results' directory for plots and saved models.")

    except Exception as e:
        print(f"Error during Machine Learning: {e}")
        print("Please ensure processed data files are available and target column is suitable for classification.")

    print("\nMachine Learning Phase complete.")


# ==============================================================================
# Phase 1.3: GSEA Pathway Analysis (Enhanced Plots for Publication/Sharing)
# This script performs Gene Set Enrichment Analysis (GSEA) on important genes
# identified from the Machine Learning phase and generates publication-ready plots.
#
# Before running:
# 1. Ensure you have gseapy installed: pip install gseapy
# 2. This script assumes 'processed_data/expr_processed.tsv',
#    'processed_data/pheno_processed.tsv', and 'ml_results/random_forest_model.joblib'
#    exist from Phase 1 and Phase 3, respectively.
# ==============================================================================

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import joblib # For loading the trained model
import gseapy as gp # For GSEA

# For wrapping text in plots
import textwrap

# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid')

def rank_genes(feature_importances_series, method='descending'):
    """
    Ranks genes based on feature importance or other metrics.

    Args:
        feature_importances_series (pd.Series): A pandas Series where index is gene names
                                                and values are their importance scores.
        method (str): How to rank the genes ('descending' for high importance first,
                      'ascending' for low importance first).

    Returns:
        pd.Series: A Series of ranked gene names.
    """
    print(f"\n--- GSEA: Ranking genes by '{method}' importance ---")
    if method == 'descending':
        ranked_genes = feature_importances_series.sort_values(ascending=False)
    elif method == 'ascending':
        ranked_genes = feature_importances_series.sort_values(ascending=True)
    else:
        raise ValueError("Method must be 'descending' or 'ascending'.")

    print(f"Top 10 ranked genes:\n{ranked_genes.head(10)}")
    return ranked_genes

def perform_gsea(gene_list, gene_set_library='MSigDB_Hallmark_2020', organism='Human', outdir=None, cutoff=0.05):
    """
    Performs Gene Set Enrichment Analysis (GSEA) using Enrichr.

    Args:
        gene_list (list): A list of gene names to analyze.
        gene_set_library (str): The gene set library to use (e.g., 'KEGG_2021_Human',
                                'GO_Biological_Process_2021', 'Reactome_2022', 'MSigDB_Hallmark_2020').
                                See gseapy.get_library_name() for available libraries.
        organism (str): Organism for the gene set library (e.g., 'Human', 'Mouse').
        outdir (str, optional): Directory to save GSEA results. If None, results are not saved to disk.
        cutoff (float): P-value or adjusted P-value cutoff for enrichment.

    Returns:
        pd.DataFrame: DataFrame containing GSEA enrichment results.
    """
    print(f"\n--- GSEA: Performing GSEA using '{gene_set_library}' library ---")
    if outdir:
        os.makedirs(outdir, exist_ok=True)
        print(f"GSEA results will be saved to: {outdir}")

    enr = gp.enrichr(gene_list=gene_list,
                  gene_sets=gene_set_library,
                  organism=organism,
                  outdir=outdir,
                  no_plot=True, # We will plot manually for better control
                  cutoff=cutoff,
                  verbose=True)

    if enr.results.empty:
        print("No enriched pathways found.")
        return pd.DataFrame()

    print("GSEA completed. Top 5 enriched pathways:")
    print(enr.results.head())
    return enr.results

def plot_top_enriched_pathways_barh(gsea_results_df, top_n=20, output_path=None):
    """
    Plots the top N enriched pathways from GSEA results as a horizontal bar plot.
    Enhanced for readability and aesthetics for sharing.

    Args:
        gsea_results_df (pd.DataFrame): DataFrame containing GSEA enrichment results.
        top_n (int): Number of top pathways to plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if gsea_results_df.empty:
        print("No GSEA results to plot for top enriched pathways.")
        return

    print(f"\n--- GSEA: Plotting Top {top_n} Enriched Pathways (Bar Plot) ---")
    # Sort by Adjusted P-value (ascending) and take top N
    plot_df = gsea_results_df.sort_values(by='Adjusted P-value', ascending=True).head(top_n)

    plt.figure(figsize=(12, min(0.6 * len(plot_df), 12))) # Dynamic height, max 12
    # FIX: Add hue='Term' and legend=False to suppress FutureWarning
    sns.barplot(x='Adjusted P-value', y='Term', data=plot_df, palette='GnBu_d', hue='Term', legend=False)
    
    plt.title(f"Top {len(plot_df)} Enriched Pathways (GSEA)", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("Adjusted P-value", fontsize=14, color='dimgray')
    plt.ylabel("Pathway Term", fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.gca().invert_yaxis() # Invert y-axis to have the most significant at the top
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Top enriched pathways bar plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def plot_gsea_dot_plot(gsea_results_df, top_n=20, output_path=None):
    """
    Plots the top N enriched pathways as a dot plot showing Adjusted P-value,
    Overlap (gene count), and Enrichment Ratio.
    Enhanced for readability and aesthetics for sharing.

    Args:
        gsea_results_df (pd.DataFrame): DataFrame containing GSEA enrichment results.
        top_n (int): Number of top pathways to plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if gsea_results_df.empty:
        print("No GSEA results to plot for dot plot.")
        return

    print(f"\n--- GSEA: Plotting Top {top_n} Enriched Pathways (Dot Plot) ---")
    plot_df = gsea_results_df.sort_values(by='Adjusted P-value', ascending=True).head(top_n).copy()
    
    # Calculate Enrichment Ratio (if not already present)
    # FIX: Corrected 'Geneset_size' to 'GeneSet_size'
    if 'Enrichment Ratio' not in plot_df.columns:
        # Ensure 'Overlap' column is in 'count/total' format and 'GeneSet_size' exists
        if 'Overlap' in plot_df.columns and 'GeneSet_size' in plot_df.columns:
            plot_df['Enrichment Ratio'] = plot_df['Overlap'].apply(lambda x: int(x.split('/')[0])) / \
                                          plot_df['GeneSet_size']
        else:
            print("Warning: 'Overlap' or 'GeneSet_size' column not found for Enrichment Ratio calculation.")
            # Fallback or handle error, e.g., set to 1 to avoid division by zero if not present
            plot_df['Enrichment Ratio'] = 1.0 # Placeholder to prevent crash

    plt.figure(figsize=(14, min(0.7 * len(plot_df), 14))) # Dynamic height
    
    # Create the dot plot
    sns.scatterplot(
        data=plot_df,
        x='Adjusted P-value',
        y='Term',
        size='Overlap', # Size of dots by number of overlapping genes
        hue='Enrichment Ratio', # Color by enrichment ratio
        palette='viridis_r', # Reverse viridis for higher ratio = darker color
        sizes=(100, 1000), # Range of dot sizes
        edgecolor='black',
        linewidth=0.5,
        alpha=0.8
    )
    
    plt.xscale('log') # Log scale for p-value for better distribution
    plt.title(f"Top {len(plot_df)} Enriched Pathways (GSEA Dot Plot)", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("Adjusted P-value (log scale)", fontsize=14, color='dimgray')
    plt.ylabel("Pathway Term", fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.gca().invert_yaxis() # Most significant at the top
    plt.grid(True, linestyle=':', alpha=0.5, color='lightgray')
    
    # Adjust legend for size and hue
    handles, labels = plt.gca().get_legend_handles_labels()
    # Separate handles/labels for size and hue
    # Ensure there are enough unique values for both hue and size legends
    if 'Enrichment Ratio' in plot_df.columns and not plot_df['Enrichment Ratio'].empty:
        num_unique_hue = len(plot_df['Enrichment Ratio'].unique())
    else:
        num_unique_hue = 0 # No unique hue values if column is missing or empty

    if 'Overlap' in plot_df.columns and not plot_df['Overlap'].empty:
        num_unique_size = len(plot_df['Overlap'].unique())
    else:
        num_unique_size = 0 # No unique size values if column is missing or empty

    # Ensure we don't try to slice beyond available handles/labels
    if num_unique_hue > 0 and len(handles) >= num_unique_hue:
        hue_legend_handles = handles[:num_unique_hue]
        hue_legend_labels = labels[:num_unique_hue]
        legend1 = plt.legend(hue_legend_handles, hue_legend_labels, title='Enrichment Ratio', 
                             bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, title_fontsize=12)
        plt.gca().add_artist(legend1) # Add the first legend manually
    else:
        print("Warning: Not enough unique hue values or handles for Enrichment Ratio legend.")

    if num_unique_size > 0 and len(handles) >= (num_unique_hue + num_unique_size):
        size_legend_handles = handles[num_unique_hue : num_unique_hue + num_unique_size]
        size_legend_labels = labels[num_unique_hue : num_unique_hue + num_unique_size]
        legend2 = plt.legend(size_legend_handles, size_legend_labels, title='Overlapping Genes', 
                             bbox_to_anchor=(1.05, 0.7), loc='upper left', fontsize=10, title_fontsize=12)
    else:
        print("Warning: Not enough unique size values or handles for Overlapping Genes legend.")
    
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to prevent legend overlap

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"GSEA dot plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def plot_gene_pathway_heatmap(gsea_results_df, top_n_pathways=20, top_genes_list=None, output_path=None):
    """
    Plots a binary heatmap showing which genes are present in which of the top N enriched pathways.
    Enhanced for readability and aesthetics for sharing.

    Args:
        gsea_results_df (pd.DataFrame): DataFrame containing GSEA enrichment results.
        top_n_pathways (int): Number of top pathways to include in the heatmap.
        top_genes_list (list, optional): A list of globally important genes (e.g., from RF feature importance).
                                         If provided, only genes from this list that are in pathways will be shown.
        output_path (str, optional): Path to save the plot.
    """
    if gsea_results_df.empty:
        print("No GSEA results to plot for gene-pathway heatmap.")
        return

    print(f"\n--- GSEA: Plotting Gene-Pathway Heatmap ---")
    # Take top N pathways
    enriched_df = gsea_results_df.sort_values(by='Adjusted P-value', ascending=True).head(top_n_pathways)
    
    pathway_gene_map = {
        row['Term']: set(row['Genes'].split(';'))
        for _, row in enriched_df.iterrows()
    }

    # Collect all unique genes in these top pathways
    all_genes_in_pathways = set(g for genes_set in pathway_gene_map.values() for g in genes_set)

    # Filter genes if a top_genes_list is provided
    if top_genes_list:
        all_genes_to_plot = sorted(list(all_genes_in_pathways.intersection(set(top_genes_list))))
        if not all_genes_to_plot:
            print("No overlapping genes between top pathways and provided top genes list. Skipping heatmap.")
            return
    else:
        all_genes_to_plot = sorted(list(all_genes_in_pathways))

    if not all_genes_to_plot:
        print("No genes found to plot in the gene-pathway heatmap. Skipping.")
        return

    # Create binary matrix [Genes x Pathways]
    heatmap_df = pd.DataFrame(0, index=all_genes_to_plot, columns=pathway_gene_map.keys())
    for pathway, genes_set in pathway_gene_map.items():
        for gene in genes_set:
            if gene in heatmap_df.index: # Only add if gene is in our filtered list
                heatmap_df.loc[gene, pathway] = 1

    # Sort genes by number of pathways they are in, and pathways by number of genes
    heatmap_df['Gene_Count'] = heatmap_df.sum(axis=1)
    heatmap_df = heatmap_df.sort_values(by='Gene_Count', ascending=False).drop(columns='Gene_Count')
    heatmap_df = heatmap_df.loc[:, heatmap_df.sum(axis=0).sort_values(ascending=False).index] # Sort columns too

    plt.figure(figsize=(min(0.6 * len(heatmap_df.columns), 20), min(0.4 * len(heatmap_df.index), 25))) # Dynamic sizing
    sns.heatmap(
        heatmap_df,
        cmap="YlGnBu", # Good for binary data
        cbar=False,
        linewidths=0.5,
        linecolor='gray',
        xticklabels=True,
        yticklabels=True
    )
    plt.title(f"Gene–Pathway Heatmap (Top {len(heatmap_df.columns)} GSEA Pathways)", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("Pathways", fontsize=14, color='dimgray')
    plt.ylabel("Genes", fontsize=14, color='dimgray')
    plt.xticks(rotation=45, ha='right', fontsize=9)
    plt.yticks(fontsize=8)
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Gene-pathway heatmap saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def plot_rf_gene_pathway_overlap(gsea_results_df, rf_top_genes_list, top_n_pathways=15, output_path=None):
    """
    Plots the number of overlapping genes between Random Forest's top genes
    and genes in the top N enriched GSEA pathways.
    Enhanced for readability and aesthetics for sharing.

    Args:
        gsea_results_df (pd.DataFrame): DataFrame containing GSEA enrichment results.
        rf_top_genes_list (list): List of gene names identified as important by Random Forest.
        top_n_pathways (int): Number of top pathways to analyze for overlap.
        output_path (str, optional): Path to save the plot.
    """
    if gsea_results_df.empty:
        print("No GSEA results to analyze for RF gene overlap.")
        return
    if not rf_top_genes_list:
        print("No Random Forest top genes provided. Skipping RF gene-pathway overlap plot.")
        return

    print(f"\n--- GSEA: Plotting Random Forest Gene Enrichment in GSEA Pathways ---")

    # Helper to wrap long gene names for text labels
    def wrap_genes(gene_str, width=25):
        return "\n".join(textwrap.wrap(gene_str, width=width))

    overlap_data = []
    # Sort by Adjusted P-value (ascending) and take top N pathways
    top_enriched_df = gsea_results_df.sort_values(by='Adjusted P-value', ascending=True).head(top_n_pathways)

    for _, row in top_enriched_df.iterrows():
        pathway = row['Term']
        pathway_genes = set(row['Genes'].split(";"))
        overlap_genes = pathway_genes.intersection(set(rf_top_genes_list))
        if overlap_genes:
            overlap_data.append({
                'Pathway': pathway,
                'NumOverlap': len(overlap_genes),
                'Genes': ", ".join(sorted(list(overlap_genes))) # Convert set to list for sorting
            })

    if not overlap_data:
        print("No overlapping genes found between RF top genes and top GSEA pathways. Skipping plot.")
        return

    overlap_df = pd.DataFrame(overlap_data)
    # Sort by NumOverlap for plotting
    overlap_df = overlap_df.sort_values(by='NumOverlap', ascending=False)
    overlap_df['WrappedGenes'] = overlap_df['Genes'].apply(lambda g: wrap_genes(g, width=25))

    plt.figure(figsize=(15, min(0.7 * len(overlap_df), 12))) # Dynamic height, max 12
    # FIX: Add hue='Pathway' and legend=False to suppress FutureWarning
    sns.barplot(
        data=overlap_df,
        x='NumOverlap',
        y='Pathway',
        palette='Blues_d',
        hue='Pathway',
        legend=False
    )

    # Add wrapped gene names beside bars
    for i, (x, label) in enumerate(zip(overlap_df['NumOverlap'], overlap_df['WrappedGenes'])):
        plt.text(
            x + 0.5, i, label,
            va='center',
            ha='left',
            fontsize=8, # Smaller font for many genes
            color='black'
        )

    plt.xlabel("Number of Overlapping RF Genes", fontsize=14, color='dimgray')
    plt.ylabel("Enriched Pathway", fontsize=14, color='dimgray')
    plt.title("Random Forest Gene Enrichment in GSEA Pathways", fontsize=18, weight='bold', color='darkblue')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=10) # Smaller font for pathway names if many
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout(rect=[0, 0, 0.75, 1]) # Adjust layout to make space for gene labels

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"RF GSEA overlap plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()


# ==============================================================================
# Main Execution Block for GSEA Pathway Analysis
# ==============================================================================
if __name__ == "__main__":
    print("Starting GSEA Pathway Analysis Phase...")

    # --- Configuration ---
    processed_data_dir = "processed_data"
    ml_results_dir = "ml_results"
    gsea_results_dir = "gsea_results"
    os.makedirs(gsea_results_dir, exist_ok=True) # Ensure output directory exists

    processed_expr_file = os.path.join(processed_data_dir, "expr_processed.tsv")
    trained_model_path = os.path.join(ml_results_dir, "random_forest_model.joblib")

    expr_processed = None
    rf_model = None

    try:
        # Load processed expression data
        if os.path.exists(processed_expr_file):
            expr_processed = pd.read_csv(processed_expr_file, sep='\t', index_col=0)
            print("Processed expression data loaded successfully.")
        else:
            print(f"Processed expression file not found: {processed_expr_file}")
            print("Generating dummy expression data for GSEA demonstration.")
            np.random.seed(42)
            num_samples = 100
            num_genes = 2000 # Reduced for dummy GSEA
            genes = [f'Gene_{i}' for i in range(num_genes)]
            samples = [f'Sample_{i}' for i in range(num_samples)]
            expr_processed = pd.DataFrame(np.random.rand(num_samples, num_genes), index=samples, columns=genes)
            print("Dummy expression data generated.")

        # Load the trained Random Forest model
        if os.path.exists(trained_model_path):
            rf_model = joblib.load(trained_model_path)
            print("Trained Random Forest model loaded successfully.")
        else:
            print(f"Trained model file not found: {trained_model_path}")
            print("Training a dummy Random Forest model for GSEA demonstration.")
            # Need dummy y for training
            dummy_pheno_path = os.path.join(processed_data_dir, "pheno_processed.tsv")
            if os.path.exists(dummy_pheno_path):
                dummy_pheno = pd.read_csv(dummy_pheno_path, sep='\t', index_col=0)
                # Ensure target column exists, create if not
                dummy_target_column = '_primary_disease'
                if dummy_target_column not in dummy_pheno.columns:
                    dummy_pheno[dummy_target_column] = np.random.choice(['TypeA', 'TypeB', 'TypeC'], len(dummy_pheno))
                
                # Align indices of expr_processed and dummy_pheno
                common_indices = expr_processed.index.intersection(dummy_pheno.index)
                X_dummy = expr_processed.loc[common_indices].fillna(0)
                y_dummy = dummy_pheno.loc[common_indices, dummy_target_column]

                if len(y_dummy.unique()) < 2: # Ensure at least two classes for classification
                    y_list = y_dummy.tolist()
                    if len(y_list) > 1:
                        y_list[1] = 'AnotherType' if y_list[1] == y_list[0] else y_list[0]
                    y_dummy = pd.Series(y_list, index=y_dummy.index)

                # Scale dummy X
                scaler = StandardScaler()
                X_scaled_dummy = scaler.fit_transform(X_dummy)

                # Split dummy data
                X_train_dummy, X_test_dummy, y_train_dummy, y_test_dummy = train_test_split(
                    X_scaled_dummy, y_dummy, test_size=0.3, stratify=y_dummy, random_state=42
                )
                
                # Train dummy model
                rf_model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
                rf_model.fit(X_train_dummy, y_train_dummy)
                print("Dummy Random Forest model trained.")
            else:
                print("Cannot load or generate dummy phenotype data. Skipping GSEA.")
                exit()


        if rf_model is None or not hasattr(rf_model, 'feature_importances_'):
            print("Random Forest model not available or lacks feature importances. Skipping GSEA.")
            exit()

        # === Step 1: Get top N important genes from RF model ===
        # Use expr_processed.columns as gene_names, which should align with model's features
        gene_names_from_expr = expr_processed.columns
        importances = rf_model.feature_importances_
        
        # Create a Series for feature importances to ensure gene names are associated
        feature_importances_series = pd.Series(importances, index=gene_names_from_expr)

        top_n_genes_for_gsea = 300 # As per your snippet
        ranked_genes = rank_genes(feature_importances_series, method='descending')
        top_genes_for_gsea_list = ranked_genes.head(top_n_genes_for_gsea).index.tolist()
        
        if not top_genes_for_gsea_list:
            print("No top genes found for GSEA. Skipping GSEA analysis.")
            exit()

        # === Step 2: Run GSEA with MSigDB Hallmark ===
        gsea_results = perform_gsea(
            gene_list=top_genes_for_gsea_list,
            gene_set_library="MSigDB_Hallmark_2020",
            organism='Human',
            outdir=None, # Do not save raw gseapy output to disk, we'll save our custom plots
            cutoff=0.05
        )

        if gsea_results.empty:
            print("No significant enriched pathways found. Skipping GSEA plots.")
        else:
            # === Step 3: Visualize Top 20 Pathways (Bar Plot) ===
            plot_top_enriched_pathways_barh(gsea_results, top_n=20,
                                            output_path=os.path.join(gsea_results_dir, "gsea_top_pathways_barh.png"))

            # === Step 4: GSEA Dot Plot (New Plot) ===
            plot_gsea_dot_plot(gsea_results, top_n=20,
                               output_path=os.path.join(gsea_results_dir, "gsea_dot_plot.png"))

            # === Step 5: Gene-Pathway Heatmap ===
            # Need to pass the original expr_processed for gene names if not using feature importance series index directly
            plot_gene_pathway_heatmap(gsea_results, top_n_pathways=20, top_genes_list=top_genes_for_gsea_list,
                                      output_path=os.path.join(gsea_results_dir, "gsea_gene_pathway_heatmap.png"))

            # === Step 6: Random Forest Gene Enrichment in GSEA Pathways (Overlap Plot) ===
            plot_rf_gene_pathway_overlap(gsea_results, top_genes_for_gsea_list, top_n_pathways=15,
                                         output_path=os.path.join(gsea_results_dir, "rf_gsea_overlap_neat_final.png"))

        print("\nGSEA Pathway Analysis completed. Check 'gsea_results' directory for generated plots.")

    except ImportError:
        print("\n'gseapy' library not found. Please install it: pip install gseapy")
        print("Skipping GSEA analysis.")
    except Exception as e:
        print(f"Error during GSEA Pathway Analysis: {e}")
        print("Please ensure processed data and trained model files are available and paths are correct.")

    print("\nGSEA Pathway Analysis Phase complete.")


# ==============================================================================
# Phase 1.4: Misclassification Analysis & Preranked GSEA
# This script identifies misclassified samples from the Machine Learning phase,
# performs differential expression analysis on specific misclassified groups,
# and conducts preranked GSEA to find enriched pathways. It also visualizes
# these groups using UMAP.
#
# Before running:
# 1. Ensure you have gseapy and umap-learn installed: pip install gseapy umap-learn
# 2. This script assumes 'processed_data/expr_processed.tsv',
#    'processed_data/pheno_processed.tsv', and 'ml_results/random_forest_model.joblib'
#    exist from previous phases.
# ==============================================================================

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import joblib # For loading the trained model
import gseapy as gp # For GSEA
import umap.umap_ as umap # For UMAP visualization

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid')

def load_data_and_model(processed_expr_file, processed_pheno_file, trained_model_path, target_column):
    """
    Loads processed data and the trained model, and prepares features/labels.
    """
    print("\n--- Loading Data and Model for Misclassification Analysis ---")
    expr_processed = None
    pheno_processed = None
    rf_model = None

    if os.path.exists(processed_expr_file) and os.path.exists(processed_pheno_file):
        expr_processed = pd.read_csv(processed_expr_file, sep='\t', index_col=0)
        pheno_processed = pd.read_csv(processed_pheno_file, sep='\t', index_col=0)
        print("Processed data loaded successfully.")
    else:
        print("Processed data files not found. Generating dummy data for demonstration.")
        np.random.seed(42)
        num_samples = 100
        num_genes = 500
        genes = [f'Gene_{i}' for i in range(num_genes)]
        samples = [f'Sample_{i}' for i in range(num_samples)]
        expr_processed = pd.DataFrame(np.random.rand(num_samples, num_genes), index=samples, columns=genes)
        tumor_types = ['BRCA', 'LUAD', 'COAD', 'KIRC', 'LIHC', 'STAD', 'BLCA', 'HNSC', 'LGG', 'OV', 'rectum adenocarcinoma', 'colon adenocarcinoma']
        pheno_processed = pd.DataFrame({
            '_primary_disease': np.random.choice(tumor_types, num_samples),
            'age_at_diagnosis': np.random.randint(30, 80, num_samples)
        }, index=samples)
        print("Dummy data generated.")

    if os.path.exists(trained_model_path):
        rf_model = joblib.load(trained_model_path)
        print("Trained Random Forest model loaded successfully.")
    else:
        print(f"Trained model file not found: {trained_model_path}")
        print("Training a dummy Random Forest model for demonstration.")
        # Ensure target column exists, create if not
        if target_column not in pheno_processed.columns:
            pheno_processed[target_column] = np.random.choice(['TypeA', 'TypeB', 'TypeC'], len(pheno_processed))
        
        # Align indices
        common_indices = expr_processed.index.intersection(pheno_processed.index)
        X_dummy = expr_processed.loc[common_indices].fillna(0)
        y_dummy = pheno_processed.loc[common_indices, target_column]

        if len(y_dummy.unique()) < 2:
            y_list = y_dummy.tolist()
            if len(y_list) > 1:
                y_list[1] = 'AnotherType' if y_list[1] == y_list[0] else y_list[0]
            y_dummy = pd.Series(y_list, index=y_dummy.index)

        scaler = StandardScaler()
        X_scaled_dummy = scaler.fit_transform(X_dummy)
        
        le = LabelEncoder()
        y_encoded_dummy = le.fit_transform(y_dummy)

        X_train_dummy, X_test_dummy, y_train_dummy, y_test_dummy = train_test_split(
            X_scaled_dummy, y_encoded_dummy, test_size=0.2, stratify=y_encoded_dummy, random_state=42
        )
        rf_model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
        rf_model.fit(X_train_dummy, y_train_dummy)
        print("Dummy Random Forest model trained.")
    
    # Prepare data for analysis
    X_full = expr_processed.fillna(0)
    scaler = StandardScaler()
    X_scaled_full = scaler.fit_transform(X_full)
    
    labels_full = pheno_processed.loc[expr_processed.index, target_column]
    le = LabelEncoder()
    y_encoded_full = le.fit_transform(labels_full)

    # Re-split data to get X_test, y_test matching the model's training
    # This is crucial for consistency with the trained model's test set
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled_full, y_encoded_full, test_size=0.2, stratify=y_encoded_full, random_state=42
    )

    return expr_processed, pheno_processed, rf_model, X_test, y_test, le, X_full # Return X_full for original gene names

def analyze_misclassifications(model, X_test, y_test, label_encoder):
    """
    Identifies and counts misclassified samples.
    """
    print("\n--- Analyzing Misclassified Samples ---")
    y_pred = model.predict(X_test)

    results_df = pd.DataFrame({
        'SampleID': label_encoder.inverse_transform(y_test), # Placeholder, will be replaced with actual sample IDs
        'TrueLabel': label_encoder.inverse_transform(y_test),
        'PredictedLabel': label_encoder.inverse_transform(y_pred)
    })
    
    # Correctly map sample IDs from the original expr_processed index
    # Assuming X_test came from a train_test_split on X_scaled_full,
    # the indices of X_test correspond to a subset of X_full's index.
    # We need to get the actual sample IDs from the original dataframe before splitting.
    # This requires recreating the split with the original index.
    
    # Re-perform split to get sample IDs associated with X_test
    X_full_df = pd.DataFrame(X_test, index=pd.Series(y_test).index, columns=model.feature_names_in_) # Use feature_names_in_ to get original gene names
    
    # This is a bit tricky. If X_test is just a numpy array, it loses its index.
    # We need to ensure that the sample IDs are correctly carried through the split.
    # For now, let's assume the order is preserved from the original full dataset after scaling and splitting.
    # A more robust way would be to split the original (unscaled) DataFrame and then scale.
    
    # Let's adjust the sample ID assignment for robustness
    # The actual sample IDs for X_test should come from the index of the original DataFrame before scaling and splitting.
    # Since X_test is a numpy array, its index is lost. We need to pass the original sample IDs through the split.
    # For this function, let's assume `X_test`'s rows correspond to the last `len(X_test)` samples
    # of the original `expr_processed` if the split was done simply on a sorted DataFrame.
    # A better approach (which should be in the main block) is to split `X_full` (DataFrame) directly.

    # For now, let's use the index of the `y_test` Series, which should retain original sample IDs if `y` was a Series with index.
    results_df['SampleID'] = pd.Series(y_test).index # This assumes y_test retains its original index

    misclassified_df = results_df[results_df['TrueLabel'] != results_df['PredictedLabel']]

    print(f"🔎 Number of misclassified samples: {misclassified_df.shape[0]}")
    print("First 5 misclassified samples:")
    print(misclassified_df.head())

    confusion_pairs = (
        misclassified_df.groupby(['TrueLabel', 'PredictedLabel'])
        .size()
        .reset_index(name='Count')
        .sort_values(by='Count', ascending=False)
    )

    print("\n🔥 Top 10 Misclassifications:\n")
    print(confusion_pairs.head(10))
    return misclassified_df, confusion_pairs

def plot_top_misclassification_pairs(confusion_pairs_df, top_n=10, output_path=None):
    """
    Plots the top N misclassification pairs as a bar plot.
    """
    print(f"\n--- Plotting Top {top_n} Misclassification Pairs ---")
    
    plot_df = confusion_pairs_df.head(top_n).copy()
    plot_df['Misclassification_Pair'] = plot_df['TrueLabel'] + ' -> ' + plot_df['PredictedLabel']
    
    plt.figure(figsize=(12, min(0.6 * len(plot_df), 10)))
    # FIX: Add hue='Misclassification_Pair' and legend=False to suppress FutureWarning
    sns.barplot(x='Count', y='Misclassification_Pair', data=plot_df, palette='Reds_d', hue='Misclassification_Pair', legend=False)
    
    plt.title(f'Top {len(plot_df)} Most Frequent Misclassification Pairs', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel('Number of Misclassified Samples', fontsize=14, color='dimgray')
    plt.ylabel('True Label -> Predicted Label', fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Top misclassification pairs plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def define_sample_groups(results_df, misclassified_df, target_misclass_true, target_misclass_pred):
    """
    Defines specific sample groups for comparative analysis.
    """
    print(f"\n--- Defining Sample Groups for {target_misclass_true} misclassified as {target_misclass_pred} ---")
    
    group_A = misclassified_df[
        (misclassified_df['TrueLabel'] == target_misclass_true) &
        (misclassified_df['PredictedLabel'] == target_misclass_pred)
    ]['SampleID'].tolist()
    print(f"Number of samples in Group A (Misclassified {target_misclass_true} -> {target_misclass_pred}): {len(group_A)}")

    group_B = results_df[
        (results_df['TrueLabel'] == target_misclass_pred) &
        (results_df['PredictedLabel'] == target_misclass_pred)
    ]['SampleID'].tolist()
    print(f"Number of samples in Group B (Correctly classified {target_misclass_pred}): {len(group_B)}")

    group_C = results_df[
        (results_df['TrueLabel'] == target_misclass_true) &
        (results_df['PredictedLabel'] == target_misclass_true)
    ]['SampleID'].tolist()
    print(f"Number of samples in Group C (Correctly classified {target_misclass_true}): {len(group_C)}")
    
    return group_A, group_B, group_C

def perform_differential_expression(expr_T, group_A_samples, group_B_samples):
    """
    Performs differential expression analysis (log2FC) between two groups.
    """
    print(f"\n--- Performing Differential Expression (Log2FC) between Group A and Group B ---")
    
    # Subset your gene expression matrix
    group_A_expr = expr_T.loc[group_A_samples].fillna(0)
    group_B_expr = expr_T.loc[group_B_samples].fillna(0)

    if group_A_expr.empty or group_B_expr.empty:
        print("One or both groups are empty for differential expression. Skipping.")
        return pd.DataFrame()

    # Compute average expression per gene
    mean_A = group_A_expr.mean(axis=0)
    mean_B = group_B_expr.mean(axis=0)

    # Compute log2 fold-change
    # Add a small offset to avoid log(0) for genes with zero expression
    logFC = np.log2((mean_A + 1e-6) / (mean_B + 1e-6))

    # Create DataFrame of results
    de_df = pd.DataFrame({
        'Gene': logFC.index,
        'log2FC': logFC.values,
        'GroupA_mean': mean_A.values,
        'GroupB_mean': mean_B.values
    })

    # Sort by absolute logFC
    de_df['abs_log2FC'] = de_df['log2FC'].abs()
    de_df_sorted = de_df.sort_values(by='abs_log2FC', ascending=False)

    print("Top 20 differentially expressed genes (by absolute log2FC):")
    print(de_df_sorted.head(20)[['Gene', 'log2FC']])
    return de_df_sorted

def run_gsea_preranked(rnk_df, gene_set_library='MSigDB_Hallmark_2020', output_dir='gsea_preranked_results'):
    """
    Runs preranked GSEA on a ranked gene list.
    """
    print(f"\n--- Running Preranked GSEA using '{gene_set_library}' library ---")
    os.makedirs(output_dir, exist_ok=True)

    # Ensure rnk_df has 'Gene' and 'log2FC' columns and is sorted
    rnk_for_gsea = rnk_df[['Gene', 'log2FC']].sort_values(by='log2FC', ascending=False)

    # Save as .rnk file (optional, but good for gseapy)
    rnk_file_path = os.path.join(output_dir, "ranked_genes_for_gsea.rnk")
    rnk_for_gsea.to_csv(rnk_file_path, sep='\t', index=False, header=False)
    print(f"Ranked gene list saved to: {rnk_file_path}")

    pre_res = gp.prerank(
        rnk=rnk_file_path, # Pass the path to the .rnk file
        gene_sets=gene_set_library,
        processes=4, # Number of processes to use
        permutation_num=100, # Number of permutations for significance testing
        outdir=output_dir,
        format='png', # Saves plots automatically by gseapy
        seed=42,
        verbose=True
    )

    if pre_res.res2d.empty:
        print("No enriched pathways found in preranked GSEA.")
        return pd.DataFrame()

    print("\nTop 10 pathways enriched in Group A (Positive NES):")
    print(pre_res.res2d[['Term', 'NES', 'NOM p-val', 'FDR q-val']].sort_values(by='NES', ascending=False).head(10))

    print("\nTop 10 pathways enriched in Group B (Negative NES):")
    print(pre_res.res2d[['Term', 'NES', 'NOM p-val', 'FDR q-val']].sort_values(by='NES', ascending=True).head(10))
    
    return pre_res.res2d

def plot_preranked_gsea_results(gsea_preranked_results_df, top_n=10, output_path=None):
    """
    Plots the top N enriched pathways from preranked GSEA (both positive and negative NES).
    """
    print(f"\n--- Plotting Top {top_n} Enriched Pathways from Preranked GSEA ---")
    
    if gsea_preranked_results_df.empty:
        print("No preranked GSEA results to plot.")
        return

    # Filter for significant pathways (e.g., FDR q-val < 0.25, common for GSEA)
    significant_results = gsea_preranked_results_df[gsea_preranked_results_df['FDR q-val'] < 0.25].copy()
    
    plot_df = pd.DataFrame() # Initialize empty DataFrame for plotting

    if not significant_results.empty:
        # Separate positive and negative NES for significant results
        positive_nes = significant_results[significant_results['NES'] > 0].sort_values(by='NES', ascending=False).head(top_n)
        negative_nes = significant_results[significant_results['NES'] < 0].sort_values(by='NES', ascending=True).head(top_n)
        plot_df = pd.concat([positive_nes, negative_nes]).sort_values(by='NES', ascending=False)
        plot_title_suffix = ""
    else:
        print(f"No pathways found with FDR q-val < 0.25. Plotting top {top_n} pathways regardless of significance.")
        # If no significant results, plot top N from the full results
        positive_nes = gsea_preranked_results_df[gsea_preranked_results_df['NES'] > 0].sort_values(by='NES', ascending=False).head(top_n)
        negative_nes = gsea_preranked_results_df[gsea_preranked_results_df['NES'] < 0].sort_values(by='NES', ascending=True).head(top_n)
        plot_df = pd.concat([positive_nes, negative_nes]).sort_values(by='NES', ascending=False)
        plot_title_suffix = " (No significant pathways found at FDR < 0.25)"


    if plot_df.empty:
        print("No pathways to plot after filtering for top N positive/negative NES (even without significance filter).")
        return

    plt.figure(figsize=(14, min(0.7 * len(plot_df), 12)))
    # FIX: Add hue='NES' and legend=False to suppress FutureWarning
    sns.barplot(x='NES', y='Term', data=plot_df, hue='NES', palette='coolwarm', dodge=False, legend=False)
    
    plt.title(f'Top Enriched Pathways (Preranked GSEA){plot_title_suffix}', fontsize=18, weight='bold', color='darkblue')
    plt.xlabel('Normalized Enrichment Score (NES)', fontsize=14, color='dimgray')
    plt.ylabel('Pathway Term', fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=10) # Smaller font for pathway names if many
    plt.axvline(0, color='grey', linestyle='--', linewidth=0.8) # Add a vertical line at NES=0
    plt.grid(axis='x', linestyle='--', alpha=0.6, color='lightgray')
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Preranked GSEA results plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def plot_umap_misclassification_groups(expr_T, sample_ids_A, sample_ids_B, sample_ids_C, output_path=None):
    """
    Plots UMAP of specific sample groups using their gene expression.
    """
    print("\n--- Plotting UMAP for Misclassification Groups ---")

    # Combine all relevant sample IDs and ensure they are in expr_T
    all_selected_samples = list(set(sample_ids_A + sample_ids_B + sample_ids_C))
    
    # Filter expr_T to only include these samples and their genes
    subset_expr = expr_T.loc[all_selected_samples].fillna(0)

    if subset_expr.empty:
        print("No data available for selected groups to plot UMAP. Skipping.")
        return

    # Create labels for UMAP plot
    group_labels = []
    for sid in subset_expr.index:
        if sid in sample_ids_A:
            group_labels.append('Misclassified_RECAD→COAD')
        elif sid in sample_ids_B:
            group_labels.append('Correct_COAD')
        elif sid in sample_ids_C:
            group_labels.append('Correct_RECAD')
        else:
            group_labels.append('Other') # Should not happen if all_selected_samples is correct

    # Perform UMAP
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.3, metric='correlation', random_state=42)
    umap_coords = reducer.fit_transform(subset_expr)

    # Create DataFrame for UMAP plotting
    umap_df = pd.DataFrame(umap_coords, columns=['UMAP1', 'UMAP2'], index=subset_expr.index)
    umap_df['Group'] = group_labels

    plt.figure(figsize=(12, 8)) # Adjusted figure size
    sns.scatterplot(
        data=umap_df,
        x='UMAP1', y='UMAP2',
        hue='Group',
        style='Group', # Use style to differentiate groups visually
        palette='Set1', # A distinct palette
        s=80, # Point size
        edgecolor='k', # Black edge for points
        alpha=0.8
    )
    plt.title("UMAP: Misclassified Rectum Adenocarcinoma vs. Colon Adenocarcinoma", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("UMAP1", fontsize=14, color='dimgray')
    plt.ylabel("UMAP2", fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(title="Sample Group", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, title_fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.5, color='lightgray')
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"UMAP plot for misclassification groups saved to: {output_path}")
    else:
        plt.show()
    plt.close()


# ==============================================================================
# Main Execution Block for Misclassification Analysis & Preranked GSEA
# ==============================================================================
if __name__ == "__main__":
    print("Starting Misclassification Analysis & Preranked GSEA Phase...")

    # --- Configuration ---
    processed_data_dir = "processed_data"
    ml_results_dir = "ml_results"
    gsea_preranked_dir = "gsea_preranked_results" # Directory for preranked GSEA outputs
    os.makedirs(gsea_preranked_dir, exist_ok=True)

    processed_expr_file = os.path.join(processed_data_dir, "expr_processed.tsv")
    processed_pheno_file = os.path.join(processed_data_dir, "pheno_processed.tsv")
    trained_model_path = os.path.join(ml_results_dir, "random_forest_model.joblib")
    target_column = '_primary_disease'

    try:
        # Load data and model, and prepare X_test, y_test, le, expr_T for analysis
        # expr_T here is the full processed expression data (samples x genes)
        expr_T, pheno_full, rf_model, X_test_scaled, y_test_encoded, le, X_full_original_genes = \
            load_data_and_model(processed_expr_file, processed_pheno_file, trained_model_path, target_column)

        if expr_T is None or rf_model is None:
            print("Required data or model could not be loaded. Exiting Misclassification Analysis.")
            exit()

        # === Step 1: Analyze Misclassifications ===
        # We need to ensure results_df gets the correct SampleIDs.
        # The y_test_encoded (numpy array) doesn't have original sample IDs.
        # We need to re-create the train_test_split on the original DataFrame with index.
        
        # Prepare full dataset with original indices
        X_for_split = expr_T.fillna(0) # Use expr_T directly as it has sample IDs as index
        y_for_split = pheno_full.loc[X_for_split.index, target_column]
        y_for_split_encoded = le.transform(y_for_split) # Use the same LabelEncoder

        # Perform the split again to get the indices of test samples
        X_train_df, X_test_df, y_train_series, y_test_series = train_test_split(
            X_for_split, y_for_split_encoded, test_size=0.2, stratify=y_for_split_encoded, random_state=42
        )
        
        # Now X_test_df has the correct sample IDs as its index
        y_pred_encoded = rf_model.predict(X_test_scaled) # Use the scaled X_test for prediction

        results_df = pd.DataFrame({
            'SampleID': X_test_df.index, # Correctly assign SampleIDs
            'TrueLabel': le.inverse_transform(y_test_series),
            'PredictedLabel': le.inverse_transform(y_pred_encoded)
        })

        misclassified_df = results_df[results_df['TrueLabel'] != results_df['PredictedLabel']]

        print(f"🔎 Number of misclassified samples: {misclassified_df.shape[0]}")
        print("First 5 misclassified samples:")
        print(misclassified_df.head())

        confusion_pairs = (
            misclassified_df.groupby(['TrueLabel', 'PredictedLabel'])
            .size()
            .reset_index(name='Count')
            .sort_values(by='Count', ascending=False)
        )
        print("\n🔥 Top 10 Misclassifications:\n")
        print(confusion_pairs.head(10))

        # NEW PLOT: Top Misclassification Pairs
        plot_top_misclassification_pairs(confusion_pairs, top_n=10,
                                         output_path=os.path.join(gsea_preranked_dir, "top_misclassification_pairs.png"))


        # === Step 2: Define Specific Sample Groups (Example: Rectum Adenocarcinoma misclassified as Colon Adenocarcinoma) ===
        target_misclass_true = 'rectum adenocarcinoma'
        target_misclass_pred = 'colon adenocarcinoma'

        group_A_samples, group_B_samples, group_C_samples = define_sample_groups(
            results_df, misclassified_df, target_misclass_true, target_misclass_pred
        )

        if not group_A_samples or not group_B_samples:
            print(f"Insufficient samples in target misclassified or correctly classified groups for DE/GSEA. Skipping.")
        else:
            # === Step 3: Perform Differential Expression Analysis ===
            # Use X_full_original_genes for DE, as it's the unscaled version with original gene names
            de_results_df = perform_differential_expression(X_full_original_genes, group_A_samples, group_B_samples)

            if not de_results_df.empty:
                # === Step 4: Run Preranked GSEA ===
                gsea_preranked_results = run_gsea_preranked(
                    de_results_df[['Gene', 'log2FC']],
                    gene_set_library='MSigDB_Hallmark_2020',
                    output_dir=gsea_preranked_dir
                )
                # NEW PLOT: Preranked GSEA Top Enriched Pathways
                if not gsea_preranked_results.empty:
                    plot_preranked_gsea_results(gsea_preranked_results, top_n=10,
                                                output_path=os.path.join(gsea_preranked_dir, "preranked_gsea_top_pathways.png"))
                else:
                    print("Preranked GSEA results are empty. Skipping preranked GSEA plot.")
            else:
                print("Differential expression results are empty. Skipping Preranked GSEA and its plot.")

            # === Step 5: Plot UMAP for Misclassification Groups ===
            plot_umap_misclassification_groups(
                X_full_original_genes, # Use original gene expression for UMAP
                group_A_samples,
                group_B_samples,
                group_C_samples,
                output_path=os.path.join(gsea_preranked_dir, "umap_misclassified_groups.png")
            )

        print("\nMisclassification Analysis & Preranked GSEA Phase complete.")

    except ImportError:
        print("\nRequired libraries (gseapy or umap-learn) not found. Please install them.")
        print("Skipping Misclassification Analysis & Preranked GSEA.")
    except Exception as e:
        print(f"Error during Misclassification Analysis & Preranked GSEA: {e}")
        print("Please ensure processed data, trained model files, and paths are correct.")

    print("\nMisclassification Analysis & Preranked GSEA Phase complete.")


# ==============================================================================
# Phase 1.5: Multi-Omics Integration & Analysis
# This script integrates gene expression and mutation data, performs
# comprehensive EDA, dimensionality reduction (PCA, UMAP), and trains
# a machine learning model for cancer type classification using multi-omics data.
#
# Before running:
# 1. Ensure you have necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn umap-learn mygene
# 2. Ensure you have the following data files:
#    - EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena (Expression)
#    - mc3.v0.2.8.PUBLIC.maf (Mutation)
#    - TCGA_phenotype_denseDataOnlyDownload.tsv (Phenotype)
#    - mart_export.txt (BioMart mapping for gene IDs - if needed for fallback)
# =============================================================================

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import mygene # For robust gene ID mapping
import umap.umap_ as umap # For UMAP visualization
import time # Import the time module for time.sleep()

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid')

# --- Configuration ---
# FIX: Updated data_dir to the absolute path provided by the user.
data_dir = r"C:\Users\shrav\Desktop\PYTHON\Cancer\Pan Cancer Analysis"
output_dir = "multi_omics_results"
os.makedirs(output_dir, exist_ok=True)

expr_path = os.path.join(data_dir, "EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena")
maf_path = os.path.join(data_dir, "mc3.v0.2.8.PUBLIC.maf")
pheno_path = os.path.join(data_dir, "TCGA_phenotype_denseDataOnlyDownload.tsv")
mart_path = os.path.join(data_dir, "mart_export.txt") # Fallback for gene mapping

target_column = '_primary_disease' # Column in phenotype for cancer type

# ==============================================================================
# 1. Data Loading and Preprocessing (Multi-Omics)
# ==============================================================================

def load_and_preprocess_multi_omics_data(expr_file, maf_file, pheno_file, mart_file, target_col):
    """
    Loads and preprocesses expression, mutation, and phenotype data,
    aligning samples and genes, and preparing for multi-omics analysis.
    """
    print("\n--- Multi-Omics Data Loading and Preprocessing ---")

    # --- Load Data ---
    try:
        expr_df = pd.read_csv(expr_file, sep="\t", index_col=0)
        maf_df = pd.read_csv(maf_file, sep="\t", comment='#', low_memory=False)
        pheno_df = pd.read_csv(pheno_file, sep="\t", low_memory=False)
        print("Raw expression, mutation (MAF), and phenotype data loaded.")
    except FileNotFoundError as e:
        print(f"Error loading data: {e}. Please ensure files are in the '{data_dir}' directory.")
        return None, None, None, None

    # --- Debug: Print Phenotype Columns ---
    print("\n--- Debug: Phenotype DataFrame Columns ---")
    print(pheno_df.columns.tolist())
    print("------------------------------------------")

    # --- Preprocess Expression Data ---
    expr_T = expr_df.T
    print(f"Initial Expression data shape (samples x genes): {expr_T.shape}")
    print(f"First 5 expression gene IDs (before mapping): {expr_T.columns[:5].tolist()}")

    # --- Preprocess Mutation Data ---
    if 'FILTER' in maf_df.columns:
        maf_df = maf_df[maf_df['FILTER'] == 'PASS']
        print(f"Filtered MAF to 'PASS' mutations. New MAF shape: {maf_df.shape}")
    
    maf_df = maf_df[["Tumor_Sample_Barcode", "Hugo_Symbol"]].dropna().drop_duplicates()
    print(f"Mutation records after dropping NaNs and duplicates: {maf_df.shape}")

    # --- Gene ID Mapping Helper Functions ---
    mg = mygene.MyGeneInfo()
    
    def batch_query_mygene_robust(ids, scopes, fields='ensembl.gene', species='human', batch_size=1000):
        mapping = {}
        successful_queries = 0
        not_found_queries_list = []
        for i in range(0, len(ids), batch_size):
            batch = ids[i:i + batch_size]
            try:
                res = mg.querymany(batch, scopes=scopes, fields=fields, species=species, returnall=True)
                
                for nf_item in res.get('notfound', []):
                    not_found_queries_list.append(nf_item['query'])

                for r in res['out']:
                    query_id = str(r['query'])
                    ensembl_info = r.get(fields)
                    if ensembl_info:
                        ensembl_id = None
                        if isinstance(ensembl_info, list):
                            for item in ensembl_info:
                                if isinstance(item, dict) and 'gene' in item:
                                    ensembl_id = item['gene']
                                    break
                        elif isinstance(ensembl_info, dict) and 'gene' in ensembl_info:
                            ensembl_id = ensembl_info['gene']
                        
                        if ensembl_id:
                            mapping[query_id] = ensembl_id
                            successful_queries += 1
            except Exception as e:
                print(f"❌ Error in mygene batch query for scopes '{scopes}': {e}")
            time.sleep(0.1)
        print(f"Successfully mapped {successful_queries} out of {len(ids)} terms for scopes '{scopes}'.")
        if not_found_queries_list:
            print(f"First 10 unmapped queries for scopes '{scopes}': {not_found_queries_list[:10]}")
        return mapping

    def load_mart_mapping(mart_file_path):
        hugo_to_ensembl = {}
        entrez_to_ensembl = {}
        try:
            mart_df = pd.read_csv(mart_file_path, sep="\t", low_memory=False)
            print(f"BioMart mapping file '{os.path.basename(mart_file_path)}' loaded.")
            print(f"BioMart columns: {mart_df.columns.tolist()}")

            # For Hugo Symbol to Ensembl (Gene name to Gene stable ID)
            if 'Gene name' in mart_df.columns and 'Gene stable ID' in mart_df.columns:
                temp_map_df = mart_df[['Gene name', 'Gene stable ID']].dropna().drop_duplicates()
                hugo_to_ensembl = dict(zip(temp_map_df['Gene name'], temp_map_df['Gene stable ID']))
                print(f"Found {len(hugo_to_ensembl)} Hugo Symbol to Ensembl mappings in BioMart file.")
            
            # For Entrez Gene ID to Ensembl (NCBI gene ID to Gene stable ID)
            if 'NCBI gene ID' in mart_df.columns and 'Gene stable ID' in mart_df.columns:
                temp_map_df = mart_df[['NCBI gene ID', 'Gene stable ID']].dropna().drop_duplicates()
                entrez_to_ensembl = dict(zip(temp_map_df['NCBI gene ID'].astype(str), temp_map_df['Gene stable ID']))
                print(f"Found {len(entrez_to_ensembl)} Entrez ID to Ensembl mappings in BioMart file.")

        except FileNotFoundError:
            print(f"Warning: BioMart mapping file '{os.path.basename(mart_file_path)}' not found. Skipping BioMart mapping.")
        except Exception as e:
            print(f"Error loading or processing BioMart mapping file: {e}")
        return hugo_to_ensembl, entrez_to_ensembl

    # Load BioMart mappings once at the beginning
    hugo_to_ensembl_mart, entrez_to_ensembl_mart = load_mart_mapping(mart_path)

    # --- Expression Gene Mapping ---
    expr_T_ensembl = pd.DataFrame()
    
    # Check if expression gene IDs are already Ensembl (e.g., ENSG00000123456.7)
    is_expr_ensembl = False
    if len(expr_T.columns) > 0 and isinstance(expr_T.columns[0], str):
        if all(col.startswith('ENSG') and '.' in col for col in expr_T.columns[:min(5, len(expr_T.columns))]):
            is_expr_ensembl = True
            expr_T_ensembl = expr_T.copy()
            expr_T_ensembl.columns = expr_T_ensembl.columns.str.split('.').str[0]
            print("Expression gene IDs appear to be Ensembl. Skipping mygene mapping for expression and removing version numbers.")
    
    if not is_expr_ensembl:
        all_expr_genes = list(expr_T.columns.astype(str))
        expr_gene_map = {}

        # FIX: Prioritize BioMart mapping for expression if available
        if entrez_to_ensembl_mart:
            print("Attempting BioMart mapping for expression data (Entrez to Ensembl)...")
            expr_gene_map = {k: entrez_to_ensembl_mart.get(k, None) for k in all_expr_genes}
            expr_gene_map = {k: v for k, v in expr_gene_map.items() if v is not None} # Filter out unmapped
            print(f"BioMart mapped {len(expr_gene_map)} expression genes.")

        if not expr_gene_map: # If BioMart mapping failed or wasn't available, try mygene
            print("BioMart mapping for expression failed or not available. Attempting mygene mapping (Entrez/Symbol to Ensembl)...")
            expr_gene_map = batch_query_mygene_robust(all_expr_genes, scopes=['entrezgene', 'symbol'])
        
        if expr_gene_map:
            # Create a new DataFrame with mapped columns
            mapped_cols_data = {}
            for original_col, mapped_col in expr_gene_map.items():
                if original_col in expr_T.columns:
                    mapped_cols_data[mapped_col] = expr_T[original_col]
            
            expr_T_ensembl = pd.DataFrame(mapped_cols_data, index=expr_T.index)
            expr_T_ensembl = expr_T_ensembl.loc[:, ~expr_T_ensembl.columns.duplicated()] # Remove duplicate Ensembl IDs
            print(f"Expression data after mapping to Ensembl: {expr_T_ensembl.shape}")
        else:
            expr_T_ensembl = expr_T.copy() # Keep original columns if all mapping attempts fail
            print("Warning: Expression gene ID mapping failed completely. Proceeding with original expression gene IDs. This will prevent gene-level alignment with mutation data.")
            print(f"Expression data shape (original IDs): {expr_T_ensembl.shape}")

    # --- Mutation Gene Mapping ---
    maf_df_mapped = maf_df.copy()
    unique_hugo_symbols = maf_df_mapped['Hugo_Symbol'].unique().tolist()
    mutation_gene_map = {}

    # FIX: Prioritize BioMart mapping for mutation if available
    if hugo_to_ensembl_mart:
        print("Attempting BioMart mapping for mutation data (Hugo Symbol to Ensembl)...")
        mutation_gene_map = {k: hugo_to_ensembl_mart.get(k, None) for k in unique_hugo_symbols}
        mutation_gene_map = {k: v for k, v in mutation_gene_map.items() if v is not None} # Filter out unmapped
        print(f"BioMart mapped {len(mutation_gene_map)} mutation genes.")

    if not mutation_gene_map: # If BioMart mapping failed or wasn't available, try mygene
        print("BioMart mapping for mutation failed or not available. Attempting mygene mapping (Hugo Symbol to Ensembl)...")
        mutation_gene_map = batch_query_mygene_robust(unique_hugo_symbols, scopes=['symbol'])

    if mutation_gene_map:
        maf_df_mapped['Ensembl_ID'] = maf_df_mapped['Hugo_Symbol'].map(mutation_gene_map)
        maf_df_mapped = maf_df_mapped.dropna(subset=['Ensembl_ID'])
        print(f"Mutation records after mapping Hugo_Symbol to Ensembl: {maf_df_mapped.shape}")
    else:
        maf_df_mapped['Ensembl_ID'] = maf_df_mapped['Hugo_Symbol'] # Keep original Hugo Symbols as 'Ensembl_ID'
        print("Warning: Mutation gene ID mapping failed completely. Proceeding with original Hugo Symbols as gene IDs.")
        print(f"Mutation records shape (original Hugo Symbols): {maf_df_mapped.shape}")

    # Create binary mutation matrix (samples x Ensembl_IDs or Hugo_Symbols)
    if not maf_df_mapped.empty:
        # Ensure 'Ensembl_ID' column is unique before pivoting
        maf_df_mapped = maf_df_mapped.drop_duplicates(subset=['Tumor_Sample_Barcode', 'Ensembl_ID'])
        mutation_matrix = pd.crosstab(maf_df_mapped['Tumor_Sample_Barcode'], maf_df_mapped['Ensembl_ID'])
        mutation_matrix = mutation_matrix.clip(upper=1)
        print(f"Binary mutation matrix shape: {mutation_matrix.shape}")
    else:
        mutation_matrix = pd.DataFrame()
        print("Empty mutation matrix created due to no successful gene mapping or no mutation records.")


    # --- Standardize Sample IDs ---
    pheno_sample_id_col = 'sample' # Confirmed from previous debug output
    
    if pheno_sample_id_col not in pheno_df.columns:
        print(f"Error: The specified sample ID column '{pheno_sample_id_col}' not found in phenotype data.")
        return None, None, None, None
    
    print(f"Using '{pheno_sample_id_col}' as the sample ID column in phenotype data.")
    pheno_df_cleaned = pheno_df.copy()
    pheno_df_cleaned[pheno_sample_id_col] = pheno_df_cleaned[pheno_sample_id_col].astype(str).str.slice(0, 15).str.upper()
    
    # Remove duplicate sample IDs by taking the first occurrence for phenotype
    pheno_df_cleaned = pheno_df_cleaned.loc[~pheno_df_cleaned[pheno_sample_id_col].duplicated(keep='first')]
    pheno_df_cleaned = pheno_df_cleaned.set_index(pheno_sample_id_col)

    expr_T_ensembl.index = expr_T_ensembl.index.astype(str).str.slice(0, 15).str.upper()
    expr_T_ensembl = expr_T_ensembl.groupby(expr_T_ensembl.index).mean() # Aggregate duplicate sample IDs

    if not mutation_matrix.empty:
        mutation_matrix.index = mutation_matrix.index.astype(str).str.slice(0, 15).str.upper()
        mutation_matrix = mutation_matrix.loc[~mutation_matrix.index.duplicated(keep='first')] # Aggregate duplicate sample IDs

    # --- Intersect Samples Across All Omics ---
    common_samples = expr_T_ensembl.index.intersection(pheno_df_cleaned.index)
    if not mutation_matrix.empty:
        common_samples = common_samples.intersection(mutation_matrix.index)
    
    print(f"Common samples across all omics datasets: {len(common_samples)}")

    if len(common_samples) == 0:
        print("No common samples found across all datasets. Cannot proceed with multi-omics analysis.")
        return None, None, None, None

    # Subset all dataframes to common samples
    expr_aligned = expr_T_ensembl.loc[common_samples]
    mut_aligned = mutation_matrix.loc[common_samples] if not mutation_matrix.empty else pd.DataFrame(0, index=common_samples, columns=[])
    pheno_aligned = pheno_df_cleaned.loc[common_samples]

    # --- Align Genes (Columns) for Multi-Omics ---
    # Determine if both expression and mutation data have Ensembl IDs for gene-level alignment
    expr_cols_are_ensembl = not expr_aligned.empty and expr_aligned.shape[1] > 0 and \
                            all(isinstance(col, str) and col.startswith('ENSG') for col in expr_aligned.columns[:min(5, expr_aligned.shape[1])])
    mut_cols_are_ensembl = not mut_aligned.empty and mut_aligned.shape[1] > 0 and \
                           all(isinstance(col, str) and col.startswith('ENSG') for col in mut_aligned.columns[:min(5, mut_aligned.shape[1])])
    
    if expr_cols_are_ensembl and mut_cols_are_ensembl:
        print("Both expression and mutation genes are Ensembl IDs. Attempting gene-level alignment.")
        all_genes_union = expr_aligned.columns.union(mut_aligned.columns)
        expr_aligned = expr_aligned.reindex(columns=all_genes_union, fill_value=0)
        mut_aligned = mut_aligned.reindex(columns=all_genes_union, fill_value=0)
    else:
        print("Gene-level alignment skipped (either not all Ensembl IDs or one/both omics data are empty/unmapped). Features will be concatenated as-is.")
        # If not aligned by gene, ensure columns are unique before concatenation
        # This step is now handled during concatenation by adding suffixes
            
    print(f"Aligned expression data shape: {expr_aligned.shape}")
    print(f"Aligned mutation data shape: {mut_aligned.shape}")

    # --- Impute Missing Values (Expression) ---
    if expr_aligned.shape[1] > 0:
        imputer = SimpleImputer(strategy='median')
        expr_imputed = pd.DataFrame(imputer.fit_transform(expr_aligned),
                                    columns=expr_aligned.columns,
                                    index=expr_aligned.index)
        print(f"Expression data imputed. Missing values: {expr_imputed.isnull().sum().sum()}")
    else:
        expr_imputed = expr_aligned.copy()
        print("No expression columns to impute. Skipping imputation.")

    # --- Scale Data ---
    expr_scaled = pd.DataFrame()
    if expr_imputed.shape[1] > 0:
        scaler_expr = StandardScaler()
        expr_scaled = pd.DataFrame(scaler_expr.fit_transform(expr_imputed),
                                   columns=expr_imputed.columns,
                                   index=expr_imputed.index)
    else:
        print("No expression columns to scale. Skipping scaling.")

    mut_scaled = pd.DataFrame()
    if mut_aligned.shape[1] > 0:
        scaler_mut = StandardScaler()
        mut_scaled = pd.DataFrame(scaler_mut.fit_transform(mut_aligned),
                                  columns=mut_aligned.columns,
                                  index=mut_aligned.index)
    else:
        print("No mutation columns to scale. Skipping scaling.")
    
    print("Expression and mutation data scaled (if columns available).")

    # --- Concatenate Multi-Omics Data ---
    # Only rename columns if the DataFrame is not empty AND gene-level alignment was skipped
    # If gene-level alignment happened, columns are already aligned and unique
    if not (expr_cols_are_ensembl and mut_cols_are_ensembl): # Only add suffix if not aligned by gene
        if not expr_scaled.empty:
            expr_scaled.columns = [f"{col}_expr" for col in expr_scaled.columns]
        if not mut_scaled.empty:
            mut_scaled.columns = [f"{col}_mut" for col in mut_scaled.columns]

    if not expr_scaled.empty and not mut_scaled.empty:
        X_multiomics = pd.concat([expr_scaled, mut_scaled], axis=1)
    elif not expr_scaled.empty:
        X_multiomics = expr_scaled
    elif not mut_scaled.empty:
        X_multiomics = mut_scaled
    else:
        X_multiomics = pd.DataFrame(index=common_samples)
        print("Warning: Both expression and mutation data are empty after preprocessing. Multi-omics matrix is empty.")

    y_labels = pheno_aligned[target_col]

    print(f"Final Multi-omics feature matrix shape: {X_multiomics.shape}")
    print(f"Target labels shape: {y_labels.shape}")

    return X_multiomics, y_labels, pheno_aligned, expr_aligned

# ==============================================================================
# 2. Multi-Omics EDA & Visualization
# ==============================================================================

def perform_multi_omics_eda(X_multiomics, y_labels, pheno_aligned, expr_original_aligned):
    """
    Performs Exploratory Data Analysis and generates plots for multi-omics data.
    """
    print("\n--- Multi-Omics EDA and Visualization ---")

    # --- Basic Stats ---
    print(f"Combined Multi-omics data shape: {X_multiomics.shape}")
    print(f"Number of unique cancer types: {y_labels.nunique()}")
    print("Samples per cancer type:\n", y_labels.value_counts().head())

    # --- Plot: Cancer Type Distribution ---
    plt.figure(figsize=(14, 8))
    sns.countplot(y=y_labels, order=y_labels.value_counts().index, palette='viridis')
    plt.title("Sample Count per Cancer Type (Multi-Omics Cohort)", fontsize=18, weight='bold', color='darkblue')
    plt.xlabel("Number of Samples", fontsize=14, color='dimgray')
    plt.ylabel("Cancer Type", fontsize=14, color='dimgray')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=10)
    plt.grid(axis='x', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "multiomics_cancer_type_distribution.png"), dpi=300, bbox_inches='tight')
    plt.close()
    print("Cancer type distribution plot saved.")

    # --- Plot: Top Variable Expression Genes ---
    expr_original_numeric = expr_original_aligned.copy()
    
    if not expr_original_numeric.empty and expr_original_numeric.shape[1] > 0:
        gene_std = expr_original_numeric.std().sort_values(ascending=False)
        top_expr_genes = gene_std.head(50).index.tolist()

        if top_expr_genes:
            plt.figure(figsize=(15, 8))
            sns.boxplot(data=expr_original_numeric[top_expr_genes], palette='Set3')
            plt.xticks(rotation=90, fontsize=10)
            plt.yticks(fontsize=10)
            plt.title("Top 50 Most Variable Expression Genes", fontsize=18, weight='bold', color='darkblue')
            plt.xlabel("Gene", fontsize=14, color='dimgray')
            plt.ylabel("Expression Level", fontsize=14, color='dimgray')
            plt.grid(axis='y', linestyle='--', alpha=0.6)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "multiomics_top_variable_expr_genes.png"), dpi=300, bbox_inches='tight')
            plt.close()
            print("Top variable expression genes plot saved.")
        else:
            print("No top variable expression genes found for plotting.")
    else:
        print("Original expression data is empty or has no columns. Skipping top variable expression genes plot.")

    # --- Plot: Top Mutated Genes ---
    mut_data_cols = [col for col in X_multiomics.columns if col.endswith('_mut')]
    if mut_data_cols and len(mut_data_cols) > 0:
        mut_numeric = X_multiomics[mut_data_cols].copy()
        # Remove '_mut' suffix to get original gene names
        mut_numeric.columns = [col.replace('_mut', '') for col in mut_numeric.columns]

        top_mutated_genes_counts = mut_numeric.sum().sort_values(ascending=False).head(50)

        if not top_mutated_genes_counts.empty:
            # Only attempt mapping for plotting if genes are not already Ensembl
            # This logic needs to be careful if the columns are original Hugo Symbols
            # Let's assume for plotting we want Hugo Symbols if not mapped to Ensembl
            
            # Check if the columns are Ensembl IDs (from successful mapping)
            if all(isinstance(col, str) and col.startswith('ENSG') for col in top_mutated_genes_counts.index[:min(5, len(top_mutated_genes_counts))]):
                # If they are Ensembl, try to map them to Hugo Symbols for readability in plot
                mg = mygene.MyGeneInfo()
                ensembl_ids_to_map = top_mutated_genes_counts.index.tolist()
                gene_info = mg.querymany(ensembl_ids_to_map, scopes='ensembl.gene', fields='symbol', species='human', returnall=True)

                ensg_to_symbol = {}
                for entry in gene_info['out']:
                    if 'notfound' not in entry and 'symbol' in entry:
                        ensg_to_symbol[entry['query']] = entry['symbol']
                
                mapped_gene_names = [ensg_to_symbol.get(gene_id, gene_id) for gene_id in top_mutated_genes_counts.index]
                top_mutated_genes_counts.index = mapped_gene_names
                print("Mapped top mutated Ensembl IDs to Hugo Symbols for plotting.")
            else:
                print("Top mutated genes are not Ensembl IDs. Plotting with their current IDs.")
            
            plt.figure(figsize=(12, 8))
            sns.barplot(x=top_mutated_genes_counts.values, y=top_mutated_genes_counts.index, palette='Reds_d')
            plt.title("Top 50 Most Frequently Mutated Genes (by Sample Count)", fontsize=18, weight='bold', color='darkblue')
            plt.xlabel("Number of Mutated Samples", fontsize=14, color='dimgray')
            plt.ylabel("Gene", fontsize=14, color='dimgray')
            plt.xticks(fontsize=12)
            plt.yticks(fontsize=12)
            plt.grid(axis='x', linestyle='--', alpha=0.6)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "multiomics_top_mutated_genes_barplot.png"), dpi=300, bbox_inches='tight')
            plt.close()
            print("Top mutated genes bar plot saved.")
        else:
            print("No top mutated genes found for plotting.")
    else:
        print("No mutation data columns found. Skipping top mutated genes plot.")

    # --- NEW PLOT: Top Mutation Frequencies Across Cancer Types Heatmap ---
    # This plot requires 'mut_numeric' and 'y_labels' to be available and not empty.
    # It also requires a list of top mutated gene IDs (either Ensembl or Hugo for plotting).
    
    # Re-evaluate conditions for heatmap plotting
    if mut_data_cols and len(mut_data_cols) > 0 and not y_labels.empty:
        mut_numeric = X_multiomics[[col for col in X_multiomics.columns if col.endswith('_mut')]].copy()
        mut_numeric.columns = [col.replace('_mut', '') for col in mut_numeric.columns] # Remove suffix for gene names

        top_20_mutated_genes_ensembl_or_original = mut_numeric.sum().sort_values(ascending=False).head(20).index.tolist()

        if top_20_mutated_genes_ensembl_or_original:
            mut_df_for_heatmap = mut_numeric[top_20_mutated_genes_ensembl_or_original].copy()
            mut_df_for_heatmap['CancerType'] = y_labels

            mutation_frequency_per_cancer = mut_df_for_heatmap.groupby('CancerType')[top_20_mutated_genes_ensembl_or_original].sum()
            
            sample_counts = y_labels.value_counts()
            mutation_frequency_per_cancer = mutation_frequency_per_cancer.div(sample_counts, axis=0) * 100

            # Map gene names for heatmap labels if they are Ensembl IDs
            heatmap_gene_names = top_20_mutated_genes_ensembl_or_original
            if all(isinstance(col, str) and col.startswith('ENSG') for col in top_20_mutated_genes_ensembl_or_original[:min(5, len(top_20_mutated_genes_ensembl_or_original))]):
                mg = mygene.MyGeneInfo()
                gene_info_heatmap = mg.querymany(top_20_mutated_genes_ensembl_or_original, scopes='ensembl.gene', fields='symbol', species='human', returnall=True)
                ensg_to_symbol_heatmap = {}
                for entry in gene_info_heatmap['out']:
                    if 'notfound' not in entry and 'symbol' in entry:
                        ensg_to_symbol_heatmap[entry['query']] = entry['symbol']
                heatmap_gene_names = [ensg_to_symbol_heatmap.get(gene_id, gene_id) for gene_id in top_20_mutated_genes_ensembl_or_original]
                print("Mapped heatmap gene IDs to Hugo Symbols for readability.")
            else:
                print("Heatmap gene IDs are not Ensembl. Plotting with their current IDs.")

            mutation_frequency_per_cancer.columns = heatmap_gene_names

            plt.figure(figsize=(16, 12))
            sns.heatmap(mutation_frequency_per_cancer.T, cmap='YlGnBu', annot=True, fmt=".1f", linewidths=.5, linecolor='gray',
                        cbar_kws={'label': 'Mutation Frequency (%)'})
            plt.title("Top 20 Mutation Frequencies Across Cancer Types", fontsize=18, weight='bold', color='darkblue')
            plt.xlabel("Cancer Type", fontsize=14, color='dimgray')
            plt.ylabel("Gene", fontsize=14, color='dimgray')
            plt.xticks(rotation=45, ha='right', fontsize=10)
            plt.yticks(fontsize=10)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "multiomics_mutation_frequency_heatmap.png"), dpi=300, bbox_inches='tight')
            plt.close()
            print("Mutation frequency heatmap saved.")
        else:
            print("No top 20 mutated genes to plot in heatmap.")
    else:
        print("Insufficient mutation data or labels for mutation frequency heatmap.")


# ==============================================================================
# 3. Multi-Omics Machine Learning & Dimensionality Reduction
# ==============================================================================

def perform_multi_omics_ml(X_multiomics, y_labels):
    """
    Performs dimensionality reduction and trains a Random Forest classifier
    on multi-omics data.
    """
    print("\n--- Multi-Omics Machine Learning & Dimensionality Reduction ---")

    le = LabelEncoder()
    y_encoded = le.fit_transform(y_labels)
    class_names = le.classes_
    print(f"Encoded {len(class_names)} cancer types.")

    n_components_pca = min(50, X_multiomics.shape[1])
    if X_multiomics.shape[1] > 1 and n_components_pca > 1:
        pca = PCA(n_components=n_components_pca, random_state=42)
        X_pca = pca.fit_transform(X_multiomics)
        print(f"PCA reduced data shape: {X_pca.shape}")
        print(f"PCA explained variance ratio (first {n_components_pca} components): {np.sum(pca.explained_variance_ratio_):.2f}")
    else:
        print("Not enough features for PCA. Skipping PCA.")
        X_pca = X_multiomics.values

    if X_pca.shape[1] >= 2:
        reducer = umap.UMAP(n_components=2, random_state=42, metric='euclidean')
        X_umap = reducer.fit_transform(X_pca)
        print(f"UMAP reduced data shape: {X_umap.shape}")

        plt.figure(figsize=(12, 10))
        sns.scatterplot(x=X_umap[:, 0], y=X_umap[:, 1], hue=y_labels, palette='tab20',
                        s=80, alpha=0.85, edgecolor='black', linewidth=0.7)
        plt.title("UMAP Projection of Multi-Omics Data (Expression + Mutation)", fontsize=18, weight='bold', color='darkblue')
        plt.xlabel("UMAP Component 1", fontsize=14, color='dimgray')
        plt.ylabel("UMAP Component 2", fontsize=14, color='dimgray')
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.legend(title="Cancer Type", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize='small', title_fontsize=12)
        plt.grid(True, linestyle=':', alpha=0.5)
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        plt.savefig(os.path.join(output_dir, "multiomics_umap_projection.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print("UMAP projection plot saved.")
    else:
        print("Not enough dimensions after PCA for UMAP. Skipping UMAP plot.")

    if X_multiomics.shape[0] > 1 and X_multiomics.shape[1] > 0 and len(np.unique(y_encoded)) > 1:
        X_train, X_test, y_train, y_test = train_test_split(
            X_multiomics, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
        )
        print(f"Training data shape: {X_train.shape}, Test data shape: {X_test.shape}")

        clf = RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1, verbose=1)
        print("Training Random Forest Classifier...")
        clf.fit(X_train, y_train)
        print("Random Forest Classifier training complete.")

        y_pred = clf.predict(X_test)

        print("\n📊 Classification Report (Multi-Omics Random Forest):")
        print(classification_report(y_test, y_pred, target_names=class_names, zero_division=0))

        cm = confusion_matrix(y_test, y_pred, normalize='true', labels=range(len(class_names)))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)

        fig, ax = plt.subplots(figsize=(16, 16))
        disp.plot(cmap='Blues', ax=ax, colorbar=False, xticks_rotation='vertical')
        plt.title("Normalized Confusion Matrix (Multi-Omics Random Forest)", fontsize=18, weight='bold', color='darkblue')
        plt.xlabel("Predicted Label", fontsize=14, color='dimgray')
        plt.ylabel("True Label", fontsize=14, color='dimgray')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "multiomics_confusion_matrix.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print("Normalized Confusion Matrix plot saved.")
    else:
        print("Insufficient data or classes for Machine Learning. Skipping model training and evaluation.")

    print("\nMulti-Omics Machine Learning & Dimensionality Reduction complete.")


# ==============================================================================
# Main Execution Block for Multi-Omics Analysis
# ==============================================================================
if __name__ == "__main__":
    print("Starting Multi-Omics Integration & Analysis Phase...")

    try:
        X_multiomics_features, y_multiomics_labels, pheno_aligned_df, expr_original_aligned_df = \
            load_and_preprocess_multi_omics_data(expr_path, maf_path, pheno_path, mart_path, target_column)

        if X_multiomics_features is None or X_multiomics_features.empty or y_multiomics_labels.empty:
            print("Multi-omics data preparation failed or resulted in empty data. Exiting.")
        else:
            perform_multi_omics_eda(X_multiomics_features, y_multiomics_labels, pheno_aligned_df, expr_original_aligned_df)
            perform_multi_omics_ml(X_multiomics_features, y_multiomics_labels)

        print("\nMulti-Omics Integration & Analysis Phase complete.")

    except ImportError:
        print("\nOne or more required libraries (mygene, umap-learn) not found. Please install them.")
        print("Skipping Multi-Omics Integration & Analysis.")
    except Exception as e:
        print(f"An unexpected error occurred during Multi-Omics Analysis: {e}")
        import traceback
        traceback.print_exc()


# ==============================================================================
# Phase 1.6: Advanced Multi-Omics Integration & Machine Learning
# This script integrates gene expression, mutation, copy number variation (CNV),
# and miRNA expression data. It performs comprehensive EDA, dimensionality
# reduction (PCA, UMAP), and trains multiple machine learning models for
# cancer type classification using the integrated multi-omics data.
#
# Before running:
# 1. Ensure you have necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn umap-learn mygene lightgbm
# 2. Ensure you have the following data files in the specified 'data_dir':
#    - EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena (Gene Expression)
#    - mc3.v0.2.8.PUBLIC.maf (Somatic Mutation)
#    - TCGA_phenotype_denseDataOnlyDownload.tsv (Phenotype)
#    - mart_export.txt (BioMart gene mapping - if available)
#    - pancanMiRs_EBadjOnProtocolPlatformWithoutRepsWithUnCorrectMiRs_08_04_16.xena (miRNA Expression)
#    - TCGA.PANCAN.sampleMap_Gistic2_CopyNumber_Gistic2_all_data_by_genes (Copy Number Variation)
#
# NOTE: Methylation data (e.g., beta values) is NOT included as only a probe map
#       was provided. If you have the actual methylation data, please specify
#       its filename for inclusion in a future update.
# ==============================================================================

In [3]:
# ==============================================================================
# Multi-Omics Data Analysis Workflow for Jupyter Notebook (EDA Focus)
# This script integrates Data Preparation and comprehensive Exploratory Data Analysis (EDA)
# and Visualization. All machine learning, GSEA, and misclassification analysis
# components have been removed to focus purely on data understanding.
#
# Before running:
# 1. Ensure you have all necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn umap-learn mygene scipy
# 2. Update the 'raw_expr_file' and 'raw_pheno_file' paths in the 'Data Preparation' section
#    to point to your actual raw data files.
# ==============================================================================

# --- General Imports (common to all modules) ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import mygene # For robust gene ID mapping
import time # Import the time module for time.sleep()
from scipy import stats # For t-test in Volcano Plot (if used in EDA)

# --- Specific Imports for EDA & Visualization ---
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
import umap.umap_ as umap


# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid')


# ==============================================================================
# 1. Data Preparation Functions
# ==============================================================================

def load_and_preprocess_multi_omics_data_advanced(expr_file, maf_file, pheno_file, mart_file, mirna_file, cnv_file, rppa_file, meth_file, target_col):
    """
    Loads and preprocesses expression, mutation, miRNA, CNV, RPPA, Methylation,
    and phenotype data, aligning samples and genes/features, and preparing for
    multi-omics analysis.
    Includes log transformation for expression/miRNA and low-variance feature filtering.
    """
    print("\n--- Multi-Omics Data Loading and Preprocessing (Advanced) ---")

    # --- Load Data ---
    try:
        expr_df = pd.read_csv(expr_file, sep="\t", index_col=0, on_bad_lines='skip')
        maf_df = pd.read_csv(maf_file, sep="\t", comment='#', low_memory=False, on_bad_lines='skip')
        pheno_df = pd.read_csv(pheno_file, sep="\t", low_memory=False, on_bad_lines='skip')
        mirna_df = pd.read_csv(mirna_file, sep="\t", index_col=0, on_bad_lines='skip')
        cnv_df = pd.read_csv(cnv_file, sep="\t", index_col=0, on_bad_lines='skip')
        rppa_df = pd.read_csv(rppa_file, sep="\t", index_col=0, on_bad_lines='skip')
        
        print(f"Attempting to load methylation data in chunks from: {meth_file}")
        CHUNK_SIZE = 50000
        TOP_METHYLATION_PROBES_TO_KEEP = 1000 

        all_probes_std = pd.Series(dtype=float)
        header_df = pd.read_csv(meth_file, sep="\t", nrows=0, on_bad_lines='skip')
        index_col_name = header_df.columns[0]
        data_columns_in_file = header_df.columns[1:].tolist()

        for i, chunk in enumerate(pd.read_csv(meth_file, sep="\t", index_col=0, chunksize=CHUNK_SIZE, on_bad_lines='skip')):
            chunk_numeric = chunk.select_dtypes(include=np.number)
            if not chunk_numeric.empty:
                current_chunk_std = chunk_numeric.std()
                all_probes_std = all_probes_std.add(current_chunk_std, fill_value=0)
            print(f"  Processed chunk {i+1}, current chunk shape: {chunk.shape}")

        if not all_probes_std.empty:
            all_probes_std = all_probes_std[all_probes_std.index.isin(data_columns_in_file)]
            top_variable_probes = all_probes_std.nlargest(TOP_METHYLATION_PROBES_TO_KEEP).index.tolist()
            print(f"Identified top {len(top_variable_probes)} most variable methylation probes from {len(all_probes_std)} total probes.")
        else:
            top_variable_probes = []
            print("No numeric methylation probes found to determine variability or no probes in data_columns_in_file.")

        if top_variable_probes:
            columns_to_load_by_name = [index_col_name] + top_variable_probes
            meth_df = pd.read_csv(meth_file, sep="\t", index_col=0, usecols=columns_to_load_by_name, on_bad_lines='skip')
            print(f"Methylation data re-loaded with only top {len(top_variable_probes)} variable probes. Final shape: {meth_df.shape}")
        else:
            meth_df = pd.DataFrame()
            print("No top variable methylation probes identified or loaded. Methylation DataFrame is empty.")

        print("Raw expression, mutation (MAF), phenotype, miRNA, CNV, RPPA, and Methylation data loaded.")
    except FileNotFoundError as e:
        print(f"Error loading data: {e}. Please ensure all files are in the '{data_dir}' directory.")
        return None, None, None, None, None, None, None, None

    print("\n--- Debug: Phenotype DataFrame Columns ---")
    print(pheno_df.columns.tolist())
    print("------------------------------------------")

    # --- Gene/miRNA ID Mapping Helper Functions ---
    mg = mygene.MyGeneInfo()
    
    def batch_query_mygene_robust(ids, scopes, fields='ensembl.gene', species='human', batch_size=1000):
        mapping = {}
        successful_queries = 0
        not_found_queries_list = []
        for i in range(0, len(ids), batch_size):
            batch = ids[i:i + batch_size]
            try:
                res = mg.querymany(batch, scopes=scopes, fields=fields, species=species, returnall=True)
                
                for nf_item in res.get('notfound', []):
                    not_found_queries_list.append(nf_item['query'])

                for r in res['out']:
                    query_id = str(r['query'])
                    ensembl_info = r.get(fields)
                    if ensembl_info:
                        ensembl_id = None
                        if isinstance(ensembl_info, list):
                            for item in ensembl_info:
                                if isinstance(item, dict) and 'gene' in item:
                                    ensembl_id = item['gene']
                                    break
                        elif isinstance(ensembl_info, dict) and 'gene' in ensembl_info:
                            ensembl_id = ensembl_info['gene']
                        
                        if ensembl_id:
                            mapping[query_id] = ensembl_id
                            successful_queries += 1
            except Exception as e:
                print(f"❌ Error in mygene batch query for scopes '{scopes}': {e}")
            time.sleep(0.1)
        print(f"Successfully mapped {successful_queries} out of {len(ids)} terms for scopes '{scopes}'.")
        if not_found_queries_list:
            print(f"First 10 unmapped queries for scopes '{scopes}': {not_found_queries_list[:10]}")
        return mapping

    def load_mart_mapping(mart_file_path):
        hugo_to_ensembl = {}
        entrez_to_ensembl = {}
        try:
            mart_df = pd.read_csv(mart_file_path, sep="\t", low_memory=False, on_bad_lines='skip')
            print(f"BioMart mapping file '{os.path.basename(mart_file_path)}' loaded.")
            print(f"BioMart columns: {mart_df.columns.tolist()}")

            if 'Gene name' in mart_df.columns and 'Gene stable ID' in mart_df.columns:
                temp_map_df = mart_df[['Gene name', 'Gene stable ID']].dropna().drop_duplicates()
                hugo_to_ensembl = dict(zip(temp_map_df['Gene name'], temp_map_df['Gene stable ID']))
                print(f"Found {len(hugo_to_ensembl)} Hugo Symbol to Ensembl mappings in BioMart file.")
            
            if 'NCBI gene ID' in mart_df.columns and 'Gene stable ID' in mart_df.columns:
                temp_map_df = mart_df[['NCBI gene ID', 'Gene stable ID']].dropna().drop_duplicates()
                entrez_to_ensembl = dict(zip(temp_map_df['NCBI gene ID'].astype(str), temp_map_df['Gene stable ID']))
                print(f"Found {len(entrez_to_ensembl)} Entrez ID to Ensembl mappings in BioMart file.")

        except FileNotFoundError:
            print(f"Warning: BioMart mapping file '{os.path.basename(mart_file_path)}' not found. Skipping BioMart mapping.")
        except Exception as e:
            print(f"Error loading or processing BioMart mapping file: {e}")
        return hugo_to_ensembl, entrez_to_ensembl

    hugo_to_ensembl_mart, entrez_to_ensembl_mart = load_mart_mapping(mart_file)

    # --- Preprocess Gene Expression Data ---
    expr_T = expr_df.T
    print(f"Initial Gene Expression data shape (samples x genes): {expr_T.shape}")
    print(f"First 5 gene expression IDs (before mapping): {expr_T.columns[:5].tolist()}")

    expr_T_ensembl = pd.DataFrame()
    is_expr_ensembl = False
    if len(expr_T.columns) > 0 and isinstance(expr_T.columns[0], str):
        if all(col.startswith('ENSG') and '.' in col for col in expr_T.columns[:min(5, len(expr_T.columns))]):
            is_expr_ensembl = True
            expr_T_ensembl = expr_T.copy()
            expr_T_ensembl.columns = expr_T_ensembl.columns.str.split('.').str[0]
            print("Gene Expression IDs appear to be Ensembl. Removing version numbers.")
    
    if not is_expr_ensembl:
        all_expr_genes = list(expr_T.columns.astype(str))
        expr_gene_map = {}

        if entrez_to_ensembl_mart:
            print("Attempting BioMart mapping for Gene Expression (Entrez to Ensembl)...")
            expr_gene_map = {k: entrez_to_ensembl_mart.get(k, None) for k in all_expr_genes}
            expr_gene_map = {k: v for k, v in expr_gene_map.items() if v is not None}
            print(f"BioMart mapped {len(expr_gene_map)} gene expression IDs.")

        if not expr_gene_map:
            print("BioMart mapping for Gene Expression failed or not available. Attempting mygene mapping (Entrez/Symbol to Ensembl)...")
            expr_gene_map = batch_query_mygene_robust(all_expr_genes, scopes=['entrezgene', 'symbol'])
        
        if expr_gene_map:
            mapped_cols_data = {}
            for original_col, mapped_col in expr_gene_map.items():
                if original_col in expr_T.columns:
                    mapped_cols_data[mapped_col] = expr_T[original_col]
            expr_T_ensembl = pd.DataFrame(mapped_cols_data, index=expr_T.index)
            expr_T_ensembl = expr_T_ensembl.loc[:, ~expr_T_ensembl.columns.duplicated()]
            print(f"Gene Expression data after mapping to Ensembl: {expr_T_ensembl.shape}")
        else:
            expr_T_ensembl = expr_T.copy()
            print("Warning: Gene Expression ID mapping failed completely. Proceeding with original gene IDs. This may affect gene-level alignment.")
            print(f"Gene Expression data shape (original IDs): {expr_T_ensembl.shape}")

    # --- Preprocess Somatic Mutation Data ---
    if 'FILTER' in maf_df.columns:
        maf_df = maf_df[maf_df['FILTER'] == 'PASS']
    maf_df = maf_df[["Tumor_Sample_Barcode", "Hugo_Symbol"]].dropna().drop_duplicates()
    print(f"Mutation records after dropping NaNs and duplicates: {maf_df.shape}")

    maf_df_mapped = maf_df.copy()
    unique_hugo_symbols = maf_df_mapped['Hugo_Symbol'].unique().tolist()
    mutation_gene_map = {}

    if hugo_to_ensembl_mart:
        print("Attempting BioMart mapping for Mutation data (Hugo Symbol to Ensembl)...")
        mutation_gene_map = {k: hugo_to_ensembl_mart.get(k, None) for k in unique_hugo_symbols}
        mutation_gene_map = {k: v for k, v in mutation_gene_map.items() if v is not None}
        print(f"BioMart mapped {len(mutation_gene_map)} mutation genes.")

    if not mutation_gene_map:
        print("BioMart mapping for Mutation failed or not available. Attempting mygene mapping (Hugo Symbol to Ensembl)...")
        mutation_gene_map = batch_query_mygene_robust(unique_hugo_symbols, scopes=['symbol'])

    if mutation_gene_map:
        maf_df_mapped['Ensembl_ID'] = maf_df_mapped['Hugo_Symbol'].map(mutation_gene_map)
        maf_df_mapped = maf_df_mapped.dropna(subset=['Ensembl_ID'])
        print(f"Mutation records after mapping Hugo_Symbol to Ensembl: {maf_df_mapped.shape}")
    else:
        maf_df_mapped['Ensembl_ID'] = maf_df_mapped['Hugo_Symbol']
        print("Warning: Mutation gene ID mapping failed completely. Proceeding with original Hugo Symbols as gene IDs.")
        print(f"Mutation records shape (original Hugo Symbols): {maf_df_mapped.shape}")

    mutation_matrix = pd.DataFrame()
    if not maf_df_mapped.empty:
        maf_df_mapped = maf_df_mapped.drop_duplicates(subset=['Tumor_Sample_Barcode', 'Ensembl_ID'])
        mutation_matrix = pd.crosstab(maf_df_mapped['Tumor_Sample_Barcode'], maf_df_mapped['Ensembl_ID'])
        mutation_matrix = mutation_matrix.clip(upper=1)
        print(f"Binary mutation matrix shape: {mutation_matrix.shape}")
    else:
        print("Empty mutation matrix created due to no successful gene mapping or no mutation records.")

    # --- Preprocess miRNA Expression Data ---
    mirna_T = mirna_df.T
    print(f"Initial miRNA Expression data shape (samples x miRNAs): {mirna_T.shape}")
    print(f"First 5 miRNA IDs: {mirna_T.columns[:5].tolist()}")

    # --- Preprocess Copy Number Variation (CNV) Data ---
    cnv_T = cnv_df.T # Samples as rows, genes as columns
    print(f"Initial CNV data shape (samples x genes): {cnv_T.shape}")
    print(f"First 5 CNV gene IDs: {cnv_T.columns[:5].tolist()}")

    cnv_T_ensembl = pd.DataFrame()
    is_cnv_ensembl = False
    if len(cnv_T.columns) > 0 and isinstance(cnv_T.columns[0], str):
        if all(col.startswith('ENSG') and '.' in col for col in cnv_T.columns[:min(5, len(cnv_T.columns))]):
            is_cnv_ensembl = True
            cnv_T_ensembl = cnv_T.copy()
            cnv_T_ensembl.columns = cnv_T_ensembl.columns.str.split('.').str[0]
            print("CNV gene IDs appear to be Ensembl. Removing version numbers.")
    
    if not is_cnv_ensembl:
        all_cnv_genes = list(cnv_T.columns.astype(str))
        cnv_gene_map = {}

        if hugo_to_ensembl_mart:
            print("Attempting BioMart mapping for CNV data (Hugo Symbol to Ensembl)...")
            cnv_gene_map = {k: hugo_to_ensembl_mart.get(k, None) for k in all_cnv_genes}
            cnv_gene_map = {k: v for k, v in cnv_gene_map.items() if v is not None}
            print(f"BioMart mapped {len(cnv_gene_map)} CNV genes.")

        if not cnv_gene_map:
            print("BioMart mapping for CNV failed or not available. Attempting mygene mapping (Symbol to Ensembl)...")
            cnv_gene_map = batch_query_mygene_robust(all_cnv_genes, scopes=['symbol'])
        
        if cnv_gene_map:
            mapped_cols_data = {}
            for original_col, mapped_col in cnv_gene_map.items():
                if original_col in cnv_T.columns:
                    mapped_cols_data[mapped_col] = cnv_T[original_col]
            cnv_T_ensembl = pd.DataFrame(mapped_cols_data, index=cnv_T.index)
            cnv_T_ensembl = cnv_T_ensembl.loc[:, ~cnv_T_ensembl.columns.duplicated()]
            print(f"CNV data after mapping to Ensembl: {cnv_T_ensembl.shape}")
        else:
            cnv_T_ensembl = cnv_T.copy()
            print("Warning: CNV gene ID mapping failed completely. Proceeding with original CNV gene IDs. This may affect gene-level alignment.")
            print(f"CNV data shape (original IDs): {cnv_T_ensembl.shape}")

    # NEW: Preprocess RPPA Data
    rppa_T = rppa_df.T
    print(f"Initial RPPA data shape (samples x proteins): {rppa_T.shape}")
    print(f"First 5 RPPA protein IDs: {rppa_T.columns[:5].tolist()}")

    # NEW: Preprocess Methylation Data
    meth_T = meth_df.T
    print(f"Initial Methylation data shape (samples x probes): {meth_T.shape}")
    print(f"First 5 Methylation probe IDs: {meth_T.columns[:5].tolist()}")

    # --- Standardize Sample IDs Across All Omics and Phenotype ---
    pheno_sample_id_col = 'sample'
    
    if pheno_sample_id_col not in pheno_df.columns:
        print(f"Error: The specified sample ID column '{pheno_sample_id_col}' not found in phenotype data.")
        return None, None, None, None, None, None, None, None

    # Ensure all sample IDs are consistent (first 15 characters, uppercase)
    pheno_df_cleaned = pheno_df.copy()
    pheno_df_cleaned[pheno_sample_id_col] = pheno_df_cleaned[pheno_sample_id_col].astype(str).str.slice(0, 15).str.upper()
    pheno_df_cleaned = pheno_df_cleaned.loc[~pheno_df_cleaned[pheno_sample_id_col].duplicated(keep='first')]
    pheno_df_cleaned = pheno_df_cleaned.set_index(pheno_sample_id_col)

    expr_T_ensembl.index = expr_T_ensembl.index.astype(str).str.slice(0, 15).str.upper()
    expr_T_ensembl = expr_T_ensembl.groupby(expr_T_ensembl.index).mean()

    if not mutation_matrix.empty:
        mutation_matrix.index = mutation_matrix.index.astype(str).str.slice(0, 15).str.upper()
        mutation_matrix = mutation_matrix.loc[~mutation_matrix.index.duplicated(keep='first')]

    mirna_T.index = mirna_T.index.astype(str).str.slice(0, 15).str.upper()
    mirna_T = mirna_T.groupby(mirna_T.index).mean()

    cnv_T_ensembl.index = cnv_T_ensembl.index.astype(str).str.slice(0, 15).str.upper()
    cnv_T_ensembl = cnv_T_ensembl.groupby(cnv_T_ensembl.index).mean()

    rppa_T.index = rppa_T.index.astype(str).str.slice(0, 15).str.upper()
    rppa_T = rppa_T.groupby(rppa_T.index).mean()

    meth_T.index = meth_T.index.astype(str).str.slice(0, 15).str.upper()
    meth_T = meth_T.groupby(meth_T.index).mean()


    # --- Intersect Samples Across All Omics ---
    common_samples = expr_T_ensembl.index.intersection(pheno_df_cleaned.index)
    if not mutation_matrix.empty:
        common_samples = common_samples.intersection(mutation_matrix.index)
    if not mirna_T.empty:
        common_samples = common_samples.intersection(mirna_T.index)
    if not cnv_T_ensembl.empty:
        common_samples = common_samples.intersection(cnv_T_ensembl.index)
    if not rppa_T.empty:
        common_samples = common_samples.intersection(rppa_T.index)
    if not meth_T.empty:
        common_samples = common_samples.intersection(meth_T.index)
    
    print(f"Common samples across all omics datasets: {len(common_samples)}")

    if len(common_samples) == 0:
        print("No common samples found across all datasets. Cannot proceed with multi-omics analysis.")
        return None, None, None, None, None, None, None, None

    # Subset all dataframes to common samples
    expr_aligned = expr_T_ensembl.loc[common_samples]
    mut_aligned = mutation_matrix.loc[common_samples] if not mutation_matrix.empty else pd.DataFrame(0, index=common_samples, columns=[])
    mirna_aligned = mirna_T.loc[common_samples] if not mirna_T.empty else pd.DataFrame(0, index=common_samples, columns=[])
    cnv_aligned = cnv_T_ensembl.loc[common_samples] if not cnv_T_ensembl.empty else pd.DataFrame(0, index=common_samples, columns=[])
    rppa_aligned = rppa_T.loc[common_samples] if not rppa_T.empty else pd.DataFrame(0, index=common_samples, columns=[])
    meth_aligned = meth_T.loc[common_samples] if not meth_T.empty else pd.DataFrame(0, index=common_samples, columns=[])
    pheno_aligned = pheno_df_cleaned.loc[common_samples]

    # --- Align Genes/Features (Columns) for Gene-Level Omics ---
    expr_cols_are_ensembl = not expr_aligned.empty and expr_aligned.shape[1] > 0 and \
                            all(isinstance(col, str) and col.startswith('ENSG') for col in expr_aligned.columns[:min(5, len(expr_aligned.columns))])
    mut_cols_are_ensembl = not mut_aligned.empty and mut_aligned.shape[1] > 0 and \
                           all(isinstance(col, str) and col.startswith('ENSG') for col in mut_aligned.columns[:min(5, len(mut_aligned.columns))])
    cnv_cols_are_ensembl = not cnv_aligned.empty and cnv_aligned.shape[1] > 0 and \
                           all(isinstance(col, str) and col.startswith('ENSG') for col in cnv_aligned.columns[:min(5, len(cnv_aligned.columns))])
    
    if expr_cols_are_ensembl and mut_cols_are_ensembl and cnv_cols_are_ensembl:
        print("All gene-level omics (Expression, Mutation, CNV) are Ensembl IDs. Attempting gene-level alignment.")
        all_gene_features_union = expr_aligned.columns.union(mut_aligned.columns).union(cnv_aligned.columns)
        
        expr_aligned = expr_aligned.reindex(columns=all_gene_features_union, fill_value=0)
        mut_aligned = mut_aligned.reindex(columns=all_gene_features_union, fill_value=0)
        cnv_aligned = cnv_aligned.reindex(columns=all_gene_features_union, fill_value=0)
    else:
        print("Gene-level alignment skipped (either not all Ensembl IDs or one/more gene-level omics data are empty/unmapped). Features will be concatenated as-is.")
            
    print(f"Aligned gene expression data shape: {expr_aligned.shape}")
    print(f"Aligned mutation data shape: {mut_aligned.shape}")
    print(f"Aligned miRNA expression data shape: {mirna_aligned.shape}")
    print(f"Aligned CNV data shape: {cnv_aligned.shape}")
    print(f"Aligned RPPA data shape: {rppa_aligned.shape}")
    print(f"Aligned Methylation data shape: {meth_aligned.shape}")


    # --- Impute Missing Values and Scale Data for Each Omics ---
    
    # Helper function for filtering low variance and processing
    def process_omics_data(df_aligned, omics_name, apply_log_transform=False):
        if df_aligned.empty or df_aligned.shape[1] == 0:
            print(f"No {omics_name} columns to process. Skipping imputation and scaling.")
            return pd.DataFrame(index=df_aligned.index), df_aligned # Return empty scaled, and original aligned

        # Drop columns that are entirely NaN before imputation
        df_filtered = df_aligned.dropna(axis=1, how='all')
        if df_filtered.empty:
            print(f"{omics_name} data became empty after dropping all-NaN columns. Skipping imputation and scaling.")
            return pd.DataFrame(index=df_aligned.index), df_aligned

        # Apply log transformation if specified (for expression/miRNA)
        if apply_log_transform:
            # Add a small constant (1) to avoid log(0) issues
            df_filtered = np.log2(df_filtered + 1)
            print(f"{omics_name} data log2 transformed.")

        # Remove features with zero variance (after log transform if applied)
        # This is a simple form of noise elimination
        initial_cols = df_filtered.shape[1]
        df_filtered = df_filtered.loc[:, df_filtered.std() > 1e-6] # Keep columns with std > a very small number
        if df_filtered.shape[1] < initial_cols:
            print(f"Removed {initial_cols - df_filtered.shape[1]} zero-variance features from {omics_name} data.")
        
        if df_filtered.empty:
            print(f"{omics_name} data became empty after dropping zero-variance columns. Skipping imputation and scaling.")
            return pd.DataFrame(index=df_aligned.index), df_aligned

        imputer = SimpleImputer(strategy='median')
        df_imputed = pd.DataFrame(imputer.fit_transform(df_filtered),
                                  columns=df_filtered.columns,
                                  index=df_filtered.index)
        print(f"{omics_name} data imputed. Missing values: {df_imputed.isnull().sum().sum()}")

        scaler = StandardScaler()
        df_scaled = pd.DataFrame(scaler.fit_transform(df_imputed),
                                 columns=df_imputed.columns,
                                 index=df_imputed.index)
        return df_scaled, df_aligned # Return scaled data and original aligned (potentially log-transformed and filtered)

    expr_scaled, expr_original_aligned_processed = process_omics_data(expr_aligned, "Gene Expression", apply_log_transform=True)
    mut_scaled, _ = process_omics_data(mut_aligned, "Mutation") # Mutation is binary, no log transform
    mirna_scaled, mirna_original_aligned_processed = process_omics_data(mirna_aligned, "miRNA Expression", apply_log_transform=True)
    cnv_scaled, cnv_original_aligned_processed = process_omics_data(cnv_aligned, "CNV")
    rppa_scaled, rppa_original_aligned_processed = process_omics_data(rppa_aligned, "RPPA")
    meth_scaled, meth_original_aligned_processed = process_omics_data(meth_aligned, "Methylation")
    
    print("All omics data imputed and scaled (if columns available).")

    # --- Concatenate All Scaled Multi-Omics Data ---
    final_omics_dfs = []
    
    if not expr_scaled.empty:
        if not (expr_cols_are_ensembl and mut_cols_are_ensembl and cnv_cols_are_ensembl):
            expr_scaled.columns = [f"{col}_expr" for col in expr_scaled.columns]
        final_omics_dfs.append(expr_scaled)
    
    if not mut_scaled.empty:
        if not (expr_cols_are_ensembl and mut_cols_are_ensembl and cnv_cols_are_ensembl):
            mut_scaled.columns = [f"{col}_mut" for col in mut_scaled.columns]
        final_omics_dfs.append(mut_scaled)

    if not mirna_scaled.empty:
        mirna_scaled.columns = [f"{col}_mirna" for col in mirna_scaled.columns]
        final_omics_dfs.append(mirna_scaled)

    if not cnv_scaled.empty:
        if not (expr_cols_are_ensembl and mut_cols_are_ensembl and cnv_cols_are_ensembl):
            cnv_scaled.columns = [f"{col}_cnv" for col in cnv_scaled.columns]
        final_omics_dfs.append(cnv_scaled)

    if not rppa_scaled.empty:
        rppa_scaled.columns = [f"{col}_rppa" for col in rppa_scaled.columns]
        final_omics_dfs.append(rppa_scaled)

    if not meth_scaled.empty:
        meth_scaled.columns = [f"{col}_meth" for col in meth_scaled.columns]
        final_omics_dfs.append(meth_scaled)

    if final_omics_dfs:
        X_multiomics = pd.concat(final_omics_dfs, axis=1)
        # Ensure no duplicate columns after concatenation, which can happen if original IDs overlap
        X_multiomics = X_multiomics.loc[:, ~X_multiomics.columns.duplicated()]
    else:
        X_multiomics = pd.DataFrame(index=common_samples)
        print("Warning: All omics dataframes are empty after preprocessing. Multi-omics matrix is empty.")

    y_labels = pheno_aligned[target_col]

    print(f"Final Multi-omics feature matrix shape: {X_multiomics.shape}")
    print(f"Target labels shape: {y_labels.shape}")

    return X_multiomics, y_labels, pheno_aligned, expr_original_aligned_processed, mirna_original_aligned_processed, cnv_original_aligned_processed, rppa_original_aligned_processed, meth_original_aligned_processed

def export_data_to_pkl(dataframe, output_path, name="DataFrame"):
    """
    Exports a DataFrame to a .pkl file.

    Args:
        dataframe (pd.DataFrame or pd.Series): The DataFrame or Series to export.
        output_path (str): The path to save the .pkl file.
        name (str): A descriptive name for the data being saved.
    """
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

    try:
        dataframe.to_pickle(output_path)
        print(f"{name} successfully exported to: {output_path}")
    except Exception as e:
        print(f"Error exporting {name} to .pkl: {e}")


# ==============================================================================
# 2. EDA & Visualization Functions
# ==============================================================================

def generate_summary_statistics(dataframe, name="Data"):
    """
    Generates and prints summary statistics for a given DataFrame.

    Args:
        dataframe (pd.DataFrame): The DataFrame for which to generate statistics.
        name (str): A descriptive name for the DataFrame (e.g., "Gene Expression", "Phenotype").
    """
    print(f"\n--- EDA: Summary Statistics for {name} ---")
    print(dataframe.describe())
    print(f"\nMissing values in {name}:\n{dataframe.isnull().sum().sum()} total missing values.")
    if dataframe.isnull().sum().sum() > 0:
        print(f"Missing values per column:\n{dataframe.isnull().sum()[dataframe.isnull().sum() > 0]}")
    print(f"\nDataFrame Info for {name}:")
    dataframe.info()
    print("-" * (25 + len(name)))

def plot_tumor_type_distribution(phenotype_df, tumor_type_column='_primary_disease', output_path=None):
    """
    Plots the distribution of tumor types from the phenotype DataFrame.

    Args:
        phenotype_df (pd.DataFrame): The phenotype DataFrame.
        tumor_type_column (str): The name of the column containing tumor type information.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if tumor_type_column not in phenotype_df.columns:
        print(f"Error: '{tumor_type_column}' not found in phenotype DataFrame. Cannot plot tumor type distribution.")
        return

    print(f"\n--- EDA: Plotting Tumor Type Distribution ---")
    plt.figure(figsize=(12, 7))
    sns.countplot(y=phenotype_df[tumor_type_column], order=phenotype_df[tumor_type_column].value_counts().index, palette='viridis')
    plt.title(f'Distribution of {tumor_type_column.replace("_", " ").title()}', fontsize=16, weight='bold')
    plt.xlabel('Number of Samples', fontsize=12)
    plt.ylabel(tumor_type_column.replace("_", " ").title(), fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='x', linestyle='--', alpha=0.7)
    plt.tight_layout()

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300)
        print(f"Tumor type distribution plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def perform_pca(data_df, n_components=2):
    """
    Performs Principal Component Analysis (PCA) on the given data.

    Args:
        data_df (pd.DataFrame): DataFrame (samples x features).
        n_components (int): Number of principal components to compute.

    Returns:
        tuple: A tuple containing:
            - pca_result (pd.DataFrame): DataFrame with PCA components.
            - pca_model (PCA): The fitted PCA model.
    """
    print(f"\n--- EDA: Performing PCA with {n_components} components ---")
    # Handle potential NaN values by filling them (e.g., with mean or 0)
    data_df_filled = data_df.fillna(data_df.mean())

    pca = PCA(n_components=n_components)
    principal_components = pca.fit_transform(data_df_filled)
    pca_result = pd.DataFrame(data=principal_components,
                              columns=[f'PC{i+1}' for i in range(n_components)],
                              index=data_df.index)
    print(f"Explained variance ratio by components: {pca.explained_variance_ratio_}")
    print(f"Cumulative explained variance: {np.sum(pca.explained_variance_ratio_)}")
    return pca_result, pca

def plot_pca(pca_result_df, phenotype_df, color_column='_primary_disease', output_path=None):
    """
    Plots the PCA results, colored by a specified phenotype column.

    Args:
        pca_result_df (pd.DataFrame): DataFrame with PCA components.
        phenotype_df (pd.DataFrame): Matched phenotype DataFrame.
        color_column (str): The column in phenotype_df to use for coloring the plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if color_column not in phenotype_df.columns:
        print(f"Error: '{color_column}' not found in phenotype DataFrame. Cannot color PCA plot.")
        return

    # Merge PCA results with phenotype data for coloring
    plot_df = pca_result_df.merge(phenotype_df[[color_column]], left_index=True, right_index=True)

    print(f"\n--- EDA: Plotting PCA results, colored by '{color_column}' ---")
    plt.figure(figsize=(10, 8))
    sns.scatterplot(x='PC1', y='PC2', hue=color_column, data=plot_df,
                    palette='tab20', s=70, alpha=0.8, edgecolor='w', linewidth=0.5)
    plt.title(f'PCA of Multi-Omics Data (Colored by {color_column.replace("_", " ").title()})', fontsize=16, weight='bold')
    plt.xlabel(f'Principal Component 1', fontsize=12)
    plt.ylabel(f'Principal Component 2', fontsize=12)
    plt.legend(title=color_column.replace("_", " ").title(), bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to prevent legend overlap

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300)
        print(f"PCA plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()


def perform_umap(data_df, n_components=2, random_state=42):
    """
    Performs UMAP dimensionality reduction on the given data.

    Args:
        data_df (pd.DataFrame): DataFrame (samples x features).
        n_components (int): Number of dimensions for the UMAP embedding.
        random_state (int): Random seed for reproducibility.

    Returns:
        pd.DataFrame: DataFrame with UMAP components.
    """
    print(f"\n--- EDA: Performing UMAP with {n_components} components ---")
    # Handle potential NaN values by filling them (e.g., with mean or 0)
    data_df_filled = data_df.fillna(data_df.mean())

    reducer = umap.UMAP(n_components=n_components, random_state=random_state)
    umap_embedding = reducer.fit_transform(data_df_filled)
    umap_result = pd.DataFrame(data=umap_embedding,
                               columns=[f'UMAP{i+1}' for i in range(n_components)],
                               index=data_df.index)
    return umap_result

def plot_umap(umap_result_df, phenotype_df, color_column='_primary_disease', output_path=None):
    """
    Plots the UMAP results, colored by a specified phenotype column.

    Args:
        umap_result_df (pd.DataFrame): DataFrame with UMAP components.
        phenotype_df (pd.DataFrame): Matched phenotype DataFrame.
        color_column (str): The column in phenotype_df to use for coloring the plot.
        output_path (str, optional): Path to save the plot. If None, displays the plot.
    """
    if color_column not in phenotype_df.columns:
        print(f"Error: '{color_column}' not found in phenotype DataFrame. Cannot color UMAP plot.")
        return

    # Merge UMAP results with phenotype data for coloring
    plot_df = umap_result_df.merge(phenotype_df[[color_column]], left_index=True, right_index=True)

    print(f"\n--- EDA: Plotting UMAP results, colored by '{color_column}' ---")
    plt.figure(figsize=(10, 8))
    sns.scatterplot(x='UMAP1', y='UMAP2', hue=color_column, data=plot_df,
                    palette='tab20', s=70, alpha=0.8, edgecolor='w', linewidth=0.5)
    plt.title(f'UMAP of Multi-Omics Data (Colored by {color_column.replace("_", " ").title()})', fontsize=16, weight='bold')
    plt.xlabel(f'UMAP Component 1', fontsize=12)
    plt.ylabel(f'UMAP Component 2', fontsize=12)
    plt.legend(title=color_column.replace("_", " ").title(), bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    if output_path:
        output_dir = os.path.dirname(output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")
        plt.savefig(output_path, dpi=300)
        print(f"UMAP plot saved to: {output_path}")
    else:
        plt.show()
    plt.close()

def perform_clustering_and_survival_analysis(X_data, y_labels, pheno_df, output_dir, n_clusters=3):
    """
    Performs K-Means clustering on the data and then conducts Kaplan-Meier survival
    analysis for the identified clusters.
    """
    print(f"\n--- Performing K-Means Clustering (k={n_clusters}) and Survival Analysis ---")

    # Check for empty numpy array using .size
    if X_data.size == 0 or X_data.shape[0] < n_clusters:
        print(f"Insufficient data for clustering with {n_clusters} clusters. Skipping.")
        return

    try:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) # n_init for robust centroid initialization
        cluster_labels = kmeans.fit_predict(X_data)
        print(f"K-Means clustering complete. Found {n_clusters} clusters.")

        # Add cluster labels to phenotype data
        pheno_with_clusters = pheno_df.copy()
        # Ensure the index of pheno_with_clusters matches X_data for Series creation
        # If X_data is a numpy array, its index is not directly available. Use pheno_df.index.
        pheno_with_clusters['Cluster'] = pd.Series(cluster_labels, index=pheno_df.index)


        # Prepare survival data
        # Common TCGA survival columns are '_OS_TIME' (event indicator) and '_OS_IND' (time in days)
        survival_time_col = '_OS_TIME' # Overall Survival Time in days
        survival_event_col = '_OS_IND' # Overall Survival Indicator (1=dead, 0=alive/censored)

        # Check if the required survival columns exist in the phenotype data
        if survival_time_col not in pheno_with_clusters.columns or survival_event_col not in pheno_with_clusters.columns:
            print(f"Survival data columns ('{survival_time_col}', '{survival_event_col}') not found in phenotype data.")
            print("Skipping survival analysis. Please ensure your phenotype file contains these columns or update the column names in the script.")
            return

        # Convert survival status to boolean (True for event, False for censored)
        # Ensure that 'OS_IND' is treated as 1 for event, 0 for censored.
        # Sometimes it might be 'Death'/'Alive' or other strings, so convert to int first.
        pheno_with_clusters['Event'] = pheno_with_clusters[survival_event_col].astype(int) == 1
        pheno_with_clusters['Time'] = pheno_with_clusters[survival_time_col].astype(float)

        # Drop rows with NaN in survival data
        pheno_with_clusters = pheno_with_clusters.dropna(subset=['Time', 'Event', 'Cluster'])
        
        if pheno_with_clusters.empty:
            print("No valid survival data after cleaning. Skipping survival analysis.")
            return

        # Kaplan-Meier Plotting
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(10, 7))

        cluster_groups = pheno_with_clusters.groupby('Cluster')
        for name, group in cluster_groups:
            if len(group) > 1: # Need at least 2 samples to plot a curve
                kmf.fit(group['Time'], event_observed=group['Event'], label=f'Cluster {name}')
                kmf.plot_survival_function(ax=plt.gca())
            else:
                print(f"Cluster {name} has only {len(group)} sample(s), skipping KM plot for this cluster.")

        plt.title(f"Kaplan-Meier Survival Curves by K-Means Clusters (k={n_clusters})", fontsize=16, weight='bold', color='darkblue')
        plt.xlabel("Time (days)", fontsize=12, color='dimgray')
        plt.ylabel("Survival Probability", fontsize=12, color='dimgray')
        plt.grid(True, linestyle=':', alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"survival_curves_kmeans_k{n_clusters}.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Kaplan-Meier Survival Curves plot saved for {n_clusters} clusters.")

        # Log-rank test for all pairs of clusters
        print("\n--- Log-Rank Test Results (P-values between clusters) ---")
        cluster_ids = sorted(pheno_with_clusters['Cluster'].unique())
        if len(cluster_ids) >= 2:
            for i in range(len(cluster_ids)):
                for j in range(i + 1, len(cluster_ids)):
                    cluster1_data = pheno_with_clusters[pheno_with_clusters['Cluster'] == cluster_ids[i]]
                    cluster2_data = pheno_with_clusters[pheno_with_clusters['Cluster'] == cluster_ids[j]]

                    if len(cluster1_data) > 1 and len(cluster2_data) > 1:
                        results = logrank_test(cluster1_data['Time'], cluster2_data['Time'],
                                               event_observed_A=cluster1_data['Event'],
                                               event_observed_B=cluster2_data['Event'])
                        print(f"Cluster {cluster_ids[i]} vs Cluster {cluster_ids[j]}: p-value = {results.p_value:.4f}")
                    else:
                        print(f"Skipping log-rank test for Cluster {cluster_ids[i]} vs Cluster {cluster_ids[j]} due to insufficient samples.")
        else:
            print("Fewer than 2 clusters with sufficient samples for log-rank test.")

    except Exception as e:
        print(f"Error during clustering or survival analysis: {e}")
        import traceback
        traceback.print_exc()


# ==============================================================================
# Main Execution Block for Multi-Omics Workflow (EDA Focus)
# ==============================================================================
if __name__ == "__main__":
    print("Starting Multi-Omics Workflow (EDA Focus)...")

    # --- Configuration ---
    # IMPORTANT: Update these paths to your actual raw data files
    data_dir = r"C:\Users\shrav\Desktop\PYTHON\Cancer\Pan Cancer Analysis"
    raw_expr_file = os.path.join(data_dir, "EB++AdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena")
    raw_maf_file = os.path.join(data_dir, "mc3.v0.2.8.PUBLIC.maf")
    raw_pheno_file = os.path.join(data_dir, "TCGA_phenotype_denseDataOnlyDownload.tsv")
    raw_mart_file = os.path.join(data_dir, "mart_export.txt") # BioMart mapping file
    raw_mirna_file = os.path.join(data_dir, "pancanMiRs_EBadjOnProtocolPlatformWithoutRepsWithUnCorrectMiRs_08_04_16.xena")
    raw_cnv_file = os.path.join(data_dir, "Gistic2_CopyNumber_Gistic2_all_thresholded.by_genes")
    raw_rppa_file = os.path.join(data_dir, "TCGA-RPPA-pancan-clean.xena")
    raw_meth_file = os.path.join(data_dir, "jhu-usc.edu_PANCAN_HumanMethylation450.betaValue_whitelisted.tsv.synapse_download_5096262.xena")
    
    # Output directories for processed data and plots
    output_dir = "multi_omics_results_advanced" # This is the main output directory
    os.makedirs(output_dir, exist_ok=True)

    # Define the target column for cancer type in phenotype data
    target_column = '_primary_disease'

    # ==========================================================================
    # PHASE 1: DATA PREPARATION (Advanced Multi-Omics Loading and Preprocessing)
    # ==========================================================================
    print("\n" + "="*50)
    print("EXECUTING PHASE 1: DATA PREPARATION (Advanced Multi-Omics)")
    print("="*50)
    
    X_multiomics_features = None
    y_multiomics_labels = None
    pheno_aligned_df = None
    expr_original_aligned_processed = None
    mirna_original_aligned_processed = None
    cnv_original_aligned_processed = None
    rppa_original_aligned_processed = None
    meth_original_aligned_processed = None

    try:
        # Check if all raw data files exist
        required_files = {
            raw_expr_file: "Gene Expression",
            raw_maf_file: "Somatic Mutation",
            raw_pheno_file: "Phenotype",
            raw_mart_file: "BioMart Mapping",
            raw_mirna_file: "miRNA Expression",
            raw_cnv_file: "Copy Number Variation",
            raw_rppa_file: "RPPA Expression",
            raw_meth_file: "Methylation Beta Values"
        }
        all_files_exist = True
        for f_path, f_name in required_files.items():
            if not os.path.exists(f_path):
                print(f"Error: Required file not found for {f_name}: {f_path}")
                all_files_exist = False
        
        if all_files_exist:
            (X_multiomics_features, y_multiomics_labels, pheno_aligned_df,
             expr_original_aligned_processed, mirna_original_aligned_processed,
             cnv_original_aligned_processed, rppa_original_aligned_processed,
             meth_original_aligned_processed) = load_and_preprocess_multi_omics_data_advanced(
                raw_expr_file, raw_maf_file, raw_pheno_file, raw_mart_file,
                raw_mirna_file, raw_cnv_file, raw_rppa_file, raw_meth_file, target_column
            )

            if X_multiomics_features is not None and not X_multiomics_features.empty:
                # Save the processed data for Phase 8
                export_data_to_pkl(X_multiomics_features, os.path.join(output_dir, 'X_multiomics_features.pkl'), 'X_multiomics_features')
                export_data_to_pkl(y_multiomics_labels, os.path.join(output_dir, 'y_multiomics_labels.pkl'), 'y_multiomics_labels')
                export_data_to_pkl(pheno_aligned_df, os.path.join(output_dir, 'pheno_aligned_df.pkl'), 'pheno_aligned_df')
                print("\nProcessed multi-omics data saved to .pkl files for subsequent phases.")
            else:
                print("Multi-omics data processing resulted in empty features. Skipping saving.")
        else:
            print("\nSkipping multi-omics data preparation due to missing raw input files.")
            print("Please ensure all required raw data files are in the specified 'data_dir'.")
            # Generate dummy data for EDA demonstration if raw files are missing
            print("\nGenerating dummy data for EDA demonstration...")
            np.random.seed(42)
            num_samples = 500
            num_features = 1000 # A reasonable number for multi-omics
            
            X_multiomics_features = pd.DataFrame(np.random.rand(num_samples, num_features),
                                                 index=[f'sample_{i}' for i in range(num_samples)],
                                                 columns=[f'feature_{j}' for j in range(num_features)])
            
            cancer_types = ['BRCA', 'LUAD', 'COAD', 'KIRC', 'LIHC', 'STAD', 'THCA', 'OV', 'LGG', 'SKCM']
            y_multiomics_labels = pd.Series(np.random.choice(cancer_types, num_samples),
                                            index=X_multiomics_features.index)
            
            pheno_aligned_df = pd.DataFrame({
                '_primary_disease': y_multiomics_labels,
                'gender': np.random.choice(['Male', 'Female'], num_samples),
                'age_at_diagnosis': np.random.randint(30, 80, num_samples),
                '_OS_TIME': np.random.randint(100, 3000, num_samples), # Dummy survival time
                '_OS_IND': np.random.choice([0, 1], num_samples, p=[0.7, 0.3]) # Dummy survival indicator
            }, index=X_multiomics_features.index)
            
            print("Dummy multi-omics data generated for EDA.")

    except Exception as e:
        print(f"Error during Advanced Data Preparation: {e}")
        print("Exiting workflow as data preparation is crucial.")
        exit()

    # ==========================================================================
    # PHASE 2: EDA & VISUALIZATION
    # ==========================================================================
    print("\n" + "="*50)
    print("EXECUTING PHASE 2: EDA & VISUALIZATION")
    print("="*50)

    # Check if data was successfully loaded or generated before proceeding with EDA
    if X_multiomics_features is not None and not X_multiomics_features.empty and \
       y_multiomics_labels is not None and not y_multiomics_labels.empty and \
       pheno_aligned_df is not None and not pheno_aligned_df.empty:
        try:
            # --- Extensive EDA and Data Inspection ---
            print("\n--- Detailed EDA of Loaded Multi-Omics Data ---")

            print("\nOriginal Multi-omics Feature Matrix (X_multiomics_features):")
            print(f"Shape: {X_multiomics_features.shape}")
            print("First 5 rows:")
            print(X_multiomics_features.head())
            print("\nDescriptive statistics for X_multiomics_features (first 50 columns):")
            print(X_multiomics_features.iloc[:, :50].describe()) # Show for a subset of columns
            print("\nMissing values in X_multiomics_features:")
            print(X_multiomics_features.isnull().sum().sum())
            
            print("\nOriginal Target Labels (y_multiomics_labels):")
            print(f"Shape: {y_multiomics_labels.shape}")
            print("Value counts:")
            print(y_multiomics_labels.value_counts().head(10)) # Show top 10 cancer types
            
            print("\nAligned Phenotype Data (pheno_aligned_df):")
            print(f"Shape: {pheno_aligned_df.shape}")
            print("First 5 rows:")
            print(pheno_aligned_df.head())
            print("\nDescriptive statistics for pheno_aligned_df:")
            print(pheno_aligned_df.describe(include='all')) # Include non-numeric columns
            print("\nMissing values in pheno_aligned_df:")
            print(pheno_aligned_df.isnull().sum())

            # Filter out classes with only one sample before encoding
            class_counts = y_multiomics_labels.value_counts()
            classes_to_keep = class_counts[class_counts >= 2].index
            
            if len(classes_to_keep) < 2:
                print("After filtering, fewer than 2 cancer types remain with at least 2 samples. Skipping further analysis.")
            else:
                filtered_samples_mask = y_multiomics_labels.isin(classes_to_keep)
                X_filtered = X_multiomics_features[filtered_samples_mask]
                y_filtered = y_multiomics_labels[filtered_samples_mask]
                pheno_filtered = pheno_aligned_df[filtered_samples_mask] # Keep phenotype aligned for survival analysis

                # Remove duplicate columns from X_filtered before analysis
                initial_cols_X_filtered = X_filtered.shape[1]
                X_filtered = X_filtered.loc[:, ~X_filtered.columns.duplicated()]
                if X_filtered.shape[1] < initial_cols_X_filtered:
                    print(f"Removed {initial_cols_X_filtered - X_filtered.shape[1]} duplicate columns from X_filtered.")

                print("\n--- Detailed EDA of Filtered Data (after removing single-sample classes) ---")
                print(f"Filtered Multi-omics Feature Matrix (X_filtered) shape: {X_filtered.shape}")
                print("First 5 rows of X_filtered:")
                print(X_filtered.head())
                print("\nDescriptive statistics for X_filtered (first 50 columns):")
                print(X_filtered.iloc[:, :50].describe())

                print(f"\nFiltered Target Labels (y_filtered) shape: {y_filtered.shape}")
                print("Value counts for y_filtered:")
                print(y_filtered.value_counts().head(10))

                print(f"\nFiltered Phenotype Data (pheno_filtered) shape: {pheno_filtered.shape}")
                print("First 5 rows of pheno_filtered:")
                print(pheno_filtered.head())
                print("\nDescriptive statistics for pheno_filtered:")
                print(pheno_filtered.describe(include='all'))


                le = LabelEncoder()
                y_encoded = le.fit_transform(y_filtered)
                class_names = le.classes_
                print(f"\nEncoded {len(class_names)} cancer types for EDA.")

                # --- Dimensionality Reduction: PCA ---
                n_components_pca = min(50, X_filtered.shape[1])
                if X_filtered.shape[1] > 1 and n_components_pca > 1:
                    pca = PCA(n_components=n_components_pca, random_state=42)
                    X_pca = pca.fit_transform(X_filtered)
                    # Ensure X_pca has an index for consistency with pheno_filtered
                    X_pca_df = pd.DataFrame(X_pca, index=X_filtered.index)
                    print(f"\nPCA reduced data shape: {X_pca.shape}")
                    print(f"PCA explained variance ratio (first {n_components_pca} components): {np.sum(pca.explained_variance_ratio_):.2f}")
                    print("First 5 rows of PCA reduced data (X_pca_df):")
                    print(X_pca_df.head())
                else:
                    print("Not enough features for PCA. Skipping PCA.")
                    X_pca = X_filtered.values # Keep as numpy array if PCA skipped
                    X_pca_df = X_filtered # Keep as DataFrame if PCA skipped

                # --- Dimensionality Reduction: UMAP ---
                if X_pca.shape[1] >= 2: # Use X_pca (numpy array) directly for UMAP
                    reducer = umap.UMAP(n_components=2, random_state=42, metric='euclidean')
                    X_umap = reducer.fit_transform(X_pca)
                    print(f"\nUMAP reduced data shape: {X_umap.shape}")

                    plt.figure(figsize=(12, 10))
                    sns.scatterplot(x=X_umap[:, 0], y=X_umap[:, 1], hue=y_filtered, palette='tab20',
                                    s=80, alpha=0.85, edgecolor='black', linewidth=0.7)
                    plt.title("UMAP Projection of Multi-Omics Data (Integrated Features)", fontsize=18, weight='bold', color='darkblue')
                    plt.xlabel("UMAP Component 1", fontsize=14, color='dimgray')
                    plt.ylabel("UMAP Component 2", fontsize=14, color='dimgray')
                    plt.xticks(fontsize=12)
                    plt.yticks(fontsize=12)
                    plt.legend(title="Cancer Type", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize='small', title_fontsize=12)
                    plt.grid(True, linestyle=':', alpha=0.5)
                    plt.tight_layout(rect=[0, 0, 0.85, 1])
                    plt.savefig(os.path.join(output_dir, "multiomics_umap_projection_advanced.png"), dpi=300, bbox_inches='tight')
                    plt.close()
                    print("UMAP projection plot saved.")
                else:
                    print("Not enough dimensions after PCA for UMAP. Skipping UMAP plot.")

                # Perform Clustering and Survival Analysis
                # Pass the PCA-reduced data (NumPy array) and original phenotype DataFrame
                perform_clustering_and_survival_analysis(X_pca, y_filtered, pheno_filtered, output_dir, n_clusters=3) # You can adjust n_clusters

                print("\nEDA and Visualization steps completed.")
            
        except Exception as e:
            print(f"Error during EDA & Visualization: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Multi-omics data was not loaded or generated successfully. Skipping EDA & Visualization steps.")

    print("\nMulti-Omics Workflow (EDA Focus) execution complete.")


Starting Multi-Omics Workflow (EDA Focus)...

EXECUTING PHASE 1: DATA PREPARATION (Advanced Multi-Omics)

--- Multi-Omics Data Loading and Preprocessing (Advanced) ---
Attempting to load methylation data in chunks from: C:\Users\shrav\Desktop\PYTHON\Cancer\Pan Cancer Analysis\jhu-usc.edu_PANCAN_HumanMethylation450.betaValue_whitelisted.tsv.synapse_download_5096262.xena
  Processed chunk 1, current chunk shape: (50000, 9664)
  Processed chunk 2, current chunk shape: (50000, 9664)
  Processed chunk 3, current chunk shape: (50000, 9664)
  Processed chunk 4, current chunk shape: (50000, 9664)
  Processed chunk 5, current chunk shape: (50000, 9664)
  Processed chunk 6, current chunk shape: (50000, 9664)
  Processed chunk 7, current chunk shape: (50000, 9664)
  Processed chunk 8, current chunk shape: (46065, 9664)
Identified top 1000 most variable methylation probes from 9664 total probes.
Methylation data re-loaded with only top 1000 variable probes. Final shape: (396065, 1000)
Raw expressi

Input sequence provided is already in string format. No operation performed


BioMart mapping for Gene Expression failed or not available. Attempting mygene mapping (Entrez/Symbol to Ensembl)...


19 input query terms found dup hits:	[('ABCA11P', 2), ('ABCA17P', 2), ('ABCC13', 2), ('ABCC6P1', 2), ('ABCC6P2', 3), ('ADAM21P1', 2), ('A
82 input query terms found no hit:	['100130426', '100133144', '10431', '136542', '317712', '391343', '553137', '57714', '645851', '6529
Input sequence provided is already in string format. No operation performed
6 input query terms found dup hits:	[('ATP6AP1L', 2), ('ATP8B5P', 2), ('BAGE2', 2), ('BIRC8', 2), ('BMS1P4', 2), ('BRD7P3', 2)]
338 input query terms found no hit:	['ARMC4', 'ARNTL2', 'ARNTL', 'ARPM1', 'ARSE', 'ASAM', 'ASAP1IT1', 'ASFMR1', 'ASNA1', 'ATHL1', 'ATP5A
Input sequence provided is already in string format. No operation performed
1 input query terms found dup hits:	[('C3P1', 2)]
668 input query terms found no hit:	['C17orf103', 'C17orf104', 'C17orf105', 'C17orf106', 'C17orf108', 'C17orf28', 'C17orf37', 'C17orf39'
Input sequence provided is already in string format. No operation performed
10 input query terms found dup hits:	[('CAST',

Successfully mapped 0 out of 20531 terms for scopes '['entrezgene', 'symbol']'.
Gene Expression data shape (original IDs): (11069, 20531)
Mutation records after dropping NaNs and duplicates: (2378187, 2)
Attempting BioMart mapping for Mutation data (Hugo Symbol to Ensembl)...
BioMart mapped 18862 mutation genes.
Mutation records after mapping Hugo_Symbol to Ensembl: (2228644, 3)
Binary mutation matrix shape: (9104, 18862)
Initial miRNA Expression data shape (samples x miRNAs): (10824, 743)
First 5 miRNA IDs: ['hsa-let-7a-2-3p', 'hsa-let-7a-3p', 'hsa-let-7a-5p', 'hsa-let-7b-3p', 'hsa-let-7b-5p']
Initial CNV data shape (samples x genes): (10845, 24776)
First 5 CNV gene IDs: ['ACAP3', 'ACTRT2', 'AGRN', 'ANKRD65', 'ATAD3A']
Attempting BioMart mapping for CNV data (Hugo Symbol to Ensembl)...
BioMart mapped 21521 CNV genes.
CNV data after mapping to Ensembl: (10845, 21521)
Initial RPPA data shape (samples x proteins): (7754, 258)
First 5 RPPA protein IDs: ['X1433EPSILON', 'X4EBP1', 'X4EBP1_p

  result = func(self.values, **kwargs)
  result = func(self.values, **kwargs)
  sqr = _ensure_numeric((avg - values) ** 2)


Gene Expression data log2 transformed.
Removed 302 zero-variance features from Gene Expression data.
Gene Expression data imputed. Missing values: 0
Removed 6025 zero-variance features from Mutation data.
Mutation data imputed. Missing values: 0


  result = func(self.values, **kwargs)


miRNA Expression data log2 transformed.
Removed 1 zero-variance features from miRNA Expression data.
miRNA Expression data imputed. Missing values: 0
CNV data imputed. Missing values: 0
RPPA data imputed. Missing values: 0
Methylation data imputed. Missing values: 0
All omics data imputed and scaled (if columns available).
Final Multi-omics feature matrix shape: (568, 451633)
Target labels shape: (568,)
X_multiomics_features successfully exported to: multi_omics_results_advanced\X_multiomics_features.pkl
y_multiomics_labels successfully exported to: multi_omics_results_advanced\y_multiomics_labels.pkl
pheno_aligned_df successfully exported to: multi_omics_results_advanced\pheno_aligned_df.pkl

Processed multi-omics data saved to .pkl files for subsequent phases.

EXECUTING PHASE 2: EDA & VISUALIZATION

--- Detailed EDA of Loaded Multi-Omics Data ---

Original Multi-omics Feature Matrix (X_multiomics_features):
Shape: (568, 451633)
First 5 rows:
                 100130426_expr  10013314

  warn(



UMAP reduced data shape: (565, 2)
UMAP projection plot saved.

--- Performing K-Means Clustering (k=3) and Survival Analysis ---
K-Means clustering complete. Found 3 clusters.
Survival data columns ('_OS_TIME', '_OS_IND') not found in phenotype data.
Skipping survival analysis. Please ensure your phenotype file contains these columns or update the column names in the script.

EDA and Visualization steps completed.

Multi-Omics Workflow (EDA Focus) execution complete.


In [None]:
# ==============================================================================
# Phase 1.8: Multi-Omics Machine Learning and Advanced Analysis
# This script performs dimensionality reduction (PCA, UMAP), trains multiple
# machine learning models for cancer type classification, and includes clustering
# with survival analysis. It is designed to be run AFTER data preparation (Phase 7).
#
# Before running:
# 1. Ensure you have necessary libraries installed:
#    pip install pandas numpy matplotlib seaborn scikit-learn umap-learn lightgbm lifelines
# 2. This script EXPLICITLY LOADS 'X_multiomics_features', 'y_multiomics_labels',
#    and 'pheno_aligned_df' from the 'multi_omics_results_advanced' directory.
#    You MUST run the "Phase 7: Multi-Omics Data Preparation & EDA" Canvas first
#    to generate and save these files.
# ==============================================================================

In [None]:


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import umap.umap_ as umap # For UMAP visualization

from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC # Support Vector Classifier
from sklearn.neural_network import MLPClassifier # Multi-layer Perceptron (simple NN)
import lightgbm as lgb # Light Gradient Boosting Machine
from sklearn.cluster import KMeans # For clustering

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# For survival analysis
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test

# Set a consistent style for all plots for a professional look
plt.style.use('seaborn-v0_8-darkgrid')

# --- Configuration ---
# This output directory should match where the data prep script saves its output
output_dir = "multi_omics_results_advanced"
os.makedirs(output_dir, exist_ok=True)

# ==============================================================================
# Helper Functions
# ==============================================================================

def perform_clustering_and_survival_analysis(X_data, y_labels, pheno_df, output_dir, n_clusters=3):
    """
    Performs K-Means clustering on the data and then conducts Kaplan-Meier survival
    analysis for the identified clusters.
    """
    print(f"\n--- Performing K-Means Clustering (k={n_clusters}) and Survival Analysis ---")

    # Check for empty numpy array using .size
    if X_data.size == 0 or X_data.shape[0] < n_clusters:
        print(f"Insufficient data for clustering with {n_clusters} clusters. Skipping.")
        return

    try:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) # n_init for robust centroid initialization
        cluster_labels = kmeans.fit_predict(X_data)
        print(f"K-Means clustering complete. Found {n_clusters} clusters.")

        # Add cluster labels to phenotype data
        pheno_with_clusters = pheno_df.copy()
        # Ensure the index of pheno_with_clusters matches X_data for Series creation
        # If X_data is a numpy array, its index is not directly available. Use pheno_df.index.
        pheno_with_clusters['Cluster'] = pd.Series(cluster_labels, index=pheno_df.index)


        # Prepare survival data
        # Common TCGA survival columns are '_OS_TIME' (event indicator) and '_OS_IND' (time in days)
        survival_time_col = '_OS_TIME' # Overall Survival Time in days
        survival_event_col = '_OS_IND' # Overall Survival Indicator (1=dead, 0=alive/censored)

        # Check if the required survival columns exist in the phenotype data
        if survival_time_col not in pheno_with_clusters.columns or survival_event_col not in pheno_with_clusters.columns:
            print(f"Survival data columns ('{survival_time_col}', '{survival_event_col}') not found in phenotype data.")
            print("Skipping survival analysis. Please ensure your phenotype file contains these columns or update the column names in the script.")
            return

        # Convert survival status to boolean (True for event, False for censored)
        # Ensure that 'OS_IND' is treated as 1 for event, 0 for censored.
        # Sometimes it might be 'Death'/'Alive' or other strings, so convert to int first.
        pheno_with_clusters['Event'] = pheno_with_clusters[survival_event_col].astype(int) == 1
        pheno_with_clusters['Time'] = pheno_with_clusters[survival_time_col].astype(float)

        # Drop rows with NaN in survival data
        pheno_with_clusters = pheno_with_clusters.dropna(subset=['Time', 'Event', 'Cluster'])
        
        if pheno_with_clusters.empty:
            print("No valid survival data after cleaning. Skipping survival analysis.")
            return

        # Kaplan-Meier Plotting
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(10, 7))

        cluster_groups = pheno_with_clusters.groupby('Cluster')
        for name, group in cluster_groups:
            if len(group) > 1: # Need at least 2 samples to plot a curve
                kmf.fit(group['Time'], event_observed=group['Event'], label=f'Cluster {name}')
                kmf.plot_survival_function(ax=plt.gca())
            else:
                print(f"Cluster {name} has only {len(group)} sample(s), skipping KM plot for this cluster.")

        plt.title(f"Kaplan-Meier Survival Curves by K-Means Clusters (k={n_clusters})", fontsize=16, weight='bold', color='darkblue')
        plt.xlabel("Time (days)", fontsize=12, color='dimgray')
        plt.ylabel("Survival Probability", fontsize=12, color='dimgray')
        plt.grid(True, linestyle=':', alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"survival_curves_kmeans_k{n_clusters}.png"), dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Kaplan-Meier Survival Curves plot saved for {n_clusters} clusters.")

        # Log-rank test for all pairs of clusters
        print("\n--- Log-Rank Test Results (P-values between clusters) ---")
        cluster_ids = sorted(pheno_with_clusters['Cluster'].unique())
        if len(cluster_ids) >= 2:
            for i in range(len(cluster_ids)):
                for j in range(i + 1, len(cluster_ids)):
                    cluster1_data = pheno_with_clusters[pheno_with_clusters['Cluster'] == cluster_ids[i]]
                    cluster2_data = pheno_with_clusters[pheno_with_clusters['Cluster'] == cluster_ids[j]]

                    if len(cluster1_data) > 1 and len(cluster2_data) > 1:
                        results = logrank_test(cluster1_data['Time'], cluster2_data['Time'],
                                               event_observed_A=cluster1_data['Event'],
                                               event_observed_B=cluster2_data['Event'])
                        print(f"Cluster {cluster_ids[i]} vs Cluster {cluster_ids[j]}: p-value = {results.p_value:.4f}")
                    else:
                        print(f"Skipping log-rank test for Cluster {cluster_ids[i]} vs Cluster {cluster_ids[j]} due to insufficient samples.")
        else:
            print("Fewer than 2 clusters with sufficient samples for log-rank test.")

    except Exception as e:
        print(f"Error during clustering or survival analysis: {e}")
        import traceback
        traceback.print_exc()


# ==============================================================================
# Main Execution Block for Multi-Omics ML and Advanced Analysis
# ==============================================================================
if __name__ == "__main__":
    print("Starting Multi-Omics Machine Learning and Advanced Analysis Phase...")

    # Initialize variables to None or empty before the try block
    X_multiomics_features = None
    y_multiomics_labels = None
    pheno_aligned_df = None
    data_loaded_successfully = False

    # --- Load preprocessed data from files ---
    # These files are expected to be generated by the "Phase 7: Data Prep" Canvas.
    try:
        print(f"Attempting to load preprocessed data from '{output_dir}'...")
        X_multiomics_features = pd.read_pickle(os.path.join(output_dir, 'X_multiomics_features.pkl'))
        y_multiomics_labels = pd.read_pickle(os.path.join(output_dir, 'y_multiomics_labels.pkl'))
        pheno_aligned_df = pd.read_pickle(os.path.join(output_dir, 'pheno_aligned_df.pkl'))
        print("Preprocessed data loaded successfully.")
        data_loaded_successfully = True

    except FileNotFoundError as e:
        print(f"Error loading data: {e}. One or more required .pkl files were not found.")
        print(f"Please ensure you have run the 'Phase 7: Multi-Omics Data Preparation & EDA' Canvas first to generate these files in '{output_dir}'.")
    except Exception as e:
        print(f"An unexpected error occurred during data loading: {e}")
        import traceback
        traceback.print_exc()

    # Proceed with analysis only if data was successfully loaded
    if data_loaded_successfully and \
       X_multiomics_features is not None and not X_multiomics_features.empty and \
       y_multiomics_labels is not None and not y_multiomics_labels.empty and \
       pheno_aligned_df is not None and not pheno_aligned_df.empty:

        # Filter out classes with only one sample before encoding and splitting
        class_counts = y_multiomics_labels.value_counts()
        classes_to_keep = class_counts[class_counts >= 2].index
        
        if len(classes_to_keep) < 2:
            print("After filtering, fewer than 2 cancer types remain with at least 2 samples. Skipping ML models.")
        else:
            filtered_samples_mask = y_multiomics_labels.isin(classes_to_keep)
            X_filtered = X_multiomics_features[filtered_samples_mask]
            y_filtered = y_multiomics_labels[filtered_samples_mask]
            pheno_filtered = pheno_aligned_df[filtered_samples_mask] # Keep phenotype aligned for survival analysis

            # Remove duplicate columns from X_filtered before training
            initial_cols_X_filtered = X_filtered.shape[1]
            X_filtered = X_filtered.loc[:, ~X_filtered.columns.duplicated()]
            if X_filtered.shape[1] < initial_cols_X_filtered:
                print(f"Removed {initial_cols_X_filtered - X_filtered.shape[1]} duplicate columns from X_filtered.")


            print(f"Original samples: {len(y_multiomics_labels)}, Samples after filtering single-member classes: {len(y_filtered)}")
            print(f"Original unique classes: {y_filtered.nunique()}, Classes after filtering: {y_filtered.nunique()}")

            le = LabelEncoder()
            y_encoded = le.fit_transform(y_filtered)
            class_names = le.classes_
            print(f"Encoded {len(class_names)} cancer types for ML.")

            # --- Dimensionality Reduction: PCA ---
            n_components_pca = min(50, X_filtered.shape[1])
            if X_filtered.shape[1] > 1 and n_components_pca > 1:
                pca = PCA(n_components=n_components_pca, random_state=42)
                X_pca = pca.fit_transform(X_filtered)
                # Ensure X_pca has an index for consistency with pheno_filtered
                X_pca_df = pd.DataFrame(X_pca, index=X_filtered.index)
                print(f"\nPCA reduced data shape: {X_pca.shape}")
                print(f"PCA explained variance ratio (first {n_components_pca} components): {np.sum(pca.explained_variance_ratio_):.2f}")
            else:
                print("Not enough features for PCA. Skipping PCA.")
                X_pca = X_filtered.values # Keep as numpy array if PCA skipped
                X_pca_df = X_filtered # Keep as DataFrame if PCA skipped

            # --- Dimensionality Reduction: UMAP ---
            if X_pca.shape[1] >= 2: # Use X_pca (numpy array) directly for UMAP
                reducer = umap.UMAP(n_components=2, random_state=42, metric='euclidean')
                X_umap = reducer.fit_transform(X_pca)
                print(f"\nUMAP reduced data shape: {X_umap.shape}")

                plt.figure(figsize=(12, 10))
                sns.scatterplot(x=X_umap[:, 0], y=X_umap[:, 1], hue=y_filtered, palette='tab20',
                                s=80, alpha=0.85, edgecolor='black', linewidth=0.7)
                plt.title("UMAP Projection of Multi-Omics Data (Integrated Features)", fontsize=18, weight='bold', color='darkblue')
                plt.xlabel("UMAP Component 1", fontsize=14, color='dimgray')
                plt.ylabel("UMAP Component 2", fontsize=14, color='dimgray')
                plt.xticks(fontsize=12)
                plt.yticks(fontsize=12)
                plt.legend(title="Cancer Type", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0, fontsize='small', title_fontsize=12)
                plt.grid(True, linestyle=':', alpha=0.5)
                plt.tight_layout(rect=[0, 0, 0.85, 1])
                plt.savefig(os.path.join(output_dir, "multiomics_umap_projection_advanced.png"), dpi=300, bbox_inches='tight')
                plt.close()
                print("UMAP projection plot saved.")
            else:
                print("Not enough dimensions after PCA for UMAP. Skipping UMAP plot.")

            # Perform Clustering and Survival Analysis
            # Pass the PCA-reduced data (NumPy array) and original phenotype DataFrame
            perform_clustering_and_survival_analysis(X_pca, y_filtered, pheno_filtered, output_dir, n_clusters=3) # You can adjust n_clusters


            # --- Machine Learning: Train and Evaluate Multiple Classifiers ---
            if X_filtered.shape[0] > 1 and X_filtered.shape[1] > 0 and len(np.unique(y_encoded)) > 1:
                X_train, X_test, y_train, y_test = train_test_split(
                    X_filtered, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
                )
                print(f"\nTraining data shape: {X_train.shape}, Test data shape: {X_test.shape}")

                classifiers = {
                    "Random Forest": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),
                    "Support Vector Machine (SVC)": SVC(random_state=42, probability=True),
                    "LightGBM Classifier": lgb.LGBMClassifier(random_state=42, n_jobs=-1),
                    "MLP Classifier": MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42, verbose=True)
                }

                for name, clf in classifiers.items():
                    print(f"\n--- Training {name} ---")
                    try:
                        clf.fit(X_train, y_train)
                        y_pred = clf.predict(X_test)

                        # Get unique labels present in y_test and y_pred
                        unique_labels_in_test_pred = np.unique(np.concatenate((y_test, y_pred)))
                        # Filter class_names to match the unique labels present
                        display_class_names = [class_names[i] for i in unique_labels_in_test_pred]

                        print(f"📊 Classification Report for {name}:")
                        print(classification_report(y_test, y_pred, labels=unique_labels_in_test_pred, target_names=display_class_names, zero_division=0))

                        cm = confusion_matrix(y_test, y_pred, normalize='true', labels=unique_labels_in_test_pred)
                        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_class_names)

                        fig, ax = plt.subplots(figsize=(16, 16))
                        disp.plot(cmap='Blues', ax=ax, colorbar=False, xticks_rotation='vertical')
                        plt.title(f"Normalized Confusion Matrix ({name})", fontsize=18, weight='bold', color='darkblue')
                        plt.xlabel("Predicted Label", fontsize=14, color='dimgray')
                        plt.ylabel("True Label", fontsize=14, color='dimgray')
                        plt.xticks(fontsize=10)
                        plt.yticks(fontsize=10)
                        plt.tight_layout()
                        plt.savefig(os.path.join(output_dir, f"multiomics_confusion_matrix_{name.replace(' ', '_').replace('(', '').replace(')', '')}.png"), dpi=300, bbox_inches='tight')
                        plt.close()
                        print(f"Normalized Confusion Matrix plot for {name} saved.")

                        if name == "MLP Classifier" and hasattr(clf, 'loss_curve_'):
                            plt.figure(figsize=(10, 6))
                            plt.plot(clf.loss_curve_)
                            plt.title("MLP Classifier Training Loss Curve", fontsize=18, weight='bold', color='darkblue')
                            plt.xlabel("Iteration", fontsize=14, color='dimgray')
                            plt.ylabel("Loss", fontsize=14, color='dimgray')
                            plt.grid(True, linestyle='--', alpha=0.6)
                            plt.tight_layout()
                            plt.savefig(os.path.join(output_dir, "mlp_classifier_loss_curve.png"), dpi=300, bbox_inches='tight')
                            plt.close()
                            print("MLP Classifier loss curve plot saved.")

                    except Exception as e:
                        print(f"Error training or evaluating {name}: {e}")
                        import traceback
                        traceback.print_exc()
            else:
                print("Insufficient data or classes for Machine Learning. Skipping model training and evaluation.")
    else:
        print("Multi-omics data was not loaded successfully. Skipping ML analysis steps.")

    print("\nMulti-Omics Machine Learning & Advanced Analysis Phase complete.")
