# Save target-pathway lists as spark dfs and filter out non-gene targets

Input folder: Pathwaganda/data/GSEA_output

In [1]:
from pyspark.sql import SparkSession
import os
import shutil
from pyspark.sql.functions import split, explode, collect_list, col, concat_ws, input_file_name, regexp_extract, lit
from pyspark.sql import functions as F
from pyspark.sql import Row

In [2]:
spark = SparkSession.builder.getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/21 15:37:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/21 15:37:43 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [9]:
df_test = spark.read.parquet("/Users/polina/Pathwaganda/data/GSEA_output/Reactome_Pathways_2025_diy/diseaseId=EFO_0000094")

In [10]:
df_test.show(5, truncate=False)

+--------------------------------------------------------------------------+-------------+-------------------+------------------+---------------------+-------------------+-------------------+------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Filter by FDR, add hierarchies

1. Read all of files in Reactome_Pathways_2025_diy as one parquet file with new column diseaseId
2. Filter by fdr cut off 0.1
3. Add parent pathway into table

In [None]:
def load_reactome_pathways_with_hierarchy(base_dir, hierarchy_path, fdr_cutoff=0.1):

    # Step 1: Read parquet files recursively
    df = spark.read.option("recursiveFileLookup", "true").parquet(base_dir)

    # Step 2: Extract diseaseId from file path
    df = df.withColumn(
        "diseaseId",
        regexp_extract(input_file_name(), r"diseaseId=([^/]+)", 1)
    )

    # Step 3: Filter by FDR cutoff
    df_filtered = df.filter(col("fdr") <= fdr_cutoff)

    # Step 4: Load pathway hierarchy file
    pathways_hierarchy_df = (
        spark.read.option("delimiter", "\t")
        .csv(hierarchy_path, header=False)
        .withColumnRenamed("_c0", "parentId")
        .withColumnRenamed("_c1", "childId")
    )

    # Step 5: Compute hierarchy level in-memory
    hierarchy_pairs = pathways_hierarchy_df.collect()
    parent_map = {row["childId"]: row["parentId"] for row in hierarchy_pairs}

    def get_level(child_id):
        level = 0
        current = child_id
        while current in parent_map and parent_map[current] is not None:
            current = parent_map[current]
            level += 1
            if level > 50:  # safety break for cycles
                break
        return level

    levels_data = [
        Row(parentId=row["parentId"], childId=row["childId"], hierLevel=get_level(row["childId"]))
        for row in hierarchy_pairs
    ]

    pathways_hierarchy_df = spark.createDataFrame(levels_data)

    # Step 6: Merge filtered pathways with hierarchy (now including hierLevel)
    joined_df = df_filtered.join(
        pathways_hierarchy_df,
        df_filtered["ID"] == pathways_hierarchy_df["childId"],
        "left"
    )

    return joined_df


In [9]:
base_dir = "/Users/polina/Pathwaganda/data/GSEA_output/Reactome_Pathways_2025_diy"
hierarchy_path = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/Reactome/Pathways_hierarchy_relationship.txt"

df_filtered = load_reactome_pathways_with_hierarchy(base_dir, hierarchy_path)

                                                                                

In [16]:
df_filtered.filter(col("hierLevel") > 8).show(10)



+--------------------+------------+-------------------+-------------------+--------------------+------------------+--------------------+------------+--------------------+--------------------+-------------+------------+------------+---------+
|                Term|          ID|                 es|                nes|                pval|             sidak|                 fdr|geneset_size|        leading_edge|     propagated_edge|    diseaseId|    parentId|     childId|hierLevel|
+--------------------+------------+-------------------+-------------------+--------------------+------------------+--------------------+------------+--------------------+--------------------+-------------+------------+------------+---------+
|Formation of HIV-...|R-HSA-167200|-0.6109253369441215|-2.7815364391841655|0.005410225667489277|0.9999213261776687|0.052358961737590665|          17|POLR2K,POLR2C,POL...|CCNT1,CDK7,CDK9,C...|  EFO_0009676|R-HSA-167246|R-HSA-167200|        9|
|Formation of HIV-...|R-HSA-1672

                                                                                

In [17]:
df_filtered.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/GSEA-output_filt_hier")

25/08/14 14:21:39 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers


25/08/14 18:08:15 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 987319 ms exceeds timeout 120000 ms
25/08/14 18:08:16 WARN SparkContext: Killing executors is not supported by current scheduler.
25/08/14 18:08:18 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:132)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$

# Target-pathway matrix

Explode targets from propagated_edge and create boolean matrix TxP.

Hierarchical levels explanation:
- Level 1 - broad propagation: all pathways
- Level 2 - medium propagation: all except higher parent pathways
- Level 3 - specific propagation: except 1st and 2nd higher parent pathways

## Add targetIds

In [None]:
spark.read.parquet("/Users/polina/Pathwaganda/data/GSEA-output_filt_hier").show(5)

