This code is pasted from IBD_pathway_to_cell from similarity_mvp and should be rewritten to spark to process data more effeciently.

In [17]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, split, array_distinct, udf
from pyspark.ml.feature import MinHashLSH
from pyspark.sql.types import FloatType
import pyspark.sql.functions as F
import os
import numpy as np
import gcsfs
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
from itertools import combinations

In [6]:
spark = SparkSession.builder.getOrCreate()

# Similarity matrix for propagation results

## Misnhash

For MVP: 

presence or absence of target in pathways (0/1)

Jaccard for distance

In [None]:
def calculate_jaccard_similarity_spark_minhash(input_gcs_dir, output_gcs_dir, folders_to_process, num_hash_tables=5):
    """
    Process CSV files in specified folders within a GCS directory and calculate Jaccard similarity matrices using Spark MinHashLSH (approximate).

    Args:
        input_gcs_dir (str): Input GCS directory path.
        output_gcs_dir (str): Output GCS directory path.
        folders_to_process (list): List of folder names within the input directory to process.
        num_hash_tables (int): Number of hash tables for MinHash (higher = more accuracy, slower).

    Output:
        Saves approximate similarity results into output GCS directory.
    """

    input_gcs_dir = input_gcs_dir.rstrip("/")
    output_gcs_dir = output_gcs_dir.rstrip("/")

    for folder_name in folders_to_process:
        folder_path = f"{input_gcs_dir}/{folder_name}"
        output_folder_path = f"{output_gcs_dir}/{folder_name}"

        # Read all CSVs inside the folder
        df = spark.read.option("header", True).csv(f"{folder_path}/*.csv")

        # Check if required columns exist
        expected_cols = {'propagated_edge', 'Term'}
        if not expected_cols.issubset(set(df.columns)):
            print(f"Skipping folder {folder_name}: missing required columns.")
            continue

        # Explode propagated_edge column
        df_exploded = df.withColumn(
            "propagated_edge_exploded", 
            explode(split(col("propagated_edge"), ","))
        ).dropna(subset=["propagated_edge_exploded"])

        # Group by target and collect associated terms
        target_terms = df_exploded.groupBy("propagated_edge_exploded") \
            .agg(F.collect_set("Term").alias("terms"))

        # Index the terms (since MinHash needs fixed-size numeric vectors)
        # Create a dictionary: term -> index
        all_terms = df_exploded.select("Term").distinct().rdd.map(lambda r: r[0]).collect()
        term_to_index = {term: idx for idx, term in enumerate(all_terms)}

        # Transform terms into SparseVectors
        def terms_to_vector(terms):
            indices = [term_to_index[term] for term in terms if term in term_to_index]
            values = [1.0] * len(indices)
            size = len(term_to_index)
            return Vectors.sparse(size, list(zip(indices, values)))

        terms_to_vector_udf = udf(terms_to_vector, VectorUDT())


        target_terms = target_terms.withColumn("features", terms_to_vector_udf(col("terms")))

        # Initialize MinHashLSH model
        mh = MinHashLSH(inputCol="features", outputCol="hashes", numHashTables=num_hash_tables)
        model = mh.fit(target_terms)

        # Compute pairwise similarities
        similarities = model.approxSimilarityJoin(target_terms, target_terms, threshold=1.0, distCol="JaccardDistance")

        # Select relevant columns and format
        similarity_df = similarities.select(
            col("datasetA.propagated_edge_exploded").alias("target1"),
            col("datasetB.propagated_edge_exploded").alias("target2"),
            (1 - col("JaccardDistance")).alias("similarity")
        ).filter("target1 <= target2")  # Avoid duplicates (symmetry)

        # Save the similarity results
        similarity_df.write.mode('overwrite').option("header", True).csv(f"{output_folder_path}")

        print(f"Output: {output_folder_path}")

In [8]:
gsea_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/gsea_output"
output_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark"

# library = ["KEGG_2021_Human"]
library = ["test"]

calculate_jaccard_similarity_spark_minhash(gsea_dir, output_dir, library)

                                                                                

Processed (MinHash) and uploaded: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/test


In [9]:
df = spark.read.csv("gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/test", header=True, inferSchema=True)

                                                                                

In [10]:
df.show(5)

+-------+-------+-------------------+
|target1|target2|         similarity|
+-------+-------+-------------------+
|  ABCA1|  MYD88|0.06666666666666665|
|  ACTG1|  MED16| 0.1428571428571428|
|  ACTN1|   CDK6|0.04166666666666663|
|  ACTN4| GTF2A1|                0.5|
|  ACTN4| ARPC5L|0.33333333333333326|
+-------+-------+-------------------+
only showing top 5 rows



In [11]:
df.count()

                                                                                

513150

### MinHashLSH vs Exact Jaccard

In [24]:
def jaccard_similarity(set1, set2):
    if not set1 or not set2:
        return 0.0
    intersection = len(set(set1).intersection(set(set2)))
    union = len(set(set1).union(set(set2)))
    return float(intersection) / float(union)

