# DyGIE++ co-occurrence entity characterization and grounding

In [1]:
import networkx as nx
from collections import defaultdict, Counter
import taxoniq
from tqdm.notebook import tqdm
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import rgb2hex
import json
import pandas as pd
import timeit
from functools import partial

## Read in data

In [None]:
graph = nx.read_graphml('../data/kg/all_drought_dt_co_occurrence_graph_02May2024.graphml')

## Basic characterization

In [None]:
ents_by_type = defaultdict(list)
for n, attrs in graph.nodes(data=True):
    ents_by_type[attrs['ent_type']].append(n)

### Check database grounding potential for `Multi`- and `Unicellular_organism` types

In [None]:
grounded_multicell = {}
for n in ents_by_type['Multicellular_organism']:
    try:
        t = taxoniq.Taxon(scientific_name=n)
        grounded_multicell[n] = t.scientific_name
    except KeyError:
        continue
print(f'{len(grounded_multicell)} of {len(ents_by_type["Multicellular_organism"])} multicellular organisms could be grounded')

In [None]:
grounded_unicell = {}
for n in ents_by_type['Unicellular_organism']:
    try:
        t = taxoniq.Taxon(scientific_name=n)
        grounded_unicell[n] = t.scientific_name
    except KeyError:
        continue
print(f'{len(grounded_unicell)} of {len(ents_by_type["Unicellular_organism"])} unicellular organisms could be grounded')

Grounding capacity directly using taxoniq is horrendous. What happens if we concatenate all of the species mentions together and use TaxoNERD, like we did for citation network classification? We're going to write a mini script and submit this as a job, because TaxoNERD grounding is exceedingly slow.

In [None]:
# with open('../data/kg/full_graph_multicellular_ents_02May2024.txt', 'w') as f:
#     f.write("\n".join(ents_by_type['Multicellular_organism']))

In [None]:
with open('../data/kg/full_graph_multicellular_ents_GROUNDED_02May2024.json') as f:
    grounded_ents = json.load(f)

In [None]:
print(f'There are {len(grounded_ents)} that received a grounding, and {len(set(grounded_ents.values()))} unique groundings.')

This is great news! It means that a bunch of entities got resolved, even though we didn't get groundings for many of the entities. Let's use this to tighten up our graph. Unfortunately using the existing networkx function `contracted_nodes` is prohibitively slow on large graphs, so let's try something based on [this solution](https://stackoverflow.com/a/73762332/13340814) from stack overflow. In the stack overflow post, there are garanteed to be edges between the nodes to be contracted, whereas there is not in our case, so we'll have to improvise a little.

In [None]:
groundings_to_ents = defaultdict(list)
for ent, grd in grounded_ents.items():
    groundings_to_ents[grd].append(ent)