+--------------------+-------------+-------------------+------------------+--------------------+--------------------+--------------------+------------+--------------------+--------------------+-----------+-------------+-------------+---------+
|                Term|           ID|                 es|               nes|                pval|               sidak|                 fdr|geneset_size|        leading_edge|     propagated_edge|  diseaseId|     parentId|      childId|hierLevel|
+--------------------+-------------+-------------------+------------------+--------------------+--------------------+--------------------+------------+--------------------+--------------------+-----------+-------------+-------------+---------+
|Extracellular mat...|R-HSA-1474244|0.27986151888660815|4.4332464205552755|9.282466883142604E-6| 0.01783777610591378|2.399827104855134...|         273|SGCG,PCOLCE,ADAMT...|ACAN,ACTA1,ACTA2,...|EFO_0004747|         NULL|         NULL|     NULL|
|Transcriptional r...|R-

In [29]:
target = spark.read.parquet("/Users/polina/Pathwaganda/data/target")

In [42]:
def target_pathway_interm(input_path: str, target_path: str, fdr_threshold: float = 0.1, delimiter: str = ","):
    
    # --- Read input parquet ---
    df = spark.read.parquet(input_path)
    df = df.select("ID", "fdr", "propagated_edge", "diseaseId", "hierLevel")
    df = df.filter(F.col("fdr") <= fdr_threshold)
    df = df.withColumn("propagated_edge_array", F.split(F.col("propagated_edge"), delimiter))
    df = df.withColumn("approvedSymbol", F.explode("propagated_edge_array"))
    df = df.select("ID", "diseaseId", "hierLevel", "approvedSymbol")
    
    # --- Read target parquet ---
    target_df = spark.read.parquet(target_path).select(
        F.col("id").alias("targetId"),
        "approvedSymbol",
        "symbolSynonyms"
    )
    
    # --- Explode synonyms into mapping ---
    synonyms_df = (
        target_df
        .withColumn("syn_struct", F.explode_outer("symbolSynonyms"))
        .withColumn("approvedSymbol", F.col("syn_struct").getField("label"))  # use label
        .select("targetId", "approvedSymbol")
        .filter(F.col("approvedSymbol").isNotNull())
    )
    
    # --- Union approvedSymbol + synonyms mapping ---
    mapping_df = (
        target_df.select("targetId", "approvedSymbol")
        .unionByName(synonyms_df)
        .dropDuplicates(["targetId", "approvedSymbol"])
    )
    
    # --- Left join on mapping ---
    joined = df.join(mapping_df, on="approvedSymbol", how="left")
    
    # --- Aggregate multiple targetIds for same approvedSymbol ---
    result = (
        joined.groupBy("ID", "diseaseId", "hierLevel", "approvedSymbol")
              .agg(F.concat_ws(",", F.collect_set("targetId")).alias("targetId"))
    )
    
    return result

In [43]:
target_pathway_interm_df = target_pathway_interm(
    "/Users/polina/Pathwaganda/data/GSEA-output_filt_hier",
    "/Users/polina/Pathwaganda/data/target",
    fdr_threshold=0.1,
    delimiter=","
)

In [None]:
# Count distinct rows where targetId is null or empty
empty_targetid = (
    target_pathway_interm_df.filter((F.col("targetId").isNull()) | (F.col("targetId") == ""))
          .select("approvedSymbol")   # or "ID","diseaseId","hierLevel","approvedSymbol" if you want row uniqueness
          .distinct()
)

print(f"Number of distinct rows with empty targetId: {empty_targetid.count()}")


                                                                                

Number of distinct rows with empty targetId: 425


[Stage 80:>                                                         (0 + 1) / 1]

+--------------+
|approvedSymbol|
+--------------+
|          opaE|
|          opaJ|
|       5S rRNA|
| POU5F1 (OCT4)|
|          UL83|
+--------------+
only showing top 5 rows


                                                                                

In [44]:
target_pathway_interm_df.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/GSEA-output_by_target_interm")

25/08/18 13:44:37 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

## Save target-pathway matrix

In [51]:
def mtx_prepare_hier(input_path: str, level: int) -> None:
    """
    Reads a parquet file, filters rows by hierLevel and targetId, groups by diseaseId,
    concatenates ID values into pathwayIds, and writes result to parquet.

    Args:
        input_path (str): Path to input parquet file
        output_path (str): Path to save output parquet file
        level (int): Hierarchy filter. If 0, keep all rows. If n, keep rows where hierLevel >= n.
    """

    # Read parquet
    df = spark.read.parquet(input_path)

    # Apply hierLevel filter
    if level > 0:
        df = df.filter(F.col("hierLevel") >= level)

    # Filter out null or empty targetId
    df = df.filter(~(F.col("targetId").isNull() | (F.col("targetId") == "")))

    # Group by diseaseId, targetId, approvedSymbol and aggregate IDs
    result = (df.groupBy("diseaseId", "targetId", "approvedSymbol")
                .agg(F.collect_list("ID").alias("pathwayIds"))
                .withColumn("pathwayIds", F.concat_ws(",", "pathwayIds")))
    
    return result

