In [1]:
import csv
import itertools
import re
import spacy
import pandas as pd
import networkx as nx
from collections import Counter
nlp = spacy.load('en_core_web_sm')

In [2]:
# Spacy has some default rules for spliting text into sentences, as our text is already split
# disabled this feature
def prevent_sentence_boundary_detection(doc):
    for token in doc:
        # This will entirely disable spaCy's sentence detection
        token.is_sent_start = False
    return doc

nlp.add_pipe(prevent_sentence_boundary_detection, name='prevent-sbd', before='parser')

In [26]:
def join_bad_splits(parsed):
    """ Default tokenizer splits over some or all '-'s do this as adding rules wasn't working"""
    for token in parsed:
        if re.fullmatch(r'[A-Z]', token.text) is not None:
            i = token.i
            if i == 0:
                continue
            with parsed.retokenize() as retokenizer:
                retokenizer.merge(parsed[i-1:i+1])
            return join_bad_splits(parsed)
        if token.text == '-':
            i = token.i
            with parsed.retokenize() as retokenizer:
                retokenizer.merge(parsed[i-1:i+2])
            # Merging removes a token, so iterating over the list goes out of index    
            return join_bad_splits(parsed)
    return parsed

In [27]:
def get_shortest_path(graph, pairs):
    """ Gets the shortest dependency tree paths for each pair"""
    path_lens = []
    for p in pairs:
        try:
            path_lens.append(nx.shortest_path_length(graph, p[0], p[1]))
        except:
            continue
    if len(path_lens) == 0:
        return [-1]
    return path_lens

In [28]:
def find_ner_term(ner, token):
    """ Check if the ner term matches the token, if there is punctuation in the ner,
        check if it is a substring of the token"""
    subtokens = re.split(r'[\.\,\+\*/-]', token)
    ner_split = re.split(r'[\.\,\+\*/-]', token)
    if len(ner_split) != 1:
        return ner in token
    return ner == token or ner in subtokens

In [29]:
def tree_distance(gene, disease, parsed):
    """ Get the minimum, maxium, and average minimal dep tree distance for the terms in a sentence"""
    edges = []
    gene_mentions = []
    disease_mentions = []
    for token in parsed:
        token_format = '{0}-{1}'.format(token.text, token.i)
        if find_ner_term(gene, token.text):
            gene_mentions.append(token_format)
        if find_ner_term(disease, token.text):
            disease_mentions.append(token_format)
        for child in token.children:
            edges.append((token_format, '{0}-{1}'.format(child.text, child.i)))
    graph = nx.Graph(edges)
    pairs = [(g, d) for g in gene_mentions for d in disease_mentions]
    min_dists = get_shortest_path(graph, pairs)
    if len(min_dists) == 0:
        min_dists = [-1]
    word_dists = [abs(int(p[0].rsplit('-', 1)[1]) - int(p[1].rsplit('-', 1)[1])) for p in pairs]
    try:
        return [min_dists[0], word_dists[0]]  # Currently only 1 pair per sentence given tags
    except:
        print(gene, disease, [t.text for t in parsed])

In [30]:
def common_ancestor(gene, disease, doc):
    """ Finds the closest ancestor for gene/disease """
    gene_ancestors = []
    dis_ancestors = []
    # Get ancestors for each gene token
    for token in doc:
        if find_ner_term(gene, token.text):
            # Need to reverse list an select the first before they are different
            gene_ancestors.append([(a.text, a.i) for a in token.ancestors][::-1])
        if find_ner_term(disease, token.text):
            dis_ancestors.append([(a.text, a.i) for a in token.ancestors][::-1])
    pairs = [(g,d) for g in gene_ancestors for d in dis_ancestors]
    common_ancestors = []
    for p in pairs:
        common = ''
        depth = -1
        for gene_ancestor, disease_ancestor in zip(p[0], p[1]):
            if gene_ancestor == disease_ancestor:
                common = disease_ancestor[0]
                depth += 1 
                
            # if they aare different the trees diverge
            else:
                break
        common_ancestors.append((common, depth, len(p[0]) - depth, len(p[1]) - depth))
    return set(common_ancestors)

In [31]:
pos_counts = Counter()
for doc in docs:
    for token in doc:
        pos_counts[token.pos_] += 1
          