def compare_jaccard_exact_vs_minhash(df_input, num_hash_tables=5):
    spark = SparkSession.builder.getOrCreate()

    # Explode propagated_edge
    df = df_input.withColumn(
        "propagated_edge_exploded",
        explode(split(col("propagated_edge"), ","))
    ).dropna(subset=["propagated_edge_exploded"])

    # Group: each node → list of terms
    target_terms = df.groupBy("propagated_edge_exploded") \
        .agg(F.collect_set("Term").alias("terms"))

    # Get all unique terms and assign indices
    all_terms = df.select("Term").distinct().rdd.map(lambda r: r[0]).collect()
    term_to_index = {term: idx for idx, term in enumerate(all_terms)}

    def terms_to_sparse_vector(terms):
        indices = [term_to_index[t] for t in terms if t in term_to_index]
        return Vectors.sparse(len(term_to_index), [(i, 1.0) for i in indices])

    terms_to_vector_udf = udf(terms_to_sparse_vector, VectorUDT())

    target_terms = target_terms.withColumn("features", terms_to_vector_udf(col("terms")))

    # Fit MinHashLSH
    mh = MinHashLSH(inputCol="features", outputCol="hashes", numHashTables=num_hash_tables)
    model = mh.fit(target_terms)

    # Compute MinHash similarities
    approx_sim = model.approxSimilarityJoin(target_terms, target_terms, threshold=1, distCol="JaccardDistance") \
        .filter("datasetA.propagated_edge_exploded <= datasetB.propagated_edge_exploded") \
        .select(
            col("datasetA.propagated_edge_exploded").alias("target1"),
            col("datasetA.terms").alias("terms1"),
            col("datasetB.propagated_edge_exploded").alias("target2"),
            col("datasetB.terms").alias("terms2"),
            (1 - col("JaccardDistance")).alias("approx_jaccard")
        )

    # Compute exact Jaccard
    jaccard_udf = udf(jaccard_similarity, FloatType())
    with_exact = approx_sim.withColumn("exact_jaccard", jaccard_udf(col("terms1"), col("terms2")))

    # Calculate % error
    with_error = with_exact.withColumn(
        "percent_error",
        F.when(col("exact_jaccard") > 0,
               F.abs(col("approx_jaccard") - col("exact_jaccard")) / col("exact_jaccard") * 100)
         .otherwise(None)
    )

    # Clean output
    final_result = with_error.select("target1", "target2", "approx_jaccard", "exact_jaccard", "percent_error")

    return final_result

In [25]:
df_compare = spark.read.option("header", True).csv("gs://ot-team/polina/pathway_propagation_validation_v2/gsea_output/test/*.csv")
result_df_compare = compare_jaccard_exact_vs_minhash(df_compare, num_hash_tables=5)

In [26]:
# Show comparison
result_df_compare.orderBy("percent_error", ascending=False).show(50, truncate=False)

# # Save
# result_df_compare.write.mode("overwrite").option("header", True).csv("gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/compare_exact_vs_minhash")



+--------+-------+--------------------+-------------+--------------------+
|target1 |target2|approx_jaccard      |exact_jaccard|percent_error       |
+--------+-------+--------------------+-------------+--------------------+
|CALCOCO2|HRAS   |0.015873015873015817|0.015873017  |5.774199984473481E-6|
|KRAS    |SNW1   |0.015873015873015817|0.015873017  |5.774199984473481E-6|
|CALCOCO2|KRAS   |0.015873015873015817|0.015873017  |5.774199984473481E-6|
|HRAS    |SNW1   |0.015873015873015817|0.015873017  |5.774199984473481E-6|
|HDAC1   |NRAS   |0.12698412698412698 |0.12698413   |5.774199634753249E-6|
|KRAS    |RBPJ   |0.031746031746031744|0.031746034  |5.774199634753249E-6|
|HLA-E   |KRAS   |0.12698412698412698 |0.12698413   |5.774199634753249E-6|
|HDAC2   |NRAS   |0.12698412698412698 |0.12698413   |5.774199634753249E-6|
|NRAS    |SKP2   |0.06349206349206349 |0.06349207   |5.774199634753249E-6|
|CAMK2A  |PIK3R1 |0.12698412698412698 |0.12698413   |5.774199634753249E-6|
|HLA-A   |HRAS   |0.12698

                                                                                

In [27]:
result_df_compare.count()

                                                                                

513150

In [31]:
result_df_compare.filter(col("percent_error") != 5.774199634753249E-6).show()

[Stage 155:>                                                        (0 + 1) / 1]