In [54]:
mtx_level_0 = mtx_prepare_hier("/Users/polina/Pathwaganda/data/GSEA-output_by_target_interm", 0)
mtx_level_0.count()

                                                                                

3101968

In [None]:
mtx_level_0.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tpm_levels/0")

25/08/18 14:06:51 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [55]:
mtx_level_1 = mtx_prepare_hier("/Users/polina/Pathwaganda/data/GSEA-output_by_target_interm", 1)
mtx_level_1.count()

                                                                                

1786033

In [None]:
mtx_level_1.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tpm_levels/1")

25/08/18 14:07:01 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [56]:
mtx_level_2 = mtx_prepare_hier("/Users/polina/Pathwaganda/data/GSEA-output_by_target_interm", 2)
mtx_level_2.count()

                                                                                

1149592

In [None]:
mtx_level_2.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tpm_levels/2")

25/08/18 14:07:12 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [3]:
spark.read.parquet("/Users/polina/Pathwaganda/data/tpm_levels/2").show(5)

+-----------+---------------+--------------+------------+
|  diseaseId|       targetId|approvedSymbol|  pathwayIds|
+-----------+---------------+--------------+------------+
|EFO_0000095|ENSG00000004897|         CDC27|R-HSA-212436|
|EFO_0000095|ENSG00000013275|         PSMC4|R-HSA-212436|
|EFO_0000095|ENSG00000047315|        POLR2B|R-HSA-212436|
|EFO_0000095|ENSG00000051180|         RAD51|R-HSA-212436|
|EFO_0000095|ENSG00000060069|         CTDP1|R-HSA-212436|
+-----------+---------------+--------------+------------+
only showing top 5 rows


# Pathway embeddings

## Hierarchical (Poincare ball model)

In [3]:
import os
import math
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from gensim.models.poincare import PoincareModel
import pandas as pd

import sys
import types
# Ensure gensim can import numpy.strings
sys.modules['numpy.strings'] = types.ModuleType('numpy.strings')


spark = SparkSession.builder.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

In [4]:
def process_hierarchy_diseases(input_file, negative=10, epochs=100, max_hierLevel=None):
    """
    Process disease hierarchies with optional level filtering
    
    Args:
        input_file: Path to input parquet file
        negative: Maximum number of negative samples to use
        epochs: Number of training epochs
        max_hierLevel: Maximum hierarchy level to include (None = all levels, 0 = root only, 
                      1 = root and level 1, etc.)
    """
    # 1) Read required columns
    base_cols = ["diseaseId", "ID", "parentId"]
    if max_hierLevel is not None:
        base_cols.append("hierLevel")
    
    df_all = spark.read.parquet(input_file).select(*base_cols).cache()

    # 2) Apply hierarchy level filtering if specified
    if max_hierLevel is not None:
        df_all = df_all.filter(col("hierLevel") <= max_hierLevel)
        print(f"Filtering to hierarchy levels <= {max_hierLevel}")

    # 3) Collect distinct diseases
    disease_ids = [r["diseaseId"] for r in df_all.select("diseaseId").distinct().collect()]

    all_embeddings = []

    for disease in disease_ids:
        print(f"\nProcessing disease: {disease}")

        # Filter to one disease
        df = df_all.filter(col("diseaseId") == disease)

        # Fill null parents with self
        df = df.withColumn(
            "parentId",
            when(col("parentId").isNull(), col("ID")).otherwise(col("parentId"))
        )

        # Count unique nodes
        N = df.selectExpr("ID as node").union(
            df.selectExpr("parentId as node")
        ).distinct().count()

        dims = max(2, math.ceil(math.log2(N)))
        print(f"{disease} - Total distinct nodes N: {N}; Chosen d: {dims}")

        # Skip too small graphs
        if N < 3:
            print(f"Skipping {disease}: too few nodes ({N})")
            continue

        # Extract edges (parentId → ID)
        edges = (
            df.select("parentId", "ID")
            .dropDuplicates()
            .localCheckpoint()
            .toPandas()
            .values.tolist()
        )

        # BULLETPROOF NEGATIVE SAMPLING
        def calculate_safe_negatives(requested, total_nodes):
            max_possible = total_nodes - 2  # Conservative estimate
            safe_neg = min(requested, max(1, max_possible))
            print(f"Negative sampling: Requested {requested}, Safe maximum {max_possible}, Using {safe_neg}")
            return safe_neg

        neg = calculate_safe_negatives(negative, N)

        # Train with automatic retry logic
        max_attempts = 2
        for attempt in range(max_attempts):
            try:
                model = PoincareModel(edges, negative=neg, size=dims)
                model.train(epochs=epochs)
                break
            except ValueError as e:
                if attempt == max_attempts - 1:
                    print(f"Failed after {max_attempts} attempts for {disease}: {str(e)}")
                    continue  # Skip this disease
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                neg = max(1, neg - 2)
                print(f"Retrying with {neg} negatives")

        if 'model' not in locals():
            continue  # Skip to next disease if all attempts failed

        # Dump embeddings
        emb = [(key, *model.kv[key]) for key in model.kv.index_to_key]
        pdf = pd.DataFrame(emb, columns=["ID"] + [f"dim_{i}" for i in range(dims)])
        pdf["diseaseId"] = disease

        all_embeddings.append(pdf)

    # Combine all into one Spark DataFrame
    if all_embeddings:
        final_pdf = pd.concat(all_embeddings, ignore_index=True)
        final_sdf = spark.createDataFrame(final_pdf)
    else:
        print("No embeddings generated.")
        final_sdf = spark.createDataFrame(pd.DataFrame(columns=["ID", "diseaseId"]))

    df_all.unpersist()
    return final_sdf

