In [37]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType
import pandas as pd
import numpy as np

In [38]:
spark = SparkSession.builder.getOrCreate()

# Similarity calculation

In [44]:
from pyspark.sql import functions as F, types as T

def compute_poincare_similarities(
    input_parquet_path: str,
    output_parquet_path: str = None,
    similarity_mode: str = "inverse_distance",
    epsilon: float = 1e-9
):
    # Load
    df = spark.read.parquet(input_parquet_path)

    # --- Handle pathwayIds (split into arrays) ---
    df = df.withColumn(
        "pathwayIdsArr",
        F.split(F.col("pathwayIds"), ",")
    )

    # Find and order hyperbolic coord columns
    dim_cols = [c for c in df.columns if c.startswith("dim_")]
    if not dim_cols:
        raise ValueError("No dim_* columns found.")
    dim_cols = sorted(dim_cols, key=lambda c: int(c.split("_")[1]))

    # Drop rows with null coordinates
    df = df.dropna(subset=dim_cols)

    # Build arrays for vectorized UDF input
    dfv = df.select(
        "diseaseId",
        "targetId",
        "approvedSymbol",
        "pathwayIdsArr",
        F.array(*[F.col(c).cast("double") for c in dim_cols]).alias("vec")
    )

    # Repartition by disease
    dfv = dfv.repartition("diseaseId")

    # Self-join within disease, upper triangle only
    a = dfv.alias("a")
    b = dfv.alias("b")
    pairs = (
        a.join(b, on="diseaseId")
         .where(F.col("a.targetId") < F.col("b.targetId"))
    )

    # UDF: Poincaré distance
    @F.udf("double")
    def poincare_distance(vec_a, vec_b):
        import math
        ua = [float(x) for x in vec_a]
        ub = [float(x) for x in vec_b]
        norm_a2 = sum(x*x for x in ua)
        norm_b2 = sum(x*x for x in ub)
        if norm_a2 >= 1.0:
            scale = (1.0 - epsilon) / max(math.sqrt(norm_a2), 1.0)
            ua = [x * scale for x in ua]
            norm_a2 = sum(x*x for x in ua)
        if norm_b2 >= 1.0:
            scale = (1.0 - epsilon) / max(math.sqrt(norm_b2), 1.0)
            ub = [x * scale for x in ub]
            norm_b2 = sum(x*x for x in ub)
        diff2 = sum((x - y)**2 for x, y in zip(ua, ub))
        denom = (1.0 - norm_a2) * (1.0 - norm_b2)
        denom = max(denom, epsilon)
        z = 1.0 + (2.0 * diff2) / denom
        z = max(z, 1.0)
        return float(math.acosh(z))

    # Similarity transform
    @F.udf("double")
    def to_similarity(d):
        import math
        if similarity_mode == "inverse_distance":
            return 1.0 / (1.0 + float(d))
        elif similarity_mode == "exp_decay":
            return math.exp(-float(d))
        elif similarity_mode == "negative_distance":
            return -float(d)
        else:
            return 1.0 / (1.0 + float(d))

    # Compute scores + commonPathwayIds
    scored = (
        pairs.select(
            F.col("a.diseaseId").alias("diseaseId"),
            F.col("a.targetId").alias("targetIdA"),
            F.col("a.approvedSymbol").alias("approvedSymbolA"),
            F.col("b.targetId").alias("targetIdB"),
            F.col("b.approvedSymbol").alias("approvedSymbolB"),
            poincare_distance(F.col("a.vec"), F.col("b.vec")).alias("distance"),
            # intersection of pathwayIds
            F.array_intersect(F.col("a.pathwayIdsArr"), F.col("b.pathwayIdsArr")).alias("commonPathwayIdsArr")
        )
        .withColumn("similarScore", to_similarity(F.col("distance")))
        .withColumn("commonPathwayIds", F.array_join(F.col("commonPathwayIdsArr"), ","))  # back to string
        .drop("distance", "commonPathwayIdsArr")
    )

    # Optional write
    if output_parquet_path is not None:
        (scored
         .repartition("diseaseId")
         .write
         .mode("overwrite")
         .partitionBy("diseaseId")
         .parquet(output_parquet_path))

    return scored


In [45]:
sim_level0 = compute_poincare_similarities(input_parquet_path = "/Users/polina/Pathwaganda/data/tem_levels/0")

