In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, split, array_distinct, udf
from pyspark.ml.feature import MinHashLSH
from pyspark.sql.types import FloatType
import pyspark.sql.functions as F
import os
import numpy as np
import gcsfs
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
from itertools import combinations

In [2]:
spark = SparkSession.builder.getOrCreate()

# Similarity matrix for propagation results

For future speed optiomisation use Minhash for similarity calculation.

## Jaccard

In [3]:
def calculate_jaccard_similarity_spark_exact(input_gcs_dir, output_gcs_dir, folders_to_process):
    """
    Process CSV files in specified GCS folders and calculate exact Jaccard similarity matrices.
    Creates one output folder per input CSV file.

    Args:
        input_gcs_dir (str): Input GCS directory path (gs://bucket/path/)
        output_gcs_dir (str): Output GCS directory path (gs://bucket/path/)
        folders_to_process (list): List of folder names to process
    """
    from pyspark.sql.functions import col, explode, split
    from pyspark.sql import functions as F

    input_gcs_dir = input_gcs_dir.rstrip("/")
    output_gcs_dir = output_gcs_dir.rstrip("/")

    def jaccard_similarity(set1, set2):
        if not set1 or not set2:
            return 0.0
        intersection = len(set(set1).intersection(set(set2)))
        union = len(set(set1).union(set(set2)))
        return float(intersection) / float(union)

    jaccard_udf = F.udf(jaccard_similarity, FloatType())

    for folder_name in folders_to_process:
        # Get list of CSV files using Spark's filesystem API
        csv_files = spark.sparkContext.wholeTextFiles(f"{input_gcs_dir}/{folder_name}/*.csv").keys().collect()
        
        for csv_file in csv_files:
            # Extract just the filename (last part of GCS path)
            csv_name = csv_file.split("/")[-1].replace(".csv", "")
            output_folder_path = f"{output_gcs_dir}/{folder_name}/{csv_name}"

            # Read single CSV file
            df = spark.read.option("header", True).csv(csv_file)

            # Check required columns
            expected_cols = {'propagated_edge', 'Term'}
            if not expected_cols.issubset(set(df.columns)):
                print(f"Skipping {csv_file}: missing required columns.")
                continue

            # Process data
            df_exploded = df.withColumn(
                "propagated_edge_exploded", 
                explode(split(col("propagated_edge"), ","))
            ).dropna(subset=["propagated_edge_exploded"])

            target_terms = df_exploded.groupBy("propagated_edge_exploded") \
                .agg(F.collect_set("Term").alias("terms"))

            # Cross join and filter
            target_terms_alias1 = target_terms.alias("targetA")
            target_terms_alias2 = target_terms.alias("targetB")
            pairs = target_terms_alias1.crossJoin(target_terms_alias2) \
                .filter("targetA.propagated_edge_exploded <= targetB.propagated_edge_exploded")

            # Calculate similarities
            similarity_df = pairs.select(
                col("targetA.propagated_edge_exploded").alias("targetA"),
                col("targetB.propagated_edge_exploded").alias("targetB"),
                jaccard_udf(col("targetA.terms"), col("targetB.terms")).alias("jaccardSim")
            )

            # Save results
            similarity_df.write.mode('overwrite').option("header", True).csv(output_folder_path)
            print(f"Output: {output_folder_path}")

In [4]:
gsea_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/gsea_output"
output_dir = "gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark"

# library = ["KEGG_2021_Human"]
library = ["Reactome_Pathways_2024"]
# library = ["test"]

calculate_jaccard_similarity_spark_exact(gsea_dir, output_dir, library)

                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000095_ge_mm_som_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000183_ge_mm_som_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000274_ge_mm_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000275_ge_mm_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000341_ge_mm_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000222_ge_mm_som_gsea_Reactome_Pathways_2024_pval0.05


                                                                                

Output: gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/Reactome_Pathways_2024/EFO_0000384_ge_mm_gsea_Reactome_Pathways_2024_pval0.05


ERROR:root:KeyboardInterrupt while sending command.                 (0 + 1) / 1]
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/miniconda3/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

                                                                                

In [57]:
df_show = spark.read.csv("gs://ot-team/polina/pathway_propagation_validation_v2/similarity_mtx/jaccard_spark/KEGG_2021_Human/EFO_0000565_ge_mm_som_gsea_KEGG_2021_Human_pval0.05", header=True, inferSchema=True)

                                                                                

In [62]:
df_show.filter(col("targetA") == "CDK2").show(20)
# df_show.show(10)

[Stage 476:>                                                        (0 + 1) / 1]

+-------+----------+-----------+
|targetA|   targetB| jaccardSim|
+-------+----------+-----------+
|   CDK2|      CDK2|        1.0|
|   CDK2|      CDK4| 0.41935483|
|   CDK2|      CDK5|        0.0|
|   CDK2|    CDK5R1|        0.0|
|   CDK2|      CDK6| 0.41379312|
|   CDK2|      CDK7|0.055555556|
|   CDK2|      CDK9|        0.0|
|   CDK2|    CDKN1A|  0.3902439|
|   CDK2|    CDKN1B|        0.5|
|   CDK2|    CDKN1C|0.055555556|
|   CDK2|    CDKN2A| 0.25925925|
|   CDK2|    CDKN2B|        0.5|
|   CDK2|    CDKN2C| 0.15789473|
|   CDK2|    CDKN2D| 0.11111111|
|   CDK2|      CDX2|0.055555556|
|   CDK2|     CEBPA|       0.05|
|   CDK2|     CEBPB|        0.0|
|   CDK2|     CEBPE|        0.0|
|   CDK2|     CENPS|        0.0|
|   CDK2|CENPS-CORT|        0.0|
+-------+----------+-----------+
only showing top 20 rows



                                                                                

Ok it's too long. Lets return to hash-tables.