In [5]:
pem_level0 = process_hierarchy_diseases("/Users/polina/Pathwaganda/data/GSEA-output_filt_hier")

                                                                                


Processing disease: EFO_0000701
EFO_0000701 - Total distinct nodes N: 59; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 57, Using 10

Processing disease: EFO_0004872
EFO_0004872 - Total distinct nodes N: 61; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 59, Using 10

Processing disease: EFO_0005803
EFO_0005803 - Total distinct nodes N: 367; Chosen d: 9
Negative sampling: Requested 10, Safe maximum 365, Using 10

Processing disease: MONDO_0002715
MONDO_0002715 - Total distinct nodes N: 478; Chosen d: 9
Negative sampling: Requested 10, Safe maximum 476, Using 10

Processing disease: GO_0008150
GO_0008150 - Total distinct nodes N: 24; Chosen d: 5
Negative sampling: Requested 10, Safe maximum 22, Using 10

Processing disease: HP_0012638
HP_0012638 - Total distinct nodes N: 48; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 46, Using 10

Processing disease: MONDO_0023370
MONDO_0023370 - Total distinct nodes N: 493; Chosen d: 9
Negative sampling: Requested

In [6]:
pem_level0.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/pem_levels/0")

                                                                                

In [7]:
pem_level1 = process_hierarchy_diseases("/Users/polina/Pathwaganda/data/GSEA-output_filt_hier", 
                                        max_hierLevel=1)

Filtering to hierarchy levels <= 1

Processing disease: EFO_0000701
EFO_0000701 - Total distinct nodes N: 25; Chosen d: 5
Negative sampling: Requested 10, Safe maximum 23, Using 10

Processing disease: EFO_0004872
EFO_0004872 - Total distinct nodes N: 21; Chosen d: 5
Negative sampling: Requested 10, Safe maximum 19, Using 10

Processing disease: EFO_0005803
EFO_0005803 - Total distinct nodes N: 56; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 54, Using 10

Processing disease: MONDO_0002715
MONDO_0002715 - Total distinct nodes N: 66; Chosen d: 7
Negative sampling: Requested 10, Safe maximum 64, Using 10

Processing disease: GO_0008150
GO_0008150 - Total distinct nodes N: 14; Chosen d: 4
Negative sampling: Requested 10, Safe maximum 12, Using 10

Processing disease: HP_0012638
HP_0012638 - Total distinct nodes N: 15; Chosen d: 4
Negative sampling: Requested 10, Safe maximum 13, Using 10

Processing disease: MONDO_0023370
MONDO_0023370 - Total distinct nodes N: 76; Chosen d: 

In [8]:
pem_level1.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/pem_levels/1")

In [9]:
pem_level2 = process_hierarchy_diseases("/Users/polina/Pathwaganda/data/GSEA-output_filt_hier", 
                                        max_hierLevel=2)

Filtering to hierarchy levels <= 2

Processing disease: EFO_0000701
EFO_0000701 - Total distinct nodes N: 39; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 37, Using 10

Processing disease: EFO_0004872
EFO_0004872 - Total distinct nodes N: 40; Chosen d: 6
Negative sampling: Requested 10, Safe maximum 38, Using 10

Processing disease: EFO_0005803
EFO_0005803 - Total distinct nodes N: 145; Chosen d: 8
Negative sampling: Requested 10, Safe maximum 143, Using 10

Processing disease: MONDO_0002715
MONDO_0002715 - Total distinct nodes N: 169; Chosen d: 8
Negative sampling: Requested 10, Safe maximum 167, Using 10

Processing disease: GO_0008150
GO_0008150 - Total distinct nodes N: 20; Chosen d: 5
Negative sampling: Requested 10, Safe maximum 18, Using 10

Processing disease: HP_0012638
HP_0012638 - Total distinct nodes N: 19; Chosen d: 5
Negative sampling: Requested 10, Safe maximum 17, Using 10

Processing disease: MONDO_0023370
MONDO_0023370 - Total distinct nodes N: 190; Chose

In [10]:
pem_level2.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/pem_levels/2")

In [11]:
pem_level2.show(5)

