## Prepare input for blitzGSEA

For each diseaseId with >= 500 genes:
- take columns approvedSymbol, overallScore 
- sort by overallScore
- convert each column name 'overallScore': '1', 'approvedSymbol': '0'
- and saves each partition named after diseaseID into one parquet directory 

In [1]:
import os
import gcsfs
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct
from pyspark.sql import functions as F

ModuleNotFoundError: No module named 'gcsfs'

In [2]:
spark = SparkSession.builder.getOrCreate()

NameError: name 'SparkSession' is not defined

In [None]:
# Input and output GCS paths
INPUT_PATH  = "gs://ot-team/polina/pathwaganda/processed_diseases/oncology"
OUTPUT_BASE = "gs://ot-team/polina/pathwaganda/input_4_gsea/oncology"

# ─── Initialize GCS filesystem and check input ────────────────────────────────

fs = gcsfs.GCSFileSystem()
if not fs.exists(INPUT_PATH):
    raise FileNotFoundError(f"Input path not found: {INPUT_PATH}")

# ─── Read all Parquet files from GCS ──────────────────────────────────────────

df = spark.read.parquet(INPUT_PATH)


                                                                                

[Stage 1391:>                                                       (0 + 2) / 2]

In [47]:
MIN_TARGETS = 500

# only pull diseaseIds whose countDistinct(approvedSymbol) >= MIN_TARGETS
valid_diseases = (
    df
    .groupBy("diseaseId")
    .agg(countDistinct("approvedSymbol").alias("uniqueCount"))
    .filter(col("uniqueCount") >= MIN_TARGETS)
    .select("diseaseId")
)


In [48]:
# 2) Inner‐join back to keep only those diseases
df_filtered = df.join(valid_diseases, on="diseaseId", how="inner")

In [49]:
# 1) Select & rename once up‑front:
df2 = (
    df_filtered
    .select("diseaseId", "approvedSymbol", "overallScore")
    .withColumnRenamed("approvedSymbol", "0")
    .withColumnRenamed("overallScore",     "1")
)


In [50]:
# 2) Repartition by diseaseId and sort within each partition **descending** by score:
df2 = (
    df2
    .repartition("diseaseId")
    # .sortWithinPartitions(col("1").desc())
)

In [None]:
# 3) Write out in one go, partitioned by diseaseId:
df2.write \
   .mode("overwrite") \
   .partitionBy("diseaseId") \
   .parquet(OUTPUT_BASE)

                                                                                

                                                                                

In [None]:
# spark.read.parquet("gs://ot-team/polina/pathwaganda/input_4_gsea/non_oncology/diseaseId=EFO_0000195").count()

                                                                                

815

[Stage 1391:>                                                       (0 + 2) / 2]