<a href="https://colab.research.google.com/github/wikipathways/BioThings_Explorer_PFOCR_clustering/blob/main/bte_clustering_AA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Environment setup

In [None]:
#install the required packages
!pip install numpy pandas requests requests_cache SetSimilaritySearch fisher

In [1]:
#load packages
from urllib.request import urlopen
import requests
import json
from copy import copy,deepcopy
import pandas as pd
from fisher import pvalue_npy

## User Inputs

In [2]:
# which node ID from message.query_graph.nodes should be used for grouping. 
# The entities mapped to that node ID are the "result entities"
# If all then use all the node ids in query
# If specific node id set (e.g. {'MESH:D000068877', 'MESH:D001249'}) then use only these node ids
node_ID_for_grouping = "all"

# n corresponds to the number of results desired
n = 15

# user-specified node IDs for the subject, the object, both, or neither that must be included in the figures
# If all then use all the node ids in query
# If specific node id set (e.g. {'MESH:D000068877', 'MESH:D001249'}) then use only these node ids
required_curies = "all"

## Get BTE TRAPI Results

Master notebook: https://github.com/wikipathways/pathway-figure-ocr/blob/master/notebooks/bte_clustering.ipynb

This notebook reads in the TRAPI results from a URL for the query below.

### Query
Imatinib - [Gene] - Asthma

### Results URL
https://arax.ncats.io/api/arax/v1.3/response/7b14f961-9066-41f7-9e3b-d76b2b4a7fac (83kB, 7 results); "results" in n1

In [3]:
# store the URL in url as parameter for urlopen
trapi_results_url = "https://arax.ncats.io/api/arax/v1.3/response/7b14f961-9066-41f7-9e3b-d76b2b4a7fac"
   
# storing the JSON response from url
with urlopen(trapi_results_url) as url:
    trapi_results = json.load(url)

# print the json response
print(trapi_results)



In [4]:
# get the TRAPI results and knowledge graph from the jason response
trapi_message = trapi_results['message']

# print the TRAPI results and knowledge graph
print(trapi_message)