+-------------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+-----------+-----+-----+
|           ID|               dim_0|              dim_1|               dim_2|              dim_3|               dim_4|               dim_5|  diseaseId|dim_6|dim_7|
+-------------+--------------------+-------------------+--------------------+-------------------+--------------------+--------------------+-----------+-----+-----+
|R-HSA-1474290|  0.2979782567242389| 0.3779123131695229| 0.07479504883208896|0.06703926448642958|-0.04872236623855061| 0.48691106425357245|EFO_0000701|  NaN|  NaN|
|R-HSA-2022090| 0.27989850816200074| 0.3564410743901681|  0.0708735905468227|0.06323375140911779|-0.04511225001842...| 0.46039601461244206|EFO_0000701|  NaN|  NaN|
|R-HSA-6805567|-0.47368833256761567| 0.1543839617924824| 0.08013296850391213|-0.5596202237298129| 0.23809138300842622| 0.12701567595722524|EFO_0000701|  NaN|  NaN|
|R-HSA-6809371|-

## Weighted (Jaccard similarity index)

# Target embeddings

## Hierarchical (Poincare coordinates)

Based on tmp and pem create target coordinates (tem) in hyperbolic space.

In [4]:
from pyspark.sql.functions import explode, split, avg, col

In [6]:
from pyspark.sql import functions as F

def create_target_embeddings(pathway_file: str, target2pathway_file: str):
    """
    Generate target embeddings by averaging pathway embeddings per disease,
    preserving original pathwayIds column.
    """

    # Read pathway embeddings
    pathway_df = spark.read.parquet(pathway_file).alias("pem")
    embedding_columns = [c for c in pathway_df.columns if c.startswith("dim_")]

    # Read target-to-pathway mapping
    target_df = spark.read.parquet(target2pathway_file).alias("tpm")

    # Count missing diseases
    missing_disease_ids = (
        target_df.select(F.col("tpm.diseaseId").alias("diseaseId")).distinct()
        .join(pathway_df.select(F.col("pem.diseaseId").alias("diseaseId")).distinct(), 
              on="diseaseId", how="left_anti")
    )
    print(f"Number of diseases in TPM not found in PEM: {missing_disease_ids.count()}")

    # Explode pathwayIds, but also keep original pathwayIds
    exploded_target_df = (
        target_df
        .withColumn("pathwayId", F.explode(F.split(F.col("tpm.pathwayIds"), ",")))
        .select("tpm.targetId", "tpm.approvedSymbol", "tpm.diseaseId", "tpm.pathwayIds", "pathwayId")
    ).alias("tpm")

    # Join with pathway embeddings
    joined_df = (
        exploded_target_df
        .join(
            pathway_df.alias("pem"),
            (F.col("tpm.pathwayId") == F.col("pem.ID")) &
            (F.col("tpm.diseaseId") == F.col("pem.diseaseId")),
            "inner"
        )
    )

    # Average embeddings per target + disease, while keeping pathwayIds
    averaged_df = (
        joined_df
        .groupBy("tpm.targetId", "tpm.approvedSymbol", "tpm.diseaseId", "tpm.pathwayIds")
        .agg(*[F.avg(F.col(f"pem.{c}")).alias(c) for c in embedding_columns])
    )

    return averaged_df


In [7]:
tem_level0 = create_target_embeddings(
    "/Users/polina/Pathwaganda/data/pem_levels/0",
    "/Users/polina/Pathwaganda/data/tpm_levels/0")

tem_level0.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tem_levels/0")

Number of diseases in TPM not found in PEM: 50


25/08/21 15:43:53 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [8]:
spark.read.parquet("/Users/polina/Pathwaganda/data/tem_levels/0").show(5)

+---------------+--------------+-----------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----+
|       targetId|approvedSymbol|  diseaseId|          pathwayIds|              dim_0|               dim_1|               dim_2|              dim_3|               dim_4|               dim_5|               dim_6|               dim_7|              dim_8|dim_9|
+---------------+--------------+-----------+--------------------+-------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----+
|ENSG00000000419|          DPM1|EFO_0000199|       R-HSA-1643685|0.03693195721547555|  -0.594752644697465|-0.15494661392907222|-0.4537320177105202| -0.3208464922484456|-0.42881094175268264|-0.16167811177650787|-0.2284632900485

In [9]:
tem_level1 = create_target_embeddings(
    "/Users/polina/Pathwaganda/data/pem_levels/1",
    "/Users/polina/Pathwaganda/d" \
    "ata/tpm_levels/1")

tem_level1.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tem_levels/1")

Number of diseases in TPM not found in PEM: 90


25/08/21 15:44:40 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/21 15:44:40 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [10]:
tem_level2 = create_target_embeddings(
    "/Users/polina/Pathwaganda/data/pem_levels/2",
    "/Users/polina/Pathwaganda/data/tpm_levels/2")

tem_level2.write.mode("overwrite").parquet("/Users/polina/Pathwaganda/data/tem_levels/2")

Number of diseases in TPM not found in PEM: 16


25/08/21 15:44:47 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [27]:
tem_level0.filter(col("dim_7") != "NaN").show(5)