In [46]:
sim_level1 = compute_poincare_similarities(input_parquet_path = "/Users/polina/Pathwaganda/data/tem_levels/1")

In [47]:
sim_level2 = compute_poincare_similarities(input_parquet_path = "/Users/polina/Pathwaganda/data/tem_levels/2")

In [None]:
sim_level2.show(25)

[Stage 46:>                                                         (0 + 1) / 1]

+-----------+---------------+---------------+---------------+---------------+-------------------+----------------+
|  diseaseId|      targetIdA|approvedSymbolA|      targetIdB|approvedSymbolB|       similarScore|commonPathwayIds|
+-----------+---------------+---------------+---------------+---------------+-------------------+----------------+
|EFO_0005803|ENSG00000096264|           NCR2|ENSG00000127418|         FGFRL1|0.23141577391904944|                |
|EFO_0005803|ENSG00000096264|           NCR2|ENSG00000167772|        ANGPTL4|0.19362613348678467|                |
|EFO_0005803|ENSG00000096264|           NCR2|ENSG00000177302|          TOP3A|  0.282482787962711|                |
|EFO_0005803|ENSG00000096264|           NCR2|ENSG00000211677|          IGLC2|0.42901729706589103|                |
|EFO_0005803|ENSG00000096264|           NCR2|ENSG00000125398|           SOX9|0.22785641904153134|                |
+-----------+---------------+---------------+---------------+---------------+---

                                                                                

In [49]:
sim_level2.filter((col("approvedSymbolA") == "CDK2") & (col("approvedSymbolB") == "CDK4")).show(5)

+-----------+---------------+---------------+---------------+---------------+-------------------+--------------------+
|  diseaseId|      targetIdA|approvedSymbolA|      targetIdB|approvedSymbolB|       similarScore|    commonPathwayIds|
+-----------+---------------+---------------+---------------+---------------+-------------------+--------------------+
|EFO_0000313|ENSG00000123374|           CDK2|ENSG00000135446|           CDK4| 0.6970239959175557|R-HSA-212436,R-HS...|
|EFO_0000319|ENSG00000123374|           CDK2|ENSG00000135446|           CDK4| 0.5500472420363286|R-HSA-8848021,R-H...|
|EFO_0000508|ENSG00000123374|           CDK2|ENSG00000135446|           CDK4| 0.6498712628357501|R-HSA-912446,R-HS...|
|EFO_0000616|ENSG00000123374|           CDK2|ENSG00000135446|           CDK4| 0.6878854869170765|R-HSA-2559583,R-H...|
|EFO_0000618|ENSG00000123374|           CDK2|ENSG00000135446|           CDK4|0.37127843065302685|        R-HSA-212436|
+-----------+---------------+---------------+---

                                                                                

In [None]:
sim_level2_filt = sim_level2.filter(col("diseaseId") == "EFO_0000313").filter(col("commonPathwayIds") == "")
sim_level2_filt.sort("similarScore", ascending=False).show(10)

[Stage 60:>                                                         (0 + 1) / 1]

+-----------+--------------------+---------------+--------------------+---------------+------------------+----------------+
|  diseaseId|           targetIdA|approvedSymbolA|           targetIdB|approvedSymbolB|      similarScore|commonPathwayIds|
+-----------+--------------------+---------------+--------------------+---------------+------------------+----------------+
|EFO_0000313|ENSG00000109618,E...|            SLA|     ENSG00000119535|          CSF3R| 0.995733782468857|                |
|EFO_0000313|     ENSG00000101082|           SLA2|     ENSG00000119535|          CSF3R| 0.995733782468857|                |
|EFO_0000313|     ENSG00000134852|          CLOCK|     ENSG00000149196|        HIKESHI|0.9945789799730248|                |
|EFO_0000313|     ENSG00000104499|            GML|     ENSG00000134852|          CLOCK|0.9945789799730248|                |
|EFO_0000313|     ENSG00000164070|         HSPA4L|ENSG00000258436,E...|           RAI1|0.9945789799730248|                |
|EFO_000

                                                                                

25/08/22 19:41:49 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 982357 ms exceeds timeout 120000 ms
25/08/22 19:41:49 WARN SparkContext: Killing executors is not supported by current scheduler.
25/08/22 19:44:36 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:81)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:669)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1296)
	at o

Seems like we should propagate for only those targets for which 

Let's use similarities for score propagation straight away not to consume lots of memory.

# Score propagation