{'knowledge_graph': {'edges': {'02217ffa55351b0c3c0a32431158ca27': {'attributes': [{'attribute_type_id': 'biolink:aggregator_knowledge_source', 'value': ['infores:biothings-explorer'], 'value_type_id': 'biolink:InformationResource'}, {'attribute_source': 'infores:automat-ctd', 'attribute_type_id': 'biolink:original_knowledge_source', 'attributes': None, 'description': None, 'original_attribute_name': 'biolink:original_knowledge_source', 'value': ['infores:ctd'], 'value_type_id': 'biolink:InformationResource', 'value_url': None}, {'attribute_source': 'infores:automat-ctd', 'attribute_type_id': 'biolink:aggregator_knowledge_source', 'attributes': None, 'description': None, 'original_attribute_name': 'biolink:aggregator_knowledge_source', 'value': ['infores:automat-ctd'], 'value_type_id': 'biolink:InformationResource', 'value_url': None}, {'attribute_source': None, 'attribute_type_id': 'biolink:relation', 'attributes': None, 'description': None, 'original_attribute_name': 'relation', 'val

## Create a dataframe with all TRAPI results
  
The final dataframe is called ```trapi_results_df```

In [5]:
# NOTE: we're including categories specified in the query template
# in the next cell, and in a cell further down, we're also including
# any categories from the TRAPI results.
#
# TODO: Some categories are supersets of others. Should we handle
# this systematically?

# get the query graph from the TRAPI results
query = trapi_message["query_graph"]

curie_categories = set()
for v in query["nodes"].values():
    if "categories" in v:
        for category in v["categories"]:
            curie_categories.add(category)

In [6]:
# for genes/gene products, chemicals and diseases
# we need to resolve the CURIES as PFOCR data has only mesh ids and NCBI gene ids
preferred_prefixes = set(["NCBIGene", "MESH"])
trapi_results_unified_curies = set()
unification_failed_curies = set()
unified_prefixes = set()
all_prefixes = set()
curie_to_name = dict()
curie_to_categories = dict()
curie_to_unified_curie = dict()
for k, v in trapi_message["knowledge_graph"]["nodes"].items():
    name = v["name"]
    categories = v["categories"]
    curie_categories |= set(categories)

    for a in v["attributes"]:
        if a["attribute_type_id"] == "biolink:xref":
            curies = a["value"]

            # k should always be one of the curies
            if not k in curies:
                raise Exception(f"key {k} not in {curies}")

            unified_curie = None
            intersecting_trapi_results_unified_curies = trapi_results_unified_curies.intersection(
                set(curies)
            )
            if len(intersecting_trapi_results_unified_curies) > 1:
                multiple_matches = list(intersecting_trapi_results_unified_curies)
                raise Exception(f"matching multiple: {k} to {multiple_matches}")
            elif len(intersecting_trapi_results_unified_curies) == 1:
                unified_curie = list(intersecting_trapi_results_unified_curies)[0]
            else:
                # get curie for preferred prefix. usually this is k, but not always.
                for curie in curies:
                    [prefix, identifier] = curie.split(":")
                    if prefix in preferred_prefixes:
                        unified_curie = curie
                        trapi_results_unified_curies.add(unified_curie)
                        break

            if not unified_curie:
                if k in curie_to_unified_curie:
                    unified_curie = curie_to_unified_curie[k]
                else:
                    unification_failed_curies.add(k)
                    break

                #raise Exception(f"failed to find a unified curie for {k} in {curies}")

            [unified_prefix, unified_identifier] = unified_curie.split(":")
            unified_prefixes.add(unified_prefix)

            for curie in curies:
                [prefix, identifier] = curie.split(":")
                all_prefixes.add(prefix)
                if not curie in curie_to_unified_curie:
                    curie_to_unified_curie[curie] = unified_curie
                if not curie in curie_to_name:
                    curie_to_name[curie] = name
                    curie_to_categories[curie] = categories
                elif curie_to_name[curie] != name:
                    print(f"curie {curie} has multiple primary names: {curie_to_name[curie]} and {name}")
                    #raise Exception(f"curie{curie} has multiple names: {curie_to_name[curie]} and {name}")


#print(f'curie_categories: {curie_categories}')
#print("")
#print(f"curie_to_name key count: {len(curie_to_name.keys())}")
#print(f"curie_to_unified_curie key count: {len(curie_to_unified_curie.keys())}")
#print("")
print("all CURIE prefixes found:")
print(all_prefixes)
print("")
print("unified CURIE prefixes found:")
print(unified_prefixes)
print("")
print(f"failed to unify {len(unification_failed_curies)} CURIEs")

all CURIE prefixes found:
{'UMLS', 'HP', 'NCIT', 'OMIM', 'PUBCHEM.COMPOUND', 'CHEMBL.COMPOUND', 'HGNC', 'ENSEMBL', 'MESH', 'UniProtKB', 'SNOMEDCT', 'PR', 'CAS', 'CHEBI', 'INCHIKEY', 'NCBIGene', 'MEDDRA', 'UNII'}

unified CURIE prefixes found:
{'NCBIGene', 'MESH'}

failed to unify 0 CURIEs


In [7]:
columns = []
q_node_id_keys = set(["object", "subject"])
q_node_ids = []
q_edge_ids = []
for q_edge_id,edge_v in query["edges"].items():
    q_edge_ids.append(q_edge_id)
    q_node_id_found = False
    for k,v in edge_v.items():
        if (k in q_node_id_keys) and (type(v) is str):
            if v not in columns:
                q_node_ids.append(v)
                columns.append(v)
            if not q_node_id_found:
                q_node_id_found = True
                columns.append(q_edge_id)

query_nodes_with_ids = set()
for k, v in query['nodes'].items():
    node_ids = v.get('ids', [])
    if node_ids:
        query_nodes_with_ids.add(k)

for q_node_id in q_node_ids:
    columns.append(f"{q_node_id}_original_curie")
    columns.append(f"{q_node_id}_unified_curie")

columns.append("unified_curie_set")

trapi_result_columns = []
for i in range(len(q_node_ids)):
    trapi_result_columns.append(q_node_ids[i])
    if i < len(q_edge_ids):
        trapi_result_columns.append(q_edge_ids[i])

unified_curie_columns = []
for q_node_id in q_node_ids:
    unified_curie_columns.append(f"{q_node_id}_unified_curie")

In [8]:
trapi_results = trapi_message["results"]

#side exporation
#additonal edge information from the knowledge graph might be explored for filtering later
result_row_data = []
for trapi_result in trapi_results:
    curie_to_qnode_ids = dict()
    for qnode_id, entries in trapi_result["node_bindings"].items():
        for entry in entries:
            curie = entry["id"]
            if curie not in curie_to_qnode_ids:
                curie_to_qnode_ids[curie] = []
            curie_to_qnode_ids[curie].append(qnode_id)

    row_data_template = dict()
    q_edge_id_to_predicates = dict()
    trapi_result_curie_set = set()
    for qedge_id, entries in trapi_result["edge_bindings"].items():
        for entry in entries:
            curie = entry["id"]
            kg_entry = trapi_message["knowledge_graph"]["edges"][curie]
            subject_curie = kg_entry["subject"]
            object_curie = kg_entry["object"]
            predicate_curie = kg_entry["predicate"]
            [predicate_prefix, predicate_identifier] = predicate_curie.split(":")

            if qedge_id not in q_edge_id_to_predicates:
                q_edge_id_to_predicates[qedge_id] = set()
            q_edge_id_to_predicates[qedge_id].add(predicate_identifier)

            for curie in [subject_curie, object_curie]:
                for qnode_id in curie_to_qnode_ids[curie]:
                    if curie in curie_to_unified_curie:
                        unified_curie = curie_to_unified_curie[curie]
                    else:
                        break

                    name = curie_to_name[curie]
                    row_data_template[qnode_id] = name

                    trapi_result_curie_set.add(unified_curie)
                    row_data_template[qnode_id + "_original_curie"] = curie
                    row_data_template[qnode_id + "_unified_curie"] = unified_curie

    if len(trapi_result_curie_set) != len(q_node_ids):
        #print(f'skipping {list(curie_to_qnode_ids.keys())}')
        continue

    row_data_template["unified_curie_set"] = trapi_result_curie_set
    q_edge_ids_processed = set()
    row_datas = [row_data_template]
    for q_edge_id,predicates in q_edge_id_to_predicates.items():
        next_row_datas = []
        for row_data in row_datas:
            for predicate in predicates:
                next_row_data = deepcopy(row_data)
                next_row_data[q_edge_id] = predicate
                next_row_datas.append(
                    next_row_data 
                )
        row_datas = next_row_datas
    result_row_data += row_datas

print("warning: predicate direction(s) may be switched")
trapi_results_df = pd.DataFrame.from_records(result_row_data, columns=columns)
trapi_results_df



Unnamed: 0,n1,e0,n0,n2,e1,n1_original_curie,n1_unified_curie,n0_original_curie,n0_unified_curie,n2_original_curie,n2_unified_curie,unified_curie_set
0,VEGFA,decreases_secretion_of,Imatinib mesylate,Asthma,gene_associated_with_condition,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
1,VEGFA,decreases_secretion_of,Imatinib mesylate,Asthma,contributes_to,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
2,VEGFA,decreases_secretion_of,Imatinib mesylate,Asthma,associated_with,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
3,VEGFA,increases_secretion_of,Imatinib mesylate,Asthma,gene_associated_with_condition,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
4,VEGFA,increases_secretion_of,Imatinib mesylate,Asthma,contributes_to,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
5,VEGFA,increases_secretion_of,Imatinib mesylate,Asthma,associated_with,NCBIGene:7422,NCBIGene:7422,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:7422, MESH:D001249, MESH:D000068877}"
6,CASP8,increases_activity_of,Imatinib mesylate,Asthma,gene_associated_with_condition,NCBIGene:841,NCBIGene:841,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:841, MESH:D001249, MESH:D000068877}"
7,CASP8,increases_activity_of,Imatinib mesylate,Asthma,associated_with,NCBIGene:841,NCBIGene:841,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:841, MESH:D001249, MESH:D000068877}"
8,LYN,response_affected_by,Imatinib mesylate,Asthma,associated_with,NCBIGene:4067,NCBIGene:4067,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:4067, MESH:D001249, MESH:D000068877}"
9,LYN,response_decreased_by,Imatinib mesylate,Asthma,associated_with,NCBIGene:4067,NCBIGene:4067,PUBCHEM.COMPOUND:123596,MESH:D000068877,HP:0002099,MESH:D001249,"{NCBIGene:4067, MESH:D001249, MESH:D000068877}"


## Identify result entities using user specified parameters

```node_ID_for_grouping``` variable defines which node ID(s) from message.query_graph.nodes should be used for grouping. The entities mapped to that node ID are the "result entities". If "all" then use all the node ids in query are used. If specific node id set (e.g. {'MESH:D000068877', 'MESH:D001249'}), then use only these node ids are used.

In [9]:
#get all nodes with ids from the query 
user_specified_ids = set()
for query_node in query_nodes_with_ids:
    user_specified_ids.update(trapi_results_df[f'{query_node}_unified_curie'].tolist())
user_specified_ids

{'MESH:D000068877', 'MESH:D001249'}

In [13]:
# identify node ids for identifying result entities
if(node_ID_for_grouping != "all"):
    user_specified_ids = user_specified_ids.intersection(node_ID_for_grouping)

user_specified_ids

{'MESH:D000068877', 'MESH:D001249'}

In [14]:
print("Unique CURIE count per query node:")
for q_node_id in q_node_ids:
    print(f'{q_node_id}: {len(set(trapi_results_df[q_node_id]))}')

Unique CURIE count per query node:
n1: 7
n0: 1
n2: 1


## Get PFOCR Data

Get PFOCR data from Dropbox.

In [29]:
#create a data frame for all the figure data from PFOCR
#expected size of final data frame (as in Jan, 2023): 77719 rows × 4 columns
figures_df = pd.read_csv('https://www.dropbox.com/s/3eukom49hqskp88/figures_dataframe.csv?dl=1')
figures_df

Unnamed: 0,figure_id,chemical_curie_set,disease_curie_set,gene_curie_set
0,PMC5732092__cshperspect-CYT-028522_F2.jpg,set(),set(),"{'NCBIGene:5595', 'NCBIGene:2919', 'NCBIGene:1..."
1,PMC5793760__cshperspect-TGF-022210_F4.jpg,set(),set(),"{'NCBIGene:7042', 'NCBIGene:151449', 'NCBIGene..."
2,PMC5793761__cshperspect-TGF-031989_F1.jpg,set(),set(),"{'NCBIGene:4093', 'NCBIGene:5595', 'NCBIGene:1..."
3,PMC5830892__cshperspect-CEL-027961_F2.jpg,set(),set(),"{'NCBIGene:72', 'NCBIGene:59', 'NCBIGene:55249..."
4,PMC5830900__cshperspect-TGF-031997_F1.jpg,{'MESH:D011374'},set(),"{'NCBIGene:4093', 'NCBIGene:7042', 'NCBIGene:1..."
...,...,...,...,...
77714,PMC2804790__253_2009_2262_Fig1_HTML.jpg,"{'MESH:D009243', 'MESH:C031105', 'MESH:D000447...",set(),"{'NCBIGene:125', 'NCBIGene:137872', 'NCBIGene:..."
77715,PMC6332787__thnov09p0126g006.jpg,set(),set(),"{'NCBIGene:6387', 'NCBIGene:8295', 'NCBIGene:2..."
77716,PMC5807036__IJO-52-03-0787-g01.jpg,set(),set(),"{'NCBIGene:1977', 'NCBIGene:4254', 'NCBIGene:4..."
77717,PMC6770832__cancers-11-01236-g005.jpg,set(),set(),"{'NCBIGene:5595', 'NCBIGene:4093', 'NCBIGene:7..."


In [30]:
# read in figure metadata for all PFOCR figures
figure_metadata_df = pd.read_csv('https://www.dropbox.com/s/9jorbzpq2k8n5tr/figures_metadata.csv?dl=1', index_col=0)
figure_metadata_df

Unnamed: 0_level_0,figure_url,figure_title
figure_id,Unnamed: 1_level_1,Unnamed: 2_level_1
PMC5732092__cshperspect-CYT-028522_F2.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,Interleukin (IL)-17RA/RC signaling pathways
PMC5793760__cshperspect-TGF-022210_F4.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,TGFB signaling pathways
PMC5793761__cshperspect-TGF-031989_F1.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,Bone morphogenetic protein (BMP) signaling pat...
PMC5830892__cshperspect-CEL-027961_F2.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,FERM-binding partners of Crumbs3
PMC5830900__cshperspect-TGF-031997_F1.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,Role of the TGFB family in mammary gland devel...
...,...,...
PMC2804790__253_2009_2262_Fig1_HTML.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2...,Ehrlich degradation pathway from amino acid to...
PMC6332787__thnov09p0126g006.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,Illustration of the mechanisms underlying LPD ...
PMC5807036__IJO-52-03-0787-g01.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,Tumorigenic proteins significantly upregulated...
PMC6770832__cancers-11-01236-g005.jpg,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,Simplified TGFB pathway leading to p21 expression


In [31]:
# read in figure CURIES for all PFOCR figures
figure_curie_df = pd.read_csv('https://www.dropbox.com/s/y3ks74obs3nrjjz/figures_curie.csv?dl=1')
figure_curie_df

Unnamed: 0,figure_id,curie,category
0,PMC5732092__cshperspect-CYT-028522_F2.jpg,NCBIGene:5601,gene
1,PMC5732092__cshperspect-CYT-028522_F2.jpg,NCBIGene:5595,gene
2,PMC5732092__cshperspect-CYT-028522_F2.jpg,NCBIGene:1051,gene
3,PMC5732092__cshperspect-CYT-028522_F2.jpg,NCBIGene:10131,gene
4,PMC5732092__cshperspect-CYT-028522_F2.jpg,NCBIGene:23765,gene
...,...,...,...
1665596,PMC4609065__12885_2015_1721_Fig5_HTML.jpg,NCBIGene:1956,gene
1665597,PMC4609065__12885_2015_1721_Fig5_HTML.jpg,NCBIGene:4893,gene
1665598,PMC4609065__12885_2015_1721_Fig5_HTML.jpg,NCBIGene:2064,gene
1665599,PMC4609065__12885_2015_1721_Fig5_HTML.jpg,NCBIGene:3845,gene


## Enrichment/grouping of TRAPI results using PFOCR figures

From discussion with Andrew, we want to first identify a set of figures that can serve to cluster the TRAPI results. We can call these the "cluster figures".

**Identify Cluster Figures**

We identify the top cluster figures by iteratively running Fisher's Exact Test to identify PFOCR figures that are most similar to all the CURIEs from the BTE TRAPI results. To get each cluster figure, we do the following:

1. Run Fisher's Exact Test for the CURIEs from each figure vs. the set of the unique CURIEs found from all the TRAPI results, limiting each  set of CURIEs to those with categories found in both PFOCR and BTE TRAPI, e.g., "disease" is found in PFOCR and "biolink:Disease" is found in BTE TRAPI
2. Select the figure with the lowest p-value as a cluster figure
3. Exclude that figure and the CURIEs in it
4. Repeat to get the next cluster figure, continuing until:
    - we run out of BTE TRAPI results
    - we run out of PFOCR figures
    - we've repeated steps 1-3 a user-defined number of times (varibale ```n```)

**Cluster BTE TRAPI Results by Selected PFOCR Figures (Cluster Figures)**

Once we have a set of cluster figures, we use them to cluster the BTE TRAPI results. We do this by running Fisher's Exact Test to identify which Cluster Figure each BTE TRAPI result is most similar to. We don't actually cluster the TRAPI results by figure, however, because if a result is highly relevant to two figures, it can show up in both. It's more accurate to think of the figure as being a different view of the TRAPI results.

**View TRAPI Results by PFOCR Figure**

With BTE TRAPI results "clustered" (not actually clustered, because we show the same result multiple times for different figures if it's closely related to both) by PFOCR figure, we display the top figures, each with their most closely related BTE TRAPI results.

**Latest from Andrew**
Implementation details:
The enrichment analysis will be iterative in design. On the first iteration, all result entities will be grouped in one list, and an enrichment analysis will be run relative to all PFOCR figures. The most significantly enriched PFOCR figure will be set aside as PFOCR Figure 1.
The result entities that occurred in PFOCR Figure 1 will be removed from the result entity list, and a new enrichment analysis will be performed. The most significantly-enriched PFOCR figure will be set aside as PFOCR Figure 2.
The process will be repeated a total of n times.
The basic flow has been worked out in https://github.com/wikipathways/pathway-figure-ocr/blob/master/notebooks/bte_clustering.ipynb, but that notebook needs to be simplified for broader use.

In [None]:
# find the corresponding PFOCR category for a biolink category
biolink_category_to_pfocr_category = {
    "biolink:Drug": "chemical",
    "biolink:ChemicalEntity": "chemical",
    "biolink:SmallMolecule": "chemical",
    "biolink:Disease": "disease",
    "biolink:DiseaseOrPhenotypicFeature": "disease",
    "biolink:Gene": "gene",
    "biolink:GeneProduct": "gene",
    "biolink:Protein": "gene",
}

# limit to categories returned from query
relevant_biolink_categories = set(biolink_category_to_pfocr_category.keys()).intersection(curie_categories)
relevant_pfocr_categories = set([
    biolink_category_to_pfocr_category[
        biolink_category
    ] for biolink_category in relevant_biolink_categories
])

In [None]:
# Importing reduce for 
# rolling computations
from functools import reduce
  
def set_union(series):
    return reduce(lambda x, y: x | y, series)

def create_set(series):
    return set(series.tolist())


# TODO: remove cluster_figure_data and its associated DF.
# It's just a temporary thing to show what's happening inside.
cluster_figure_data = []
def get_next_cluster_figures(
    category_matched_pfocr_curie_count,
    remaining_figures_df,
    remaining_trapi_curie_count,
    required_curies,
    iteration_limit=n, #this is the user-defined parameter for the number of figures
    i=0,
    cluster_figure_ids=[],
):
    remaining_figure_count = len(
        remaining_figures_df
    )

    if (
        i >= iteration_limit
    ) or (
        len(remaining_trapi_curie_count) == 0
    ) or (
        remaining_figure_count == 0
    ):
        return cluster_figure_ids
    
    remaining_figures_df["figure_curie_count"] = remaining_figures_df["curie_set"].map(len)

    remaining_figures_df["figure_curies_in_trapi_results"] = remaining_figures_df["curie_set"].map(
        lambda curie_set: remaining_trapi_curie_count.intersection(
            set(curie_set)
        )
    )

    # number of CURIEs in both TRAPI results and figure
    remaining_figures_df["yes_trapi_results_yes_figure"] = remaining_figures_df[
        "figure_curies_in_trapi_results"
    ].map(len)

    # number of CURIEs in TRAPI results but not in figure
    remaining_figures_df["yes_trapi_results_no_figure"] = (
        len(remaining_trapi_curie_count) -
        remaining_figures_df["yes_trapi_results_yes_figure"]
    )

    # number of CURIEs not in TRAPI results but in figure
    remaining_figures_df["no_trapi_results_yes_figure"] = (
        remaining_figures_df["figure_curie_count"] -
        remaining_figures_df["yes_trapi_results_yes_figure"]
    )

    # number of CURIEs not in TRAPI results and not in figure,
    # ie., number of unique CURIEs only in other figures
    remaining_figures_df["no_trapi_results_no_figure"] = (
        category_matched_pfocr_curie_count -
        len(remaining_trapi_curie_count) -
        remaining_figures_df["no_trapi_results_yes_figure"]
    )

    # see https://stackoverflow.com/a/58661068/5354298
    _, _, twosided = pvalue_npy(
        remaining_figures_df["yes_trapi_results_yes_figure"].to_numpy('uint'),
        remaining_figures_df["yes_trapi_results_no_figure"].to_numpy('uint'),
        remaining_figures_df["no_trapi_results_yes_figure"].to_numpy('uint'),
        remaining_figures_df["no_trapi_results_no_figure"].to_numpy('uint'),
    )

    remaining_figures_df["p_value"] = pd.Series(twosided, index=remaining_figures_df.index)
    remaining_figures_df["p_value"].sort_values()
    
    min_df = remaining_figures_df[
        remaining_figures_df["p_value"] == remaining_figures_df["p_value"].min()
    ]
    
    if required_curies:
        # TODO: it might make sense to keep prioritizing the required CURIEs, but
        #       it's not clear whether this is always what we want.
        #       In some cases, we could obscure a different set of results that may
        #       be valuable but not because of the overlap with the required CURIEs.
        curies_to_exclude = set(
            min_df["figure_curies_in_trapi_results"].agg(set_union)
        ) - required_curies
    else:
        curies_to_exclude = set(min_df["figure_curies_in_trapi_results"].agg(set_union))
        
    if len(curies_to_exclude - required_curies) == 0:
        # On this iteration, we failed to match any new CURIE(s), other than
        # possibly just the require CURIEs.
        return cluster_figure_ids
        
    new_cluster_figure_ids = min_df.index.tolist()
    cluster_figure_ids.extend(new_cluster_figure_ids)
    
    next_min_df_score = remaining_figures_df[
        remaining_figures_df["p_value"] > remaining_figures_df["p_value"].min()
    ]["p_value"].min()
    next_min_df = remaining_figures_df[remaining_figures_df["p_value"] == next_min_df_score]
    
    for cluster_figure_id, min_row in min_df.iterrows():
        cluster_figure_data.append({
            "iteration": i,
            "figure_id": cluster_figure_id,
            "p_value": min_row["p_value"],
            "next_best_figure_ids": next_min_df.index.tolist(),
            "next_best_p_value": next_min_df["p_value"].min(),
            "yes_trapi_results_yes_figure": min_row["yes_trapi_results_yes_figure"],
            "yes_trapi_results_no_figure": min_row["yes_trapi_results_no_figure"],
            "no_trapi_results_yes_figure": min_row["no_trapi_results_yes_figure"],
            "no_trapi_results_no_figure": min_row["no_trapi_results_no_figure"],
            "figure_curies_in_trapi_results": min_row["figure_curies_in_trapi_results"],
            "curie_set": min_row["curie_set"],
            "curies_to_exclude": curies_to_exclude,
            "len_curies_to_exclude": len(curies_to_exclude),
            "remaining_figure_count": remaining_figure_count,
            "remaining_trapi_curie_count": len(remaining_trapi_curie_count),
            "cluster_figure_id_count": len(cluster_figure_ids),
        })
    
    # for next iteration, exclude figures we've already added to cluster figures
    next_remaining_figures_df = remaining_figures_df[
        (~remaining_figures_df.index.isin(
            cluster_figure_ids
        ))
    ]
    
    print(
        f'iteration {i:>2}: figures: {len(new_cluster_figure_ids):>2}, CURIEs: {len(curies_to_exclude):>2}'
    )
    
    return get_next_cluster_figures(
        category_matched_pfocr_curie_count,
        next_remaining_figures_df.reset_index().set_index("figure_id"),
        remaining_trapi_curie_count - curies_to_exclude,
        required_curies,
        iteration_limit,
        i + 1,
        cluster_figure_ids,
    )

category_matched_figure_curie_df = figure_curie_df[figure_curie_df["category"].isin(
    list(relevant_pfocr_categories)
)]
category_matched_pfocr_curie_count = len(set(
    category_matched_figure_curie_df["curie"].drop_duplicates()
))
print(f'category_matched_pfocr_curie_count: {category_matched_pfocr_curie_count}')

category_matched_figures_df = category_matched_figure_curie_df[
    ["figure_id", "curie"]
].groupby("figure_id").agg(
    create_set
).rename(columns={
    "curie": "curie_set"
})

category_matched_figures_df["figure_curies_in_trapi_results_count"] = category_matched_figures_df[
    "curie_set"
].map(
    lambda curie_set: len(trapi_results_unified_curies.intersection(
        set(curie_set)
    ))
)

overlapping_category_matched_figures_df = category_matched_figures_df[
    category_matched_figures_df["figure_curies_in_trapi_results_count"] > 0
]

# Require that cluster figures have specific user-specified CURIEs
if(required_curies == "all"):
    required_curies = user_specified_ids
if required_curies:
    overlapping_category_matched_figures_df = overlapping_category_matched_figures_df[
        overlapping_category_matched_figures_df["curie_set"].map(
            # require all of them
            # lambda curie_set: len(curie_set.intersection(required_curies)) == len(required_curies)
            # require at least one of them
            lambda curie_set: len(curie_set.intersection(required_curies)) > 0
        )
    ]

cluster_figure_ids_out = get_next_cluster_figures(
    category_matched_pfocr_curie_count,
    overlapping_category_matched_figures_df.reset_index().set_index("figure_id"),
    trapi_results_unified_curies,
    required_curies,
)
print(f'cluster figure count: {len(cluster_figure_ids_out)}')
cluster_figure_df = pd.DataFrame.from_records(cluster_figure_data)
cluster_figure_df

## View TRAPI Results by PFOCR Figure

In [None]:
figures_df = category_matched_figure_curie_df[
    category_matched_figure_curie_df["figure_id"].isin(set(cluster_figure_df["figure_id"]))
][["figure_id", "curie"]].groupby("figure_id").agg(
    create_set
).rename(columns={
    "curie": "figure_curie_set"
}).join(
    figure_metadata_df
).reset_index()

figures_df["figure_curie_count"] = figures_df["figure_curie_set"].map(len)

figures_df['crossjoin_key'] = 0
trapi_results_df['crossjoin_key'] = 0

trapi_results_with_figures_df = trapi_results_df.merge(
    figures_df, on='crossjoin_key', how='outer'
).rename(columns={
    "unified_curie_set": "trapi_result_curie_set"
})

trapi_results_with_figures_df["common_curie_set"] = trapi_results_with_figures_df.apply(
    lambda r: r["trapi_result_curie_set"] & r["figure_curie_set"], axis=1
)
trapi_results_with_figures_df
# number of CURIEs in both TRAPI result and figure
# trapi_curies_in_figure
# x
# Selected & Having the property
trapi_results_with_figures_df["yes_trapi_result_yes_figure"] = (
    trapi_results_with_figures_df["common_curie_set"].map(len)
)

# number of CURIEs in TRAPI result but not in figure
# q_node_id_count - trapi_curies_in_figure
# n - x
# Selected & Not Having the property
trapi_results_with_figures_df["yes_trapi_result_no_figure"] = (
    len(q_node_ids) - trapi_results_with_figures_df["yes_trapi_result_yes_figure"]
)

# number of CURIEs not in TRAPI result but in figure
# curies_in_figure - trapi_curies_in_figure
# N - x
# Not Selected & Having the property
trapi_results_with_figures_df["no_trapi_result_yes_figure"] = (
    trapi_results_with_figures_df["figure_curie_set"].map(len) - 
    trapi_results_with_figures_df["yes_trapi_result_yes_figure"]
)

# number of CURIEs not in TRAPI result and not in figure,
# ie., number of unique CURIEs only in other figures
# 28735 - curies_in_figure - q_node_id_count + trapi_curies_in_figure
# M - (n + N) + x
# Not Selected & Not Having the property
trapi_results_with_figures_df["no_trapi_result_no_figure"] = (
    category_matched_pfocr_curie_count -
    trapi_results_with_figures_df["figure_curie_set"].map(len) -
    len(q_node_ids) +
    trapi_results_with_figures_df["yes_trapi_result_yes_figure"]
)

# see https://stackoverflow.com/a/58661068/5354298
_, _, twosided = pvalue_npy(
    trapi_results_with_figures_df["yes_trapi_result_yes_figure"].to_numpy('uint'),
    trapi_results_with_figures_df["yes_trapi_result_no_figure"].to_numpy('uint'),
    trapi_results_with_figures_df["no_trapi_result_yes_figure"].to_numpy('uint'),
    trapi_results_with_figures_df["no_trapi_result_no_figure"].to_numpy('uint'),
)

trapi_results_with_figures_df["p_value"] = pd.Series(twosided, index=trapi_results_with_figures_df.index)
trapi_results_with_figures_df["p_value"].sort_values()

# Needed in order to use the TRAPI result CURIEs as a key. A Python Set is not hashable.
trapi_results_with_figures_df["trapi_result_curies_key"] = trapi_results_with_figures_df[
    "trapi_result_curie_set"
].map(
    lambda x: tuple(sorted(x))
)

In [None]:
trapi_results_with_figures_df

In [None]:
# This isn't really a significance threshold. It's more just intended to
# exclude results where the p-value is 1 or very close to it.
threshold = 0.9
significant_trapi_results_with_figures_df = trapi_results_with_figures_df[
    trapi_results_with_figures_df["p_value"] < 0.9
]

# all
cooccurrence_df = significant_trapi_results_with_figures_df.sort_values(
    "p_value"
)

### first and any other query node
##q_node_id_combination = (q_node_ids[0],)

## first and second query nodes
##q_node_id_combination = (q_node_ids[0], q_node_ids[1])

### first and last query nodes
##q_node_id_combination = (q_node_ids[0], q_node_ids[-1])

### second and last query nodes
##q_node_id_combination = (q_node_ids[1], q_node_ids[-1])

### every query node with 'ids' specified
##q_node_id_combination = query_nodes_with_ids

#print(f'{" & ".join(q_node_id_combination)}')
#s = pd.Series(
#    [True] * len(significant_trapi_results_with_figures_df),
#    index=significant_trapi_results_with_figures_dfindex
#)
#for q_node_id in q_node_id_combination:
#    s = s & significant_trapi_results_with_figures_df.apply(
#        lambda r: r[f'{q_node_id}_unified_curie'] in r["common_curie_set"], axis=1
#    )
#cooccurrence_df = significant_trapi_results_with_figures_df[s == True].sort_values(
#    "p_value"
#)

cooccurrence_df