+---------------+--------------+-----------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|       targetId|approvedSymbol|  diseaseId|               dim_0|               dim_1|              dim_2|               dim_3|               dim_4|               dim_5|               dim_6|               dim_7|
+---------------+--------------+-----------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|ENSG00000058404|        CAMK2B|EFO_0000313| 0.10057969602937376| 0.11319281429042716|0.09188663680975077| 0.15102883960552846|-0.15157995953640524| 0.05882703634601264| 0.08532016216936399|  0.1922430112087569|
|ENSG00000111790|      FGFR1OP2|EFO_0000313|-4.95571735941970...|-0.03676460002274...|0.27457072469634203|   0.239685775645362| -0.3008503539177753| 0.1

                                                                                

Ok now everything ready for score propagation

## Prepare target-based metadata files with info about targets per disease

In [3]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/05 12:05:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
spark.read.parquet("/Users/polina/Pathwaganda/data/target-pathway_matrix_opt/test/diseaseId=EFO_0000094").show(5)

                                                                                

+----------------+--------------------+
|  approvedSymbol|                  ID|
+----------------+--------------------+
| complete genome|R-HSA-1643685,R-H...|
|        18S rRNA|       R-HSA-1643685|
|              1B|R-HSA-1643685,R-H...|
|              1C|R-HSA-1643685,R-H...|
|              1a|R-HSA-1643685,R-H...|
+----------------+--------------------+
only showing top 5 rows


Lets use target-pathway_matrix_opt folder to start with and parse target info from OT files.

In [None]:
# Take targetId from:

spark.read.parquet("/Users/polina/Pathwaganda/data/target").show(5)

25/08/05 12:08:07 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+---------------+--------------+--------------+--------------------+--------------------+--------------------+--------------------+----------------+--------------------+--------------------+---------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----+--------------------+--------------------+--------------+--------------------+--------------------+--------------------+--------------------+---------+
|             id|approvedSymbol|       biotype|       transcriptIds| canonicalTranscript|      canonicalExons|     genomicLocation|alternativeGenes|        approvedName|                  go|hallmarks|            synonyms|      symbolSynonyms|        nameSynonyms|functionDescriptions|subcellularLocations|         targetClass|     obsoleteSymbols|       obsoleteNames|          constraint| tep|          proteinIds|             dbXrefs|chemicalProbes|   

In [27]:
# Take genetic evidence scores from gsea_4_inout files:

spark.read.parquet("/Users/polina/Pathwaganda/data/input_4_gsea/diseaseId=EFO_0000094").show(5)

+--------+-------------------+
|       0|                  1|
+--------+-------------------+
|    CUX1| 0.4559480982087158|
|  NPIPB8|0.14562438358841173|
|    ETV5| 0.3039653988058105|
|  NUTM2D| 0.3039653988058105|
|DCAF12L2| 0.3039653988058105|
+--------+-------------------+
only showing top 5 rows


In [None]:
# Take drug info from:

spark.read.parquet("/Users/polina/Pathwaganda/data/known_drug").show(5)

