# Measuring Semantic Similarity of MeSH Terms

There are, as of May 2019, 29,350 MeSH terms. MeSH terms can be represented on a direct acyclic graph where terms higher on the graph tend to be more ambiguous and have children of increasing specificity. MeSH is frequently referred to as a tree structure, and although it shares many similarities with trees, it is not a tree - most terms occur at multiple places (some dozens of times) on the graph and thus have multiple parent nodes. Because documents are annotated manually, and certainly the huge number of potential terms plays a role in this as well, lots of human biases are infused into the term selection for any given document. Thus, it might be beneficial to have semantic similarity values for all combinations of terms - because different individuals would likely annotate any single document with the same general terms, but might differ when annotating the article with more specific terms. By measuring the semantic similarity of terms, based on the entire Pubmed corpus, I am to capture some of this behavior and potentially incorporate it into my models.

I use [Song, Li, Srimani, et al.'s](https://www.ncbi.nlm.nih.gov/pubmed/26356015) method, which was used for measuring the semantic similarity of Gene Ontology terms. This method is based on the graph structure of MeSH and uses aggregatic information content to measure the similarity of any two terms based on the terms' frequencies in the corpus, their shared ancestors' frequencies in the corpus, and their children's frequencies in the corpus.

I also ended up implementing multiprocessing to compute semantic similarities in parallel to save time. This architecture ended up requiring me to add quite a bit and I am not sure how functional it would be in a notebook, so please see the [Python code](https://github.com/wigasper/FUSE/blob/master/compute_semantic_similarity.py) that this notebook is based on if you are interested. The calculation logic was unchanged by the multiprocessing implementation.

## Imports and function definitions

In [None]:
import os
import re
import math
import json
import time
import logging
import traceback
from itertools import combinations

import numpy as np

# Gets a list of children for a term. Because we we don't actually have a graph
# to traverse, it is done by searching according to the written representation of
# its position on the graph
def get_children(uid, term_trees):
    # Return empty list for terms (like 'D005260' - 'Female') that aren't
    # actually part of any trees
    if len(term_trees[uid][0]) == 0:
        return []
    
    children = []

    for tree in term_trees[uid]:
        parent_depth = len(tree.split("."))
        for key, vals in term_trees.items():
            for val in vals:
                child_depth = len(val.split("."))
                if tree in val and uid != key and child_depth == parent_depth + 1:
                    children.append(key)
    
    return list(dict.fromkeys(children))

# Recursively computes the frequency according to Song et al by adding
# the term's count to sum of the counts of all its children
def freq(uid, term_counts, term_freqs, term_trees):
    total = term_counts[uid]
    # Check to see if term has already been computed, avoid recomputation
    if term_freqs[uid] != -1:
        return term_freqs[uid]
    if len(get_children(uid, term_trees)) == 0:
        return total
    else:
        for child in get_children(uid, term_trees):
            total += freq(child, term_counts, term_freqs, term_trees)
        return total

# Get all ancestors of a term
def get_ancestors(uid, term_trees, term_trees_rev):
    ancestors = [tree for tree in term_trees[uid]]
    # Remove empty strings if they exist
    ancestors = [ancestor for ancestor in ancestors if ancestor]
    idx = 0
    while idx < len(ancestors):
        ancestors.extend([".".join(tree.split(".")[:-1]) for tree in term_trees[term_trees_rev[ancestors[idx]]]])
        ancestors = [ancestor for ancestor in ancestors if ancestor]
        ancestors = list(dict.fromkeys(ancestors))
        idx += 1
    ancestors = [term_trees_rev[ancestor] for ancestor in ancestors]
    ancestors = list(dict.fromkeys(ancestors))
    return ancestors

# Compute semantic similarity for 2 terms
def semantic_similarity(uid1, uid2, sws, svs):
    uid1_ancs = get_ancestors(uid1, term_trees, term_trees_rev)
    uid2_ancs = get_ancestors(uid2, term_trees, term_trees_rev)
    intersection = [anc for anc in uid1_ancs if anc in uid2_ancs]
    num = sum([(2 * sws[term]) for term in intersection])
    denom = svs[uid1] + svs[uid2]
    
    return 0 if num is np.NaN or denom is 0 else num / denom

## Setup

When I started this project, I was relatively new to Python. I more or less manually logged everything, but I have since moved on to using the logging module. I can't emphasize enough how valuable this has been. Several times, logging has revealed considerable mistakes that I otherwise would not have noticed. For example, it turns out that occasionally, but not often, BeautifulSoup will truncate very large files without giving any indication to the user - and I only noticed this because of the timestamps in the logs. For this reason, I switched to using regular expressions instead of BeautifulSoup. BeautifulSoup was certainly more elegant and readable than regular expressions, but aside from the data loss issue (the major dealbreaker), it turns out that it is also much slower than regular expressions and uses a magnitude greater memory for each document.

In [None]:
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.FileHandler("./logs/compute_semantic_similarity.log")
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

uids = []
names = []
trees = []

with open("./data/mesh_data.tab", "r") as handle:
    for line in handle:
        line = line.strip("\n").split("\t")
        uids.append(line[0])
        names.append(line[1])
        trees.append(line[4].split(","))

docs = os.listdir("./pubmed_bulk")

# Create term_trees dict and reverse for quick and easy lookup later
term_trees = {uids[idx]:trees[idx] for idx in range(len(uids))}
term_trees_rev = {tree:uids[idx] for idx in range(len(uids)) for tree in trees[idx]}

term_counts = {uid:0 for uid in uids}

# Compile regexes for counting MeSH terms
mesh_list_start = re.compile(r"\s*<MeshHeadingList>")
mesh_list_stop = re.compile(r"\s*</MeshHeadingList>")
mesh_term_id = re.compile(r'\s*<DescriptorName UI="(D\d+)".*>')

## Count MeSH terms

Next, I count the occurrence of each MeSH term in the entire Pubmed corpus - approximately 29 million documents.

In [None]:
for doc in docs:
    try:
        with open("./pubmed_bulk/{}".format(doc), "r") as handle:
            start_time = time.perf_counter()

            line = handle.readline()
            while line:
                if mesh_list_start.search(line):
                    while not mesh_list_stop.search(line):
                        if mesh_term_id.search(line):
                            term_id = mesh_term_id.search(line).group(1)
                            term_counts[term_id] += 1
                        line = handle.readline()
                line = handle.readline()

            # Get elapsed time and truncate for log
            elapsed_time = int((time.perf_counter() - start_time) * 10) / 10.0
            logger.info(f"{doc} MeSH term counts completed in {elapsed_time} seconds")
    except Exception as e:
        trace = traceback.format_exc()
        logger.error(repr(e))
        logger.critical(trace)

## Compute Semantic Similarity

Semantic similarity is computed in a step-by-step process with individual data structures for each step in order to keep things more readable. As previously stated, I use [Song, Li, Srimani, et al.'s](https://www.ncbi.nlm.nih.gov/pubmed/26356015) method here.
The process can be done much faster by utilizing multiprocessing, please see the [original code](https://github.com/wigasper/FUSE/blob/master/compute_semantic_similarity.py) that this notebook is based on if you are interested.

In [None]:
# Get term frequencies (counts) recursively as described by
# Song et al
start_time = time.perf_counter()
term_freqs = {uid:-1 for uid in uids}
for term in term_freqs.keys():
    term_freqs[term] = freq(term, term_counts, term_freqs, term_trees)
# Get elapsed time and truncate for log
elapsed_time = int((time.perf_counter() - start_time) * 10) / 10.0
logger.info(f"Term freqs calculated in {elapsed_time} seconds")

root_freq = sum(term_freqs.values())
            
# Get term probs
term_probs = {uid:np.NaN for uid in uids}
for term in term_probs:
    term_probs[term] = term_freqs[term] / root_freq

# Compute IC values
ics = {uid:np.NaN for uid in uids}
for term in ics:
    try:
        ics[term] = -1 * math.log(term_probs[term])
    # ZeroDivisionError should not happen with the full corpus
    except ZeroDivisionError:
        logger.error(f"ZeroDivisionError for {term}")

# Compute knowledge for each term
knowledge = {uid:np.NaN for uid in uids}
for term in knowledge:
    knowledge[term] = 1 / ics[term]
        
# Compute semantic weight for each term
sws = {uid:np.NaN for uid in uids}
for term in sws:
    sws[term] = 1 / (1 + math.exp(-1 * knowledge[term]))
    
# Compute semantic value for each term by adding the semantic weights
# of all its ancestors
svs = {uid:np.NaN for uid in uids}
for term in svs:
    sv = 0
    ancestors = get_ancestors(term, term_trees, term_trees_rev)
    for ancestor in ancestors:
        sv += sws[ancestor]
    svs[term] = sv

# Compute semantic similarity for each pair
pairs = {}
start_time = time.perf_counter()
for pair in combinations(uids, 2):
    try:
        with open("./data/semantic_similarities_rev1.csv", "a") as out:
            out.write("".join([pair[0], ",", pair[1], ",", str(semantic_similarity(pair[0], pair[1], sws, svs)), "\n"]))
    except Exception as e:
        trace = traceback.format_exc()
        logger.error(repr(e))
        logger.critical(trace)

# Get elapsed time and truncate for log
elapsed_time = int((time.perf_counter() - start_time) * 10) / 10.0
logger.info(f"Semantic similarities calculated in {elapsed_time} seconds")

In [None]:
# show some of the most similar terms here ....