In [None]:
def contract_groups(graph, groups):
    """
    Contract groups of nodes that may or may not be connected. Combines the number
    of per-doc mentions for edges and nodes, and keeps the oldest year as the first_year_mentioned.
    
    graph, newtorkx Graph: undirected network containing nodes in group
    groups, dict of list: nodes to coalesce.
    """
    # Convert original graph to node and edgelist
    nodes = graph.nodes(data=True)
    edges = nx.to_pandas_edgelist(graph)
    
    # Make sure edge years are integers for later use
    edges = edges.astype({'first_year_mentioned': 'int32'})
    
    # Go through the groups
    nodes_to_add = []
    nodes_to_remove = []
    for grounding, n_list in tqdm(groups.items()):
        
        # Get the oldest year for node mentions
        oldest_node_year = min([int(nodes[n]['first_year_mentioned']) for n in n_list])
        
        # Get total node mentions
        total_node_mentions = sum([nodes[n]['num_doc_mentions_all_time'] for n in n_list])
        
        # Get all uids of origin
        combined_node_uids = ', '.join([nodes[n]['uids_of_origin'] for n in n_list])
        
        # Get the formal name that we want to keep for the node
        try:
            formal_name = taxoniq.Taxon(int(grounding)).scientific_name
        except KeyError:
            formal_name = groups[grounding][0]
        
        # Get the subset of relations that involve these nodes
        edge_subset = edges[(edges['source'].isin(n_list)) | (edges['target'].isin(n_list))].reset_index(drop=True)
        
        # Replace all nodes with the formal representation to being coalescing
        edges_replaced = edge_subset.replace(to_replace=n_list, value=formal_name)
        
        # Combine the values of any edges that are semantically identical after replacement
        # First, get the indices of repeated groups, order-agnostically
        tup_list = [tuple(set(tup)) for tup in list(edges_replaced[['source', 'target']].itertuples(index=False, name=None)) if len(set(tup)) > 1]
        tup_set = set(tup_list)
        rep_idxs = defaultdict(list)
        for i, tup in enumerate(tup_list):
            rep_idxs[tup].append(i)
        # Now, combine the attributes and store in a dict
        edge_replacements = []
        keep_the_same = []
        for edge, idxs in rep_idxs.items():
            if len(idxs) > 1:
                oldest_edge_year = edges_replaced.loc[idxs, 'first_year_mentioned'].min()
                total_edge_mentions = edges_replaced.loc[idxs, 'num_doc_mentions_all_time'].sum()
                is_drought = edges_replaced.loc[idxs, 'is_drought'].any()
                is_desiccation = edges_replaced.loc[idxs, 'is_desiccation'].any()
                uids_of_origin = ', '.join(edges_replaced.loc[idxs, 'uids_of_origin'])
                edge_replacements.append({
                    'source': edge[0],
                    'target': edge[1],
                    'first_year_mentioned': oldest_edge_year,
                    'num_doc_mentions_all_time': total_edge_mentions,
                    'is_drought': is_drought,
                    'is_desiccation': is_desiccation,
                    'uids_of_origin': uids_of_origin
                })
            elif len(idxs) == 1:
                keep_the_same.extend(idxs)
        # Now drop all indices that had more than one semantic replicate
        edges_replaced_to_drop = edge_subset.loc[~edges_replaced.index.isin(keep_the_same)]
        edges = pd.merge(edges, edges_replaced_to_drop, how='outer', indicator=True)
        edges = edges.loc[edges._merge == 'left_only'].drop(columns='_merge')
        # And replace with the combined edges
        edges = pd.concat([edges, pd.DataFrame(edge_replacements)], ignore_index=True)
        
        # And finally, save the formal name of the new node and its attrs to use later, and add the nodes to remove
        nodes_to_add.append((formal_name,
                             {'first_year_mentioned': oldest_node_year,
                              'num_doc_mentions_all_time': total_node_mentions,
                             'uids_of_origin': combined_node_uids,
                             'entity_type': 'Multicellular_organism'})) # Since this was all we could ground
        nodes_to_remove.extend(n_list)
    
    # Remove old nodes and add new ones
    nodes_processed = [(n, attrs) for n, attrs in nodes if n not in nodes_to_remove]
    for new_node in nodes_to_add:
        nodes_processed.append(new_node)
    
    # Make new graph from edgelist and nodelist
    new_graph = nx.from_pandas_edgelist(edges, edge_attr=['first_year_mentioned',
                    'num_doc_mentions_all_time', 'is_drought', 'is_desiccation',
                    'uids_of_origin'])
    _ = new_graph.add_nodes_from(nodes_processed)
    
    return new_graph

Test on a small subset:

In [None]:
sesame_test = nx.to_pandas_edgelist(graph).loc[:5,:]
sesame_test.loc[1, 'target'] = 'sesame seed'
sesame_test.loc[5, 'source'] = 'sesame plant'
sesame_test.loc[5, 'target'] = 'peg-induced drought tolerance'

In [None]:
sesame_test

In [None]:
sesame_graph = nx.from_pandas_edgelist(sesame_test, edge_attr=[
    'is_drought',
    'uids_of_origin',
    'first_year_mentioned',
    'num_doc_mentions_all_time',
    'is_desiccation'
])
actual_nodes = [(n, attr) for n, attr in graph.nodes(data=True) if (n in sesame_test.source.tolist()) or (n in sesame_test.target.tolist())]
_ = sesame_graph.add_nodes_from(actual_nodes + [('sesame seeds', {'uids_of_origin': 'WOS:000623658100043',
                                                   'entity_type': 'Multicellular_organism',
                                                  'first_year_mentioned': '2017',
                                                  'num_doc_mentions_all_time': 3})])

In [None]:
nx.to_pandas_edgelist(sesame_graph)

In [None]:
sesame_groups = {'4182': ['sesame', 'sesame seed', 'sesame plant', 'sesame seeds']}

In [None]:
print('Correct number total mentions for combined entity:', sum([attrs['num_doc_mentions_all_time'] for n, attrs in sesame_graph.nodes(data=True) if n in sesame_groups['4182']]))
print('Correct number of total mentions for the combined edge: 3')

In [None]:
test_output = contract_groups(sesame_graph, sesame_groups)

In [None]:
nx.to_pandas_edgelist(test_output)

In [None]:
test_output.nodes(data=True)['Sesamum indicum']['first_year_mentioned'], test_output.nodes(data=True)['Sesamum indicum']['num_doc_mentions_all_time']

Test speed against using the contraction function from networkx:

