## Find Categories

This is somewhat dependent on the structure of the KG used. In the original paper, the KG was based on Wikipedia, where there is a parallel hierarchy of categories.

In the case of the KG used in this project, there is a "magic" concept representing disease (CID:`2795416`) whose children are the various high level disease categories. So the objective is to navigate from each node detected via the annotation or PPR to one of the disease category concepts if possible.

In this notebook, we will compute the categories from the original annotations and the top PPR predictions for each article.

In [1]:
import py2neo
import os

In [2]:
DATA_DIR = "../data"
KG_VERTICES_FILEPATH = os.path.join(DATA_DIR, "emmet-vertices.tsv")
ST_CONCEPTS_FILEPATH = os.path.join(DATA_DIR, "story-concepts.tsv")
PP_CONCEPTS_FILEPATH = os.path.join(DATA_DIR, "story-ppr-concepts.tsv")

NEO4J_CONN_URL = "bolt://localhost:7687"

# magic node id, below this are all the disease classes (categories)
DISEASE_CID = "2795416"

In [3]:
graph = py2neo.Graph(NEO4J_CONN_URL, auth=("neo4j", "graph"))

### Build CID to Name Mapping

The `build_cid_name_mapping` function produces a dictionary that maps the CID to the primary name for the concept. This is needed for display purposes.

In [4]:
def build_cid_name_mapping(vertices_filepath):
    cid2name = {}
    fvert = open(vertices_filepath, "r")
    for line in fvert:
        cols = line.strip().split('\t')
        cid = cols[0]
        cname = cols[1].split('|')[0]
        cid2name[cid] = cname
    fvert.close()
    return cid2name

cid2name = build_cid_name_mapping(KG_VERTICES_FILEPATH)
print(DISEASE_CID, cid2name[DISEASE_CID])

2795416 disease


### Find Disease Category CIDs

The children of the `DISEASE_CID` concept are the disease categories we are interested in rolling up to. This is used to filter out ancestor concepts.

In [5]:
def find_disease_category_nodes(graph, parent_cid):
    query = """
        MATCH (src:Concept)-[:REL]->(dst:Concept {cid:'2795416'}) 
        RETURN src.cid AS src_cid
    """
    disease_categories = []
    results = graph.run(query).data()
    for result in results:
        disease_categories.append(result["src_cid"])
    return set(disease_categories)

disease_category_cids = find_disease_category_nodes(graph, DISEASE_CID)
print(list(disease_category_cids)[0:5])

['9092344', '8816100', '8183827', '8120549', '9792275']


### Extract CIDs from Annotation and PPR Prediction files

Generic function to extract CIDs from either file (remember they have the same format) for a given story.

In [6]:
def extract_source_cids(filename, story_id):
    cids = []
    fppv = open(filename, "r")
    for line in fppv:
        sid, cid, weight = line.strip().split('\t')
        if sid != story_id:
            continue
        cids.append(cid)
    fppv.close()
    if len(cids) > 20:
        return cids[0:20]
    else:
        return cids

### Compute path from each CID to category CIDs

The pair of functions below is a recursive function to navigate via the parent link `isChildOf`. At each stage, there can be 0, 1, or more parents. Recursion stops if there are 0 parents or if the parent reached is a category CID.

Caller will call the `find_categories` function with the list of CIDs corresponding to the annotations or PPR predictions. For each CID, the function will create the necessary data structures and call the recursive function `_find_parent_nodes`.

In [7]:
def _find_parent_nodes(graph, child_cid, disease_category_cids, 
                       ancestor_list):
    query = """
        MATCH (src:Concept {cid:'%s'})-[:REL]->(dst:Concept)
        RETURN dst.cid AS dst_cid
    """ % (child_cid)
    results = graph.run(query).data()
    num_results = 0
    for result in results:
        parent_cid = result["dst_cid"]
        ancestor_list.append(parent_cid)
        if parent_cid in disease_category_cids:
            # we have reached the category nodes, stop climbing
            return
        _find_parent_nodes(graph, parent_cid, disease_category_cids, 
                           ancestor_list)
        num_results += 1
    if num_results == 0:
        # no parents for this node, stop climbing
        return


def find_categories(cids):
    doc_categories = set()
    for cid in cids:
        ancestor_list = [cid]
        _find_parent_nodes(graph, cid, disease_category_cids, ancestor_list)
        category_cids = set(ancestor_list).intersection(disease_category_cids)
        doc_categories.update(category_cids)
    return doc_categories

### Finding categories for all stories

In [8]:
for filename in os.listdir(DATA_DIR):
    if not filename.endswith(".story"):
        continue
    ppr_cids = []
    story_id = filename.split('.')[0]
    # extract cids
    orig_cids = extract_source_cids(ST_CONCEPTS_FILEPATH, story_id)
    ppr_cids = extract_source_cids(PP_CONCEPTS_FILEPATH, story_id)
    # find categories by tree climbing
    orig_categories = find_categories(orig_cids)
    ppr_categories = find_categories(ppr_cids)
    print("story_id: {:s}".format(story_id))
    print("original categories:", [(cid, cid2name[cid]) for cid in orig_categories])
    print("PPR categories:", [(cid, cid2name[cid]) for cid in ppr_categories])
    print("---")

story_id: 190823140729
original categories: [('8116566', 'disorder by body site'), ('8183829', 'neoplasm and/or hamartoma'), ('2795902', 'mental disorder'), ('9773914', 'physical disorder')]
PPR categories: [('8108157', 'inflammatory disorder'), ('9773914', 'physical disorder'), ('8120549', 'disorder characterized by pain'), ('8116566', 'disorder by body site')]
---
story_id: 190904194433
original categories: [('9105210', 'female genital and obstetric disorder'), ('8183829', 'neoplasm and/or hamartoma')]
PPR categories: [('8108157', 'inflammatory disorder'), ('9773914', 'physical disorder'), ('8183829', 'neoplasm and/or hamartoma'), ('8120549', 'disorder characterized by pain'), ('8116566', 'disorder by body site')]
---
story_id: 190909193211
original categories: [('8116566', 'disorder by body site'), ('2797067', 'nutrition and metabolism disorders'), ('8183829', 'neoplasm and/or hamartoma')]
PPR categories: [('8108157', 'inflammatory disorder'), ('9773914', 'physical disorder'), ('812