! This code is pasted from IBD_pathway_to_cell from similarity_mvp and haven't been run !

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql.functions import collect_list, concat_ws, col, when, udf, row_number, sum as spark_sum, max as spark_max, create_map, lit, min as spark_min
from pyspark.sql.types import DoubleType
from pyspark.sql import Window
from itertools import chain
from pyspark.sql import DataFrame
from pyspark.sql import Row
import pandas as pd
import os
import numpy as np
from sklearn.metrics import jaccard_score
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr, kendalltau
import gcsfs
from pathlib import Path
import blitzgsea as blitz
from functools import reduce
from sklearn.metrics import roc_auc_score, precision_recall_curve
from sklearn.utils import resample
import statsmodels.api as sm
from scipy.stats import ttest_1samp

In [None]:
spark = SparkSession.builder.getOrCreate()

In [None]:
def perform_gsea_propagation_missing_dir_gcs(input_gcs_dir, output_gcs_dir, libraries, file_suffixes, pval_cutoff=0.05):
    """
    Perform GSEA propagation and save results in GCS directories, creating library-specific subfolders.

    Args:
        input_gcs_dir (str): GCS path to the input directory (e.g., "gs://bucket-name/input-dir/").
        output_gcs_dir (str): GCS path to the output directory (e.g., "gs://bucket-name/output-dir/").
        libraries (list): List of library names to use for GSEA.
        file_suffixes (list): List of suffixes to filter files (e.g., "_gsea_in").
        pval_cutoff (float): P-value cutoff for filtering significant results.

    Returns:
        None
    """
    # Initialize GCS filesystem
    fs = gcsfs.GCSFileSystem()

    # Ensure both input and output directories are valid GCS paths
    if not input_gcs_dir.startswith("gs://") or not output_gcs_dir.startswith("gs://"):
        raise ValueError("Both input and output directories must be GCS paths starting with 'gs://'.")

    # List all files in the input directory
    input_files = [file for file in fs.ls(input_gcs_dir) if any(suffix in file for suffix in file_suffixes)]

    summary_data = []

    for input_file_path in input_files:
        # Read the file
        print(f"Processing file: {input_file_path}")
        with fs.open(input_file_path, 'r') as f:
            df = pd.read_csv(f)

        # Count rows before removing NaN values
        initial_row_count = df.shape[0]

        # Convert 'overallScore' to numeric, invalid parsing will be set as NaN
        df['overallScore'] = pd.to_numeric(df['overallScore'], errors='coerce')

        # Drop rows where 'overallScore' is NaN
        df_cleaned = df.dropna(subset=['overallScore'])

        # Count rows after removing NaN values
        final_row_count = df_cleaned.shape[0]
        print(f"File: {os.path.basename(input_file_path)}, Initial Rows: {initial_row_count}, Rows after filtering: {final_row_count}")

        # Skip if no valid rows remain
        if final_row_count == 0:
            print(f"No valid data after filtering for file: {input_file_path}")
            continue

        # Rename columns: 'overallScore' as '1' for numeric ranking, 'approvedSymbol' as '0'
        df_cleaned = df_cleaned.rename(columns={'overallScore': '1', 'approvedSymbol': '0'})

        # Extract gene symbols
        gene_symbols = set(df_cleaned['0'].unique())

        # Perform GSEA for each library
        for lib in libraries:
            # Create a unique subfolder for this library in the output directory
            lib_folder = f"{output_gcs_dir}/{lib}"
            fs.mkdirs(lib_folder, exist_ok=True)
            print(f"Library folder created: {lib_folder}")

            # Get the library from enrichr
            library = blitz.enrichr.get_library(lib)

            # Aggregate all unique genes in the current library
            library_genes = set()
            for genes in library.values():
                library_genes.update(genes)

            # Determine missing genes for this library
            missing_genes = gene_symbols - library_genes
            num_missing_genes = len(missing_genes)

            # Perform GSEA
            result = blitz.gsea(df_cleaned[['0', '1']].sort_values(by='1', ascending=False), library, processes=4)
            # print(f"Performed GSEA for {os.path.basename(input_file_path)} with library {lib}. Results shape: {result.shape}")

            # Apply the p-value cutoff
            result_sign = result[result['pval'] <= pval_cutoff].copy()
            print(f"Significant results after pval filtering: {result_sign.shape[0]}")

            # Skip if no significant results remain
            if result_sign.empty:
                print(f"No significant results for file: {input_file_path} with library {lib}")
                continue

            # Ensure 'Term' is the index before propagation
            result_sign_propagated = result_sign.copy()
            result_sign_propagated['propagated_edge'] = result_sign_propagated.index.map(
                lambda term: ','.join(library.get(term, []))
            )

            # Save the results to the library-specific subfolder
            output_file_name = f"{os.path.splitext(os.path.basename(input_file_path))[0]}_gsea_{lib}_pval{pval_cutoff}.csv"
            output_file_path = f"{lib_folder}/{output_file_name}"
            with fs.open(output_file_path, 'w') as f_out:
                result_sign_propagated.to_csv(f_out, index=True)
            print(f"Results saved to {output_file_path}")

            # Summarize the results
            summary_data.append({
                'file': os.path.basename(input_file_path),
                'library': lib,
                'valid_targets': final_row_count,
                'initial_targets': initial_row_count,
                'missing_genes_count': num_missing_genes,
            })

    # Save summary data
    summary_df = pd.DataFrame(summary_data)
    summary_output_path = f"{output_gcs_dir}/gsea_summary.csv"
    with fs.open(summary_output_path, 'w') as summary_file:
        summary_df.to_csv(summary_file, index=False)
    print(f"Summary results saved to {summary_output_path}")

In [None]:
input_gcs_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/target_disease_as"
output_gcs_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/gsea_output"

libraries = ["KEGG_2021_Human",
            "Reactome_Pathways_2024",
            "WikiPathways_2024_Human", 
            "GO_Biological_Process_2023"]
pval_cutoff = 0.05
file_suffixes = ["_ge_mm", "_ge_mm_som"]

perform_gsea_propagation_missing_dir_gcs(input_gcs_dir, output_gcs_dir, libraries, file_suffixes, pval_cutoff)