+-------+-------+-------------------+-------------+--------------------+
|target1|target2|     approx_jaccard|exact_jaccard|       percent_error|
+-------+-------+-------------------+-------------+--------------------+
|  ABCA1|  MYD88|0.06666666666666665|   0.06666667|5.215406168046513E-6|
|  ACTG1|  MED16| 0.1428571428571428|   0.14285715| 4.47034820272308E-6|
|  ACTN1|   CDK6|0.04166666666666663|  0.041666668|2.980232238769531...|
|  ACTN4| GTF2A1|                0.5|          0.5|                 0.0|
|  ACTN4| ARPC5L|0.33333333333333326|   0.33333334|2.980232172156152E-6|
| ACTR10|  DCTN6|                1.0|          1.0|                 0.0|
| ACTR1A|   GNAQ|0.05882352941176472|   0.05882353|3.725290062539522...|
|  ACTR2|S100A10|                0.5|          0.5|                 0.0|
| ACTR3B|   IRF3|0.06666666666666665|   0.06666667|5.215406168046513E-6|
| ADAM17| NOTCH4|0.16666666666666663|   0.16666667|2.980232172156152E-6|
|  ADCY4|  PDGFA|0.19999999999999996|          0.2|

                                                                                

Results seem identical so we can preliminarly use Minhash for similarity calculation. But lets just check if without its too slow.

## Jaccard

In [33]:
def calculate_jaccard_similarity_spark_exact(input_gcs_dir, output_gcs_dir, folders_to_process):
    """
    Process CSV files in specified folders within a GCS directory and calculate exact Jaccard similarity matrices.

    Args:
        input_gcs_dir (str): Input GCS directory path.
        output_gcs_dir (str): Output GCS directory path.
        folders_to_process (list): List of folder names within the input directory to process.

    Output:
        Saves exact similarity results into output GCS directory.
    """

    input_gcs_dir = input_gcs_dir.rstrip("/")
    output_gcs_dir = output_gcs_dir.rstrip("/")

    def jaccard_similarity(set1, set2):
        if not set1 or not set2:
            return 0.0
        intersection = len(set(set1).intersection(set(set2)))
        union = len(set(set1).union(set(set2)))
        return float(intersection) / float(union)

    jaccard_udf = udf(jaccard_similarity, FloatType())

    for folder_name in folders_to_process:
        folder_path = f"{input_gcs_dir}/{folder_name}"
        output_folder_path = f"{output_gcs_dir}/{folder_name}"

        # Read all CSVs inside the folder
        df = spark.read.option("header", True).csv(f"{folder_path}/*.csv")

        # Check if required columns exist
        expected_cols = {'propagated_edge', 'Term'}
        if not expected_cols.issubset(set(df.columns)):
            print(f"Skipping folder {folder_name}: missing required columns.")
            continue

        # Explode propagated_edge column
        df_exploded = df.withColumn(
            "propagated_edge_exploded", 
            explode(split(col("propagated_edge"), ","))
        ).dropna(subset=["propagated_edge_exploded"])

        # Group by target and collect associated terms
        target_terms = df_exploded.groupBy("propagated_edge_exploded") \
            .agg(F.collect_set("Term").alias("terms"))

        # Create a DataFrame with all pairs of targets
        target_terms_alias1 = target_terms.alias("target1")
        target_terms_alias2 = target_terms.alias("target2")
        
        # Cross join to get all pairs, then filter to avoid duplicates
        pairs = target_terms_alias1.crossJoin(target_terms_alias2) \
            .filter("target1.propagated_edge_exploded <= target2.propagated_edge_exploded")

        # Calculate exact Jaccard similarity for each pair
        similarity_df = pairs.select(
            col("target1.propagated_edge_exploded").alias("target1"),
            col("target2.propagated_edge_exploded").alias("target2"),
            jaccard_udf(col("target1.terms"), col("target2.terms")).alias("similarity")
        )

        # Save the similarity results
        similarity_df.write.mode('overwrite').option("header", True).csv(f"{output_folder_path}")

        print(f"Output: {output_folder_path}")

In [34]:
gsea_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/gsea_output"
output_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark"

# library = ["KEGG_2021_Human"]
library = ["test"]

calculate_jaccard_similarity_spark_exact(gsea_dir, output_dir, library)

                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/test


In [35]:
df_show = spark.read.csv("gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/test", header=True, inferSchema=True)

                                                                                

In [44]:
df_show.filter(col("target1") == "CDK4").show(10)

[Stage 174:>                                                        (0 + 1) / 1]

+-------+-------+----------+
|target1|target2|similarity|
+-------+-------+----------+
|   CDK4|   CDK4|       1.0|
|   CDK4|   CDK5|       0.0|
|   CDK4| CDK5R1|       0.0|
|   CDK4|   CDK6|0.84615386|
|   CDK4|   CDK7|      0.04|
|   CDK4| CDKN1A| 0.5945946|
|   CDK4| CDKN1B| 0.3548387|
|   CDK4| CDKN1C|      0.04|
|   CDK4| CDKN2A| 0.5769231|
|   CDK4| CDKN2B|0.25925925|
+-------+-------+----------+
only showing top 10 rows



                                                                                