+-----------+---------------+----------+-----+----------+--------------------+--------------------+---------------+--------------+--------------------+-------------+----------------+--------------------+--------------------+--------------+--------------------+--------------------+
|     drugId|       targetId| diseaseId|phase|    status|                urls|           ancestors|          label|approvedSymbol|        approvedName|  targetClass|        prefName|          tradeNames|            synonyms|      drugType|   mechanismOfAction|          targetName|
+-----------+---------------+----------+-----+----------+--------------------+--------------------+---------------+--------------+--------------------+-------------+----------------+--------------------+--------------------+--------------+--------------------+--------------------+
|CHEMBL52440|ENSG00000183454|DOID_10113|  1.0| Completed|[{ClinicalTrials,...|[MONDO_0002428, E...|trypanosomiasis|        GRIN2A|glutamate ionotro...|[Io

### Merge with known_drug from ChEMBL

In [20]:
import os
import re
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, collect_list, concat_ws, max as spark_max

def merge_files_with_target_drug_info(input_dir: str, 
                                      target_parquet: str, 
                                    #   association_parquet: str, 
                                      known_drug_parquet: str, 
                                      output_dir: str):
    # Start Spark session
    spark = SparkSession.builder.getOrCreate()

    # Read shared dataframes
    target_df = spark.read.parquet(target_parquet).select("approvedSymbol", "id").distinct()
    known_drug_df = spark.read.parquet(known_drug_parquet).select("phase", "targetId", "diseaseId")

    # Prepare target mapping: approvedSymbol -> comma-separated list of ids
    target_agg_df = (
        target_df
        .groupBy("approvedSymbol")
        .agg(concat_ws(",", collect_list("id")).alias("targetId"))
    )

    # Iterate over folders in input_dir
    for folder_name in os.listdir(input_dir):
        folder_path = os.path.join(input_dir, folder_name)
        if not os.path.isdir(folder_path):
            continue

        # Expect folder name of format diseaseId=XXX
        match = re.match(r"diseaseId=(.+)", folder_name)
        if not match:
            continue
        disease_id = match.group(1)

        # Read initial file (assuming single parquet in folder)
        initial_file_path = os.path.join(folder_path)
        initial_df = spark.read.parquet(initial_file_path)

        # Join with target mapping
        initial_with_target = (
            initial_df
            .join(target_agg_df, on="approvedSymbol", how="left")
        )

        # Filter known drug data for the current diseaseId
        filtered_known_drug = known_drug_df.filter(col("diseaseId") == disease_id)

        # Join with known_drug to get phase
        final_df = (
            initial_with_target
            .join(filtered_known_drug, on="targetId", how="left")
        )

        # Aggregate max phase for each row
        result_df = (
            final_df
            .groupBy(*initial_df.columns, "targetId")
            .agg(spark_max("phase").alias("maxPhaseChEMBL"))
        )

        # Save the result as a parquet file to output_dir with same folder name
        output_path = os.path.join(output_dir, folder_name)
        result_df.write.mode("overwrite").parquet(output_path)


In [21]:
merge_files_with_target_drug_info(
    input_dir="/Users/polina/Pathwaganda/data/target-pathway_matrix_opt/Reactome_Pathways_2025_diy",
    target_parquet="/Users/polina/Pathwaganda/data/target",
    # association_path="/path/to/association.parquet",
    known_drug_parquet="/Users/polina/Pathwaganda/data/known_drug",
    output_dir="/Users/polina/Pathwaganda/data/target_metadata/known_drug_merge/Reactome_Pathways_2025_diy"
)

                                                                                

In [None]:
spark.read.parquet("/Users/polina/Pathwaganda/diseaseId=EFO_0000094").show(5)

+--------------+--------------------+---------------+--------------+
|approvedSymbol|                  ID|       targetId|maxPhaseChEMBL|
+--------------+--------------------+---------------+--------------+
|        ANGPT1|        R-HSA-109582|ENSG00000154188|          NULL|
|         APOOL|R-HSA-1592230,R-H...|ENSG00000155008|          NULL|
|         CCAR1|R-HSA-72203,R-HSA...|ENSG00000060339|          NULL|
|          CD96|        R-HSA-198933|ENSG00000153283|          NULL|
|         CDH24|R-HSA-9759476,R-H...|ENSG00000139880|          NULL|
+--------------+--------------------+---------------+--------------+
only showing top 5 rows


### Merge with genetic association score from OT platform

In [57]:
import os
from pyspark.sql import SparkSession

def merge_parquet_folders_with_na(spark, folder_1, folder_2, output_dir):
    """
    For each subfolder in folder_1:
    - Read parquet from folder_1 and folder_2 (same subfolder name)
    - Rename columns in folder_2 df: '0' -> 'approvedSymbol', '1' -> 'geneticScore'
    - Left join folder_1 df with folder_2 df on 'approvedSymbol'
    - If folder_2 subfolder missing, write folder_1 df as is
    - Write merged df to output_dir with same subfolder name
    """
    folder_1_subdirs = [name for name in os.listdir(folder_1) 
                        if os.path.isdir(os.path.join(folder_1, name))]

    for subdir in folder_1_subdirs:
        path_1 = os.path.join(folder_1, subdir)
        path_2 = os.path.join(folder_2, subdir)
        output_path = os.path.join(output_dir, subdir)

        df1 = spark.read.parquet(path_1)

        if not os.path.exists(path_2):
            print(f"Folder {subdir} missing in folder_2. Writing original file from folder_1 as is.")
            df1.write.mode("overwrite").parquet(output_path)
            continue

        df2 = spark.read.parquet(path_2)
        df2_renamed = df2.withColumnRenamed("0", "approvedSymbol") \
                         .withColumnRenamed("1", "geneticScore")

        # Left join so unmatched get null for geneticScore
        merged_df = df1.join(df2_renamed.select("approvedSymbol", "geneticScore"), 
                             on="approvedSymbol", how="left")

        merged_df.write.mode("overwrite").parquet(output_path)
        print(f"Merged and written: {output_path}")

In [58]:
spark = SparkSession.builder.appName("MergeParquets").getOrCreate()

merge_parquet_folders_with_na(
    spark,
    folder_1="/Users/polina/Pathwaganda/data/target_metadata/known_drug_merge/Reactome_Pathways_2025_diy",
    folder_2="/Users/polina/Pathwaganda/data/input_4_gsea",
    output_dir="/Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy"
)

Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=EFO_0000503
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=EFO_0011015
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=MONDO_0000569
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=MONDO_0003916
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=EFO_0004533
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=MONDO_0002033
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=MONDO_0017343
Merged and written: /Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=MOND

In [55]:
spark.read.parquet("/Users/polina/Pathwaganda/data/target_metadata/ge_merge/test/diseaseId=MONDO_0045024").filter(col("targetId").isNull()).count()

283

! Need to make synonyms search !

# Prepare file with umap coordinates

## Case 1: user hasn't specified list of genes (show only genetics)

Steps: 
- Take coordinate file and run umap and clustering
- Write coordinates and clusters into correspondent metadata file
- Run gsea to put labels for each pathway (opt)
- Filter out genes without genetic evidence

In [None]:
import pandas as pd
import numpy as np
import os
import umap
import hdbscan
from scipy.spatial.distance import pdist, squareform

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def compute_poincare_distance_matrix(embedding_matrix):
    """Vectorized computation of pairwise Poincaré distances."""
    def poincare_dist(u, v):
        norm_u = np.linalg.norm(u)
        norm_v = np.linalg.norm(v)
        norm_diff = np.linalg.norm(u - v)

        denom = (1 - norm_u ** 2) * (1 - norm_v ** 2)
        if denom <= 0:
            return float('inf')

        argument = 1 + 2 * (norm_diff ** 2) / denom
        return np.arccosh(argument)

    return squareform(pdist(embedding_matrix, metric=poincare_dist))

In [None]:
def perform_umap_clustering_parquet(
    metadata_parquet_dir,
    coordinates_parquet_dir,
    output_dir,
    n_neighbors=10,
    min_dist=0.5,
    min_cluster_size=12,
    umap_dimensions=2
):
    """
    Performs UMAP dimensionality reduction and HDBSCAN clustering using Poincaré distance.
    Aligns metadata and coordinates by 'approvedSymbol', saves final result as TSV.
    """

    # Load metadata and coordinates
    metadata = pd.read_parquet(metadata_parquet_dir).query("geneticScore.notnull()")
    coords_df = pd.read_parquet(coordinates_parquet_dir)

    # Sanity checks
    assert 'approvedSymbol' in metadata.columns, "Metadata must contain 'approvedSymbol'"
    assert coords_df.shape[1] > 1, "Coordinates must have approvedSymbol + at least one dimension"

    # Rename first column to 'approvedSymbol' if needed
    coords_df = coords_df.rename(columns={coords_df.columns[0]: 'approvedSymbol'})

    # Convert coordinate columns to float
    coord_columns = coords_df.columns[1:]
    coords_df[coord_columns] = coords_df[coord_columns].astype(float)

    # Merge metadata and coordinates on approvedSymbol
    merged_df = pd.merge(metadata, coords_df, on='approvedSymbol', how='inner')
    print(f"Merged metadata and coordinates: {merged_df.shape[0]} entries.")

    # Extract embedding matrix (in correct order)
    embedding_matrix = merged_df[coord_columns].values

    def compute_poincare_distance_matrix(embedding_matrix):
    """Vectorized computation of pairwise Poincaré distances."""
    def poincare_dist(u, v):
        norm_u = np.linalg.norm(u)
        norm_v = np.linalg.norm(v)
        norm_diff = np.linalg.norm(u - v)

        denom = (1 - norm_u ** 2) * (1 - norm_v ** 2)
        if denom <= 0:
            return float('inf')

        argument = 1 + 2 * (norm_diff ** 2) / denom
        return np.arccosh(argument)

    return squareform(pdist(embedding_matrix, metric=poincare_dist))

    # Check that all embeddings lie within the unit ball
    norms = np.linalg.norm(embedding_matrix, axis=1)
    if np.any(norms >= 1):
        raise ValueError("Some embeddings lie outside the Poincaré ball (norm >= 1).")

    distance_matrix = compute_poincare_distance_matrix(embedding_matrix)

    # UMAP dimensionality reduction
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=umap_dimensions,
        metric='precomputed',
        random_state=42
    )
    embedding_umap = reducer.fit_transform(distance_matrix)

    # HDBSCAN clustering
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size,
        min_samples=1,
        metric='precomputed'
    )
    cluster_labels = clusterer.fit_predict(distance_matrix)

    # Add UMAP and cluster results to merged_df
    for dim in range(umap_dimensions):
        merged_df[f'UMAP {dim+1}'] = embedding_umap[:, dim]
    merged_df['cluster'] = cluster_labels

    # Drop original embedding dimensions before saving
    output_df = merged_df.drop(columns=coord_columns)

    # Output directory and file
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, 'metadata_clusters_poincare_fast_ge.tsv')
    output_df.to_csv(output_file, sep='\t', index=False)


    print(f"✅ Updated metadata with clusters saved to: {output_file}")
    return output_file

In [19]:
perform_umap_clustering_parquet(
    metadata_parquet_dir="/Users/polina/Pathwaganda/data/target_metadata/ge_merge/Reactome_Pathways_2025_diy/diseaseId=EFO_0000094/",
    coordinates_parquet_dir="/Users/polina/Pathwaganda/data/target_embeddings/Reactome_Pathways_2025_diy/diseaseId=EFO_0000094/",
    output_dir="/Users/polina/Pathwaganda/data/umap/test",
    n_neighbors=5,
    min_dist=0.7,
    min_cluster_size=5,
    umap_dimensions=2
)


Merged metadata and coordinates: 528 entries.


  warn("using precomputed metric; inverse_transform will be unavailable")
  warn(
failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!
  warn(


✅ Updated metadata with clusters saved to: /Users/polina/Pathwaganda/data/umap/test/metadata_clusters_poincare_fast_ge.tsv


'/Users/polina/Pathwaganda/data/umap/test/metadata_clusters_poincare_fast_ge.tsv'