In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql import functions as F

In [2]:
spark = SparkSession.builder.getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/11 11:11:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
facets = spark.read.parquet("/Users/polina/Pathwaganda/data/search_facet_target")

                                                                                

In [4]:
facets.show(10)

+--------------------+-------------+-----------------+------------+
|               label|     category|        entityIds|datasourceId|
+--------------------+-------------+-----------------+------------+
|2'-5'-oligoadenyl...|Approved Name|[ENSG00000111335]|        NULL|
|3-hydroxybutyrate...|Approved Name|[ENSG00000161267]|        NULL|
|5'-aminolevulinat...|Approved Name|[ENSG00000158578]|        NULL|
|A-kinase anchorin...|Approved Name|[ENSG00000170776]|        NULL|
|A-kinase interact...|Approved Name|[ENSG00000166452]|        NULL|
| ABI family member 3|Approved Name|[ENSG00000108798]|        NULL|
|ACBD3 antisense R...|Approved Name|[ENSG00000234478]|        NULL|
|  ACCSL pseudogene 1|Approved Name|[ENSG00000230570]|        NULL|
|ACD shelterin com...|Approved Name|[ENSG00000102977]|        NULL|
|  ACTB pseudogene 12|Approved Name|[ENSG00000233125]|        NULL|
+--------------------+-------------+-----------------+------------+
only showing top 10 rows


In [5]:
facets.select("category").distinct().show(truncate=False)

+-----------------------------+
|category                     |
+-----------------------------+
|Approved Name                |
|Approved Symbol              |
|Target ID                    |
|GO:CC                        |
|GO:MF                        |
|GO:BP                        |
|Subcellular Location         |
|Tractability PROTAC          |
|Tractability Other Modalities|
|Tractability Antibody        |
|Reactome                     |
|Tractability Small Molecule  |
|ChEMBL Target Class          |
+-----------------------------+



In [6]:
facets.filter(col("category") == "GO:CC").show()