In [None]:
def contraction(graph, groundings_to_ents):
    for grd, nodes in tqdm(groundings_to_ents.items()):
        first_node = nodes[0]
        formal_name = taxoniq.Taxon(int(grd)).scientific_name
        for n in nodes[1:]:
            graph = nx.contracted_nodes(graph, first_node, n)
        nx.relabel_nodes(graph, {first_node: formal_name})
    return graph

In [None]:
t1 = timeit.Timer(partial(contraction, sesame_graph, sesame_groups))
t1.timeit(5)

In [None]:
t2 = timeit.Timer(partial(contract_groups, sesame_graph, sesame_groups))
t2.timeit(5)

In [None]:
# contracted_graph = contract_groups(graph, groundings_to_ents)

In [None]:
# contracted_graph = contraction(graph, groundings_to_ents)

While the networkx builtin in faster on this small subset, it does not at all scale with size. When I tried to run the `contraction` function above on my graph, it wouldn't even run the first iteration after several minutes, while my own code only took ~7 seconds per iteration. 7 seconds per iteration is still un-ideal, because it'll take 2 days to run on the entire graph; however, that is infintely preferable to the alternative. Therefore, going to make and submit a job to do the contraction, and will read in the results here. 

In [None]:
contracted_graph = nx.read_graphml('')

### Appearance of study organisms over time
We'll get the top twenty species in the graph based on both the number of mentions of each node, as well as the number of different nodes with the same Taxonomy ID, and then look at when they appear over time. For the moment, we'll ignore any nodes that didn't get grounded.

In [None]:
## TODO account for groundings

In [None]:
ent_mentions = nx.get_node_attributes(graph, 'num_doc_mentions_all_time')
multi_mentions = {ent: ent_mentions[ent] for ent in ents_by_type['Multicellular_organism']}

In [None]:
top_twenty_organisms = dict(sorted(multi_mentions.items(), key=lambda x:x[1], reverse=True)[:20])

In [None]:
organism_year_mentions = {}
for organism in top_twenty_organisms.keys():
    year_mentions = {}
    attrs = graph.nodes[organism]
    for attr, val in attrs.items():
        if ('num_mentions_' in attr) and (attr != 'num_doc_mentions_all_time'):
            year = attr.split('_')[-1]
            if year != '2023':
                year_mentions[int(year)] = val
    organism_year_mentions[organism] = year_mentions

In [None]:
cmap = matplotlib.cm.get_cmap('tab20c')
organism_colors = {organism: rgb2hex(cmap(i)) for i, organism in enumerate(top_twenty_organisms.keys())}

In [None]:
for org in top_twenty_organisms.keys():
    x = sorted(organism_year_mentions[org].keys())
    y = [organism_year_mentions[org][i] for i in x]
    color = organism_colors[org]
    plt.plot(x, y, color=color, marker='o', label=f'{org} ({top_twenty_organisms[org]} total mentions)')
plt.xlabel('Year')
plt.ylabel('Mentions')
plt.legend(loc=(1.1,0.0))

### Genes and proteins directly connected to an organism mention
One way we can identify genes/proteins that belong to various species in our graph is to check whether they are directly connected to an organism mention. The original intention of the is-in relation was to perform this kind of linking. Since we couldn't use typed relations and instead had to rely on co-occurrence, we'll treat any link as a possible is-in link.

In [None]:
ent_types = nx.get_node_attributes(graph, 'ent_type')
genes_by_organism = {}
proteins_by_organism = {}
for n in ents_by_type['Multicellular_organism']:
    neighbors = graph.neighbors(n)
    genes = []
    proteins = []
    for m in neighbors:
        if ent_types[m] in ['DNA', 'RNA']:
            genes.append(m)
        elif ent_types[m] == 'Protein':
            proteins.append(m)
    genes_by_organism[n] = genes
    proteins_by_organism[n] = proteins

If we use the oversimplification that any true Arabidopsis genes would start with `at`, then we can check what percentage of the connections are "correct":

In [None]:
correct_arabidopsis = [g for g in genes_by_organism['arabidopsis'] if g[:2] == 'at']
print(f'Assuming correct genes will start with At, {(len(correct_arabidopsis)/len(genes_by_organism["arabidopsis"]))*100:.2f} percent of genes directly connected to Arabidopsis are correct.')

In [None]:
print(f'{len(genes_by_organism["arabidopsis"])} of {graph.degree("arabidopsis")} edges ({(len(genes_by_organism["arabidopsis"])/graph.degree("arabidopsis"))*100:.2f}%) from the node "arabidopsis" are to genes.')
print(f'{len(proteins_by_organism["arabidopsis"])} of {graph.degree("arabidopsis")} edges ({(len(proteins_by_organism["arabidopsis"])/graph.degree("arabidopsis"))*100:.2f}%) from the node "arabidopsis" are to proteins.')