In [None]:
!pip3 install blitzgsea

In [None]:
# from pyspark.sql import SparkSession
# from pyspark.sql import functions as f
# from pyspark.sql.functions import collect_list, concat_ws, col, when, udf, row_number, sum as spark_sum, max as spark_max, create_map, lit, min as spark_min
# from pyspark.sql.types import DoubleType
# from pyspark.sql import Window
# from itertools import chain
# from pyspark.sql import DataFrame
# from pyspark.sql import Row
# import pandas as pd
# import os
# import numpy as np
# from sklearn.metrics import jaccard_score
# from scipy.spatial.distance import pdist, squareform
# from scipy.stats import spearmanr, kendalltau
# import gcsfs
# from pathlib import Path
# import blitzgsea as blitz
# from functools import reduce
# from sklearn.metrics import roc_auc_score, precision_recall_curve
# from sklearn.utils import resample
# import statsmodels.api as sm
# from scipy.stats import ttest_1samp

In [13]:
spark = SparkSession.builder.getOrCreate()

## Prepare input for blitzGSEA

For each diseaseId with >= 500 genes:
- take columns approvedSymbol, overallScore 
- sort by overallScore
- convert each column name 'overallScore': '1', 'approvedSymbol': '0'
- and saves each partition named after diseaseID into one parquet directory 

In [31]:
import os
import gcsfs
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct
from pyspark.sql import functions as F

In [None]:
# Input and output GCS paths
INPUT_PATH  = "gs://ot-team/polina/pathwaganda/processed_diseases/oncology"
OUTPUT_BASE = "gs://ot-team/polina/pathwaganda/input_4_gsea/oncology"

# ─── Initialize GCS filesystem and check input ────────────────────────────────

fs = gcsfs.GCSFileSystem()
if not fs.exists(INPUT_PATH):
    raise FileNotFoundError(f"Input path not found: {INPUT_PATH}")

# ─── Read all Parquet files from GCS ──────────────────────────────────────────

df = spark.read.parquet(INPUT_PATH)


                                                                                

[Stage 1391:>                                                       (0 + 2) / 2]

In [47]:
MIN_TARGETS = 500

# only pull diseaseIds whose countDistinct(approvedSymbol) >= MIN_TARGETS
valid_diseases = (
    df
    .groupBy("diseaseId")
    .agg(countDistinct("approvedSymbol").alias("uniqueCount"))
    .filter(col("uniqueCount") >= MIN_TARGETS)
    .select("diseaseId")
)


In [48]:
# 2) Inner‐join back to keep only those diseases
df_filtered = df.join(valid_diseases, on="diseaseId", how="inner")

In [49]:
# 1) Select & rename once up‑front:
df2 = (
    df_filtered
    .select("diseaseId", "approvedSymbol", "overallScore")
    .withColumnRenamed("approvedSymbol", "0")
    .withColumnRenamed("overallScore",     "1")
)


In [50]:
# 2) Repartition by diseaseId and sort within each partition **descending** by score:
df2 = (
    df2
    .repartition("diseaseId")
    # .sortWithinPartitions(col("1").desc())
)

In [None]:
# 3) Write out in one go, partitioned by diseaseId:
df2.write \
   .mode("overwrite") \
   .partitionBy("diseaseId") \
   .parquet(OUTPUT_BASE)

                                                                                

                                                                                

In [None]:
# spark.read.parquet("gs://ot-team/polina/pathwaganda/input_4_gsea/non_oncology/diseaseId=EFO_0000195").count()

                                                                                

815

[Stage 1391:>                                                       (0 + 2) / 2]

In [10]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import blitzgsea as blitz

def blitzgsea_with_fileprep_spark(
    input_gcs_dir: str,
    output_gcs_dir: str,
    libraries: list = None,
    fdr_cutoff: float = 0.1
) -> None:
    """
    Perform scalable GSEA on a Parquet dataset with PySpark, grouping by diseaseId and writing partitioned results.

    Args:
        input_gcs_dir (str): GCS path to the Parquet input (file or directory).
        output_gcs_dir (str): GCS path for output library folders.
        libraries (list of str, optional): List of blitz Enrichr library names. If None, fetches all available libraries.
        fdr_cutoff (float): FDR threshold for significant terms.
    """
    # Initialize Spark with GCS connector settings
    spark = (
        SparkSession
        .builder
        .appName("BlitzGSEA")
        # ensure Google Cloud Storage support
        .config("spark.hadoop.fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem")
        .getOrCreate()
    )

    # Read Parquet dataset directly from GCS
    df = spark.read.parquet(input_gcs_dir)

    # Select and rename columns for GSEA
    df2 = df.select(
        col("diseaseId"),
        col("targetId"),
        col("approvedSymbol").alias("0"),
        col("overallScore").alias("1")
    )

    # Determine libraries to process
    if libraries is None:
        libraries = blitz.enrichr.get_libraries()

    # Pre-load all library sets
    library_sets = {lib: blitz.enrichr.get_library(lib) for lib in libraries}

    # Collect distinct disease IDs
    disease_ids = [row.diseaseId for row in df2.select("diseaseId").distinct().collect()]

    for disease_id in disease_ids:
        # Filter, sort, and convert to pandas for blitz.gsea
        pdf = (
            df2
            .filter(col("diseaseId") == disease_id)
            .orderBy(col("1").desc())
            .select("0", "1")
            .toPandas()
        )
        if pdf.empty:
            continue

        for lib in libraries:
            library = library_sets[lib]

            # Run GSEA
            result = blitz.gsea(pdf, library, processes=4)
            sig = result[result.get("FDR", result.get("fdr", result.get("fdr_qvalue"))) <= fdr_cutoff].copy()
            if sig.empty:
                continue

            # Propagation step
            sig["propagated_edge"] = sig.index.map(lambda term: ",".join(library.get(term, [])))
            sig["diseaseId"] = disease_id

            # Convert back to Spark and write to GCS as Parquet
            sdf = spark.createDataFrame(
                sig.reset_index().rename(columns={"index": "Term"})
            )
            output_path = f"{output_gcs_dir.rstrip('/')}/{lib}/{disease_id}"
            sdf.write.mode("overwrite").parquet(output_path)

    spark.stop()


In [11]:
input_gcs_dir = "gs://ot-team/polina/pathwaganda/processed_diseases/non_oncology"
output_gcs_dir = "gs://ot-team/polina/pathwaganda/gsea_results/non_oncology"

libraries = ["KEGG_2021_Human"
            # "Reactome_Pathways_2024",
            # "WikiPathways_2024_Human", 
            # "GO_Biological_Process_2025"]
            ]

blitzgsea_with_fileprep_spark(input_gcs_dir, output_gcs_dir, libraries)

25/07/22 11:31:16 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
                                                                                

ZeroDivisionError: float division by zero