+--------------------+--------+--------------------+------------+
|               label|category|           entityIds|datasourceId|
+--------------------+--------+--------------------+------------+
|           7SK snRNP|   GO:CC|[ENSG00000174720,...|  GO:0120259|
|              A band|   GO:CC|[ENSG00000134571,...|  GO:0031672|
|    CCR4-NOT complex|   GO:CC|[ENSG00000138767,...|  GO:0030014|
|    ESC/E(Z) complex|   GO:CC|[ENSG00000108799,...|  GO:0035098|
|IgM immunoglobuli...|   GO:CC|[ENSG00000211899,...|  GO:0071753|
|L-type voltage-ga...|   GO:CC|[ENSG00000075461,...|  GO:1990454|
|Rad51B-Rad51C-Rad...|   GO:CC|[ENSG00000196584,...|  GO:0033063|
|       TORC2 complex|   GO:CC|[ENSG00000149212,...|  GO:0031932|
|Toll-like recepto...|   GO:CC|[ENSG00000174125,...|  GO:0035354|
|  ciliary basal body|   GO:CC|[ENSG00000118965,...|  GO:0036064|
|cis-Golgi network...|   GO:CC|[ENSG00000099246,...|  GO:0033106|
|collagen type V t...|   GO:CC|[ENSG00000130635,...|  GO:0005588|
|collagen 

# Prepare gmt files

In [9]:
from pyspark.sql import functions as F

def export_gmt_files(spark_df, categories, output_dir, target_parquet):
    # Read target parquet and select required columns
    target_df = (
        spark.read.parquet(target_parquet)
        .select(F.col("id").alias("geneId"), "approvedSymbol")
    )

    # Add Term column
    df = spark_df.withColumn(
        "Term",
        F.concat_ws("", F.col("label"), F.lit("{"), F.col("datasourceId"), F.lit("}"))
    )

    for cat in categories:
        # Filter by category
        df_cat = df.filter(F.col("category") == cat)

        # Explode entityIds
        df_cat = df_cat.withColumn("geneId", F.explode("entityIds"))

        # Join with target to bring approvedSymbol
        df_cat = df_cat.join(target_df, on="geneId", how="inner")

        # Group by Term and collect symbols
        df_grouped = (
            df_cat.groupBy("Term")
                  .agg(F.collect_list("approvedSymbol").alias("symbols"))
        )

        # Convert Spark -> Pandas
        pdf = df_grouped.toPandas()

        # Format into GMT lines using approvedSymbol
        pdf["gmt_line"] = pdf.apply(
            lambda row: row["Term"] + "\t" + "\t".join(row["symbols"]),
            axis=1
        )

        # Save as a single .gmt file
        output_path = f"{output_dir}/{cat}.gmt"
        pdf["gmt_line"].to_csv(output_path, index=False, header=False)

In [10]:
categories = ["Reactome", "GO:BP", "GO:MF", "GO:CC"]

# Define output directory and target parquet
output_dir = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets"
target_parquet = "/Users/polina/Pathwaganda/data/target"

# Call the function
export_gmt_files(facets, categories, output_dir, target_parquet)


## Add propagation for ancestors 

In [14]:
import json
import re
from collections import defaultdict

def propagate_targets_to_gmt(go_json_path, gmt_path, output_path, namespace="biological_process"):
    # --- Load GO JSON ---
    with open(go_json_path, "r") as f:
        go_data = json.load(f)

    nodes = {n["id"]: n for n in go_data["graphs"][0]["nodes"] if n["type"] == "CLASS"}
    
    # Map GO:XXXXXXX → {namespace, label, deprecated}
    go_info = {}
    for node_id, node in nodes.items():
        go_id = "GO:" + node_id.split("_")[-1]  # GO_0000014 → GO:0000014
        lbl = node.get("lbl", go_id)
        deprecated = node.get("meta", {}).get("deprecated", False)
        ns = None
        for bpv in node.get("meta", {}).get("basicPropertyValues", []):
            if bpv["pred"].endswith("hasOBONamespace"):
                ns = bpv["val"]
        go_info[go_id] = {"label": lbl, "namespace": ns, "deprecated": deprecated}

    # --- Build child → parent relations ---
    child_to_parents = defaultdict(set)
    for edge in go_data["graphs"][0].get("edges", []):
        child = edge["sub"]
        parent = edge["obj"]
        if edge["pred"].endswith(("is_a", "BFO_0000050")):  # is_a or part_of
            child_to_parents[child].add(parent)

    # Helper: get all ancestors
    def get_ancestors(go_uri):
        ancestors = set()
        stack = [go_uri]
        while stack:
            cur = stack.pop()
            for parent in child_to_parents.get(cur, []):
                if parent not in ancestors:
                    ancestors.add(parent)
                    stack.append(parent)
        return ancestors

    # --- Parse GMT ---
    propagated = defaultdict(set)
    term_labels = {}  # keep label{GO:XXXX} for writing later

    with open(gmt_path, "r") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) < 2:
                continue
            desc, genes = parts[0], parts[1:]
            m = re.search(r"\{(GO:\d+)\}", desc)
            if not m:
                continue
            go_term = m.group(1)

            # Skip if not in correct namespace or deprecated
            if go_info.get(go_term, {}).get("namespace") != namespace:
                continue
            if go_info.get(go_term, {}).get("deprecated"):
                continue

            # Add genes to original GO term
            propagated[go_term].update(genes)
            term_labels[go_term] = desc  # keep original term name

            # Add genes to ancestors
            go_uri = f"http://purl.obolibrary.org/obo/{go_term.replace(':', '_')}"
            for ancestor in get_ancestors(go_uri):
                ancestor_go = "GO:" + ancestor.split("_")[-1]
                info = go_info.get(ancestor_go, {})
                if info.get("namespace") == namespace and not info.get("deprecated"):
                    propagated[ancestor_go].update(genes)
                    # If no label was stored, generate one
                    if ancestor_go not in term_labels:
                        term_labels[ancestor_go] = f"{info.get('label', ancestor_go)}{{{ancestor_go}}}"

    # --- Write GMT ---
    with open(output_path, "w") as out:
        for go_id, genes in propagated.items():
            label = term_labels.get(go_id, f"{go_id}{{{go_id}}}")
            out.write(label + "\t" + "\t".join(sorted(genes)) + "\n")

    print(f"Propagated GMT saved to {output_path}")


In [17]:
go_hirarchies = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/hierarchy_files/go-basic.json"
go_gmt = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/GO:BP.gmt"
output_path = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:BP.gmt"

result = propagate_targets_to_gmt(go_hirarchies, go_gmt, output_path, namespace="biological_process")

Propagated GMT saved to /Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:BP.gmt


In [18]:
go_hirarchies = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/hierarchy_files/go-basic.json"
go_gmt = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/GO:CC.gmt"
output_path = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:CC.gmt"

result = propagate_targets_to_gmt(go_hirarchies, go_gmt, output_path, namespace="cellular_component")

Propagated GMT saved to /Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:CC.gmt


In [None]:
go_hirarchies = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/hierarchy_files/go-basic.json"
go_gmt = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/GO:MF.gmt"
output_path = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:MF.gmt"

result = propagate_targets_to_gmt(go_hirarchies, go_gmt, output_path, namespace="molecular_function")

Propagated GMT saved to /Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/propagated/GO:MF.gmt


25/09/11 13:41:21 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 904253 ms exceeds timeout 120000 ms
25/09/11 13:41:21 WARN SparkContext: Killing executors is not supported by current scheduler.
25/09/11 13:41:25 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:81)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:669)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1296)
	at o

## Prepare hierarchical files from GO_terms json

Save in reactome format - txt file with 2 columns: parent-child

In [5]:
import json
import os

def extract_go_relationships(json_file, output_dir):
    # Load JSON
    with open(json_file, "r") as f:
        data = json.load(f)

    graph = data["graphs"][0]
    nodes = graph["nodes"]
    edges = graph["edges"]

    # Map namespace -> output filename
    ns_to_file = {
        "biological_process": "go_bp.txt",
        "cellular_component": "go_cc.txt",
        "molecular_function": "go_mf.txt"
    }

    # Build node lookup: id -> (namespace, deprecated, label)
    node_info = {}
    for node in nodes:
        meta = node.get("meta", {})
        namespace = None
        deprecated = meta.get("deprecated", False)
        label = node.get("lbl", "")

        for bpv in meta.get("basicPropertyValues", []):
            if bpv["pred"].endswith("hasOBONamespace"):
                namespace = bpv["val"]
                break

        node_info[node["id"]] = {
            "namespace": namespace,
            "deprecated": deprecated or label.startswith("obsolete"),
        }

    # Collect relationships per category
    category_edges = {ns: set() for ns in ns_to_file}

    for edge in edges:
        pred = edge["pred"]

        # Only use safe relations
        if pred not in ("is_a", "part_of"):
            continue

        child = edge["sub"]
        parent = edge["obj"]

        if child not in node_info:
            continue

        ns = node_info[child]["namespace"]
        if ns not in ns_to_file:
            continue

        if node_info[child]["deprecated"]:
            continue

        parent_id = parent.split("/")[-1]
        child_id = child.split("/")[-1]

        category_edges[ns].add((parent_id, child_id))

    # Save results
    os.makedirs(output_dir, exist_ok=True)

    for ns, edges in category_edges.items():
        file_path = os.path.join(output_dir, ns_to_file[ns])
        with open(file_path, "w") as f:
            for parent, child in sorted(edges):
                f.write(f"{parent}\t{child}\n")

    print(f"Saved {len(ns_to_file)} files in {output_dir}")

In [6]:
go_hirarchies = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/hierarchy_files/go-basic.json"
output_hier = "/Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/hierarchies"

extract_go_relationships(go_hirarchies, output_hier)

Saved 3 files in /Users/polina/Pathwaganda/data/gmt_pathway_files_prep/from_facets/hierarchies