def pos_dist(doc):
    """ Gives the normalized (sum of tags = 1) pos distribution"""
    counter = {k:0 for k in list(pos_counts.keys())}
    for token in doc:
        if token.pos_ in counter:
            counter[token.pos_] += 1
        else:
            # The X POS tag is other, can be used if POS not present in main counts
            counter['X'] += 1
    # Normalize counts to sum to 1
    return [x/len(doc) for x in list(counter.values())]

In [32]:
chunk_roots = Counter()
for doc in docs:
    for chunk in doc.noun_chunks:
        chunk_roots[chunk.root.lemma_] += 1
          
def chunk_root_normalized(doc):
    """ Gives the normalized count of chunk value by # of chunks  for top 100 lemma roots of 
        chunks in training set"""
    counter = {k:0 for k in [x[0] for x in chunk_roots.most_common(100)]}
    n_chunks = 0
    for chunk in doc.noun_chunks:
        n_chunks += 1
        if chunk.root.lemma_ in counter:
            counter[chunk.root.lemma_] += 1
    # Normalize counts
    return [x/n_chunks for x in list(counter.values())]

In [33]:
chunk_heads = Counter()
for doc in docs:
    for chunk in doc.noun_chunks:
        chunk_heads[chunk.root.head.lemma_] += 1
          
def chunk_head_normalized(doc):
    """ Gives the normalized count of chunk value by # of chunks  for top 100 lemma heads of 
        chunks in training set"""
    counter = {k:0 for k in [x[0] for x in chunk_heads.most_common(100)]}
    n_chunks = 0
    for chunk in doc.noun_chunks:
        n_chunks += 1
        if chunk.root.lemma_ in counter:
            counter[chunk.root.lemma_] += 1
    # Normalize counts
    return [x/n_chunks for x in list(counter.values())]

In [34]:
def process_ner(x):
    return x.upper().replace(' ', '_')

data = pd.read_csv('../dataset/GAD_Y_N_wPubmedID_annotated_cap.csv', usecols=[6, 9, 11], skiprows = [0],
                   header=None, names=['gene', 'disease', 'sentence'])

data.gene = data.gene.apply(process_ner)
data.disease = data.disease.apply(process_ner)

In [35]:
docs = []
for index, entry in data.iterrows():
    docs.append(join_bad_splits(nlp(entry.sentence)))

In [36]:
with open('annotated_cap_dist_features.csv', 'w') as f:
    writer = csv.writer(f, lineterminator='\n')
    for e, d in zip(data.iterrows(), docs):
        writer.writerow(tree_distance(e[1].gene, e[1].disease, d))

In [37]:
with open('annotated_cap_common_word.csv', 'w') as f:
    writer = csv.writer(f, lineterminator='\n')   
    for d, e in zip(docs, data.iterrows()):
        writer.writerow(common_ancestor(e[1].gene, e[1].disease, d))

In [38]:
with open('annotated_cap_pos_dist.csv', 'w') as f:
    writer = csv.writer(f, lineterminator='\n')
    writer.writerow(list(pos_counts.keys()))
    for doc in docs:
        writer.writerow(pos_dist(doc))

In [39]:
with open('annotated_cap_chunk_roots.csv', 'w') as f:
    writer = csv.writer(f, lineterminator='\n')
    writer.writerow(list(chunk_roots.keys()))
    for doc in docs:
        writer.writerow(chunk_root_normalized(doc))

In [40]:
with open('annotated_cap_chunk_heads.csv', 'w') as f:
    writer = csv.writer(f, lineterminator='\n')
    writer.writerow(list(chunk_heads.keys()))
    for doc in docs:
        writer.writerow(chunk_head_normalized(doc))

In [25]:
for d, e in zip(docs, data.iterrows()):
    if 1 != sum([find_ner_term(e[1].gene, t.text) for t in d]):
        print(e[1].gene, e[1].sentence)
    if 1 != sum([find_ner_term(e[1].disease, t.text) for t in d]):
        print(e[1].disease, e[1].sentence)
        test = d

In [23]:
[t for t in test]

[our,
 results,
 suggest,
 that,
 CTLA-4,
 gene,
 polymorphisms,
 may,
 partially,
 be,
 involved,
 in,
 the,
 susceptibility,
 to,
 CHRONIC_HEPATITIS_B.]