# Implant Complications


In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import numba
import metal
import random
from brat import *
import numpy as np
import matplotlib
import pandas as pd

os.environ['SNORKELDB'] = "postgresql://inkfish@127.0.0.1:4554/inkfish"
from snorkel import SnorkelSession

#from snorkel.learning.disc_models.rnn import *
from snorkel.annotations import LabelAnnotator
from snorkel.annotations import load_gold_labels
from snorkel.models import candidate_subclass, Document, Sentence, Candidate, Span
from snorkel.learning import GenerativeModel

In [None]:
session = SnorkelSession()

# Define a candidate space
try:
    ImplantComplication = candidate_subclass('ImplantComplication', ['implant','complication'])
except:
    print("candidate subclass already exists, skipping...")

### Load Candidates and Gold Labels

In [None]:
def get_doc_ids_by_name(session, doc_names):
    return session.query(Document.id).filter(Document.name.in_(doc_names)).all()

def get_cands_by_doc(session, doc_ids, candidate_class):
    """
    :param session:
    :param doc_ids:
    :param candidate_class:
    :return:
    """
    q1 = session.query(Sentence.id).filter(Sentence.document_id.in_(doc_ids)).subquery()
    q2 = session.query(Span.id).filter(Span.sentence_id.in_(q1)).subquery()
    q3 = session.query(candidate_class.id).filter(candidate_class.implant_id.in_(q2)).subquery()
    return session.query(Candidate).filter(Candidate.id.in_(q3)).all()

def init_implant_complications_splits(session, num_training_docs=600):
    """
    Create Train/Dev/Test splits
    """
    np.random.seed(1234)

    gold_docs = {'13899510', '23094601', '12189519', '20555550', '15347915', '15727234', '19325557', '13753588', 
                 '22090431', '14292702', '13881034', '19439270', '15147663', '10491490', '19655487', '11778639', 
                 '13612323', '19848445', '14718096', '10403411', '20195826', '11486237', '17576952', '15490123', 
                 '17579845', '1297680', '17410315', '18989394', '1012657', '19147933', '20860485', '12813059', 
                 '20968892', '10593203', '23348587', '1606569', '1006811', '22920273', '14116822', '10734397', 
                 '13645052', '12450938', '18029864', '15241311', '15854097', '21074206', '16490964', '1714600',
                 '2368721', '23511775'}

    gold_doc_ids = [id_[0] for id_ in get_doc_ids_by_name(session, gold_docs)]
    
    # build training set
    doc_ids = [doc.id for doc in session.query(Document).all()]
    train_doc_ids = random.sample(doc_ids, num_training_docs + len(gold_doc_ids))
    
    # ensure there isn't any overlap between training docs and dev/test
    train_doc_ids = list(set(train_doc_ids).difference(gold_doc_ids))[0:num_training_docs]
    print("Training docs: {}".format(len(train_doc_ids)))

    # train_doc_ids = set(train_doc_ids).difference(gold_docs)
    train_cands = get_cands_by_doc(session, train_doc_ids , ImplantComplication)
    
    # split into dev/test splits
    split = int(len(gold_doc_ids) * 0.50)
    dev_cands   = get_cands_by_doc(session, gold_doc_ids[split:], ImplantComplication)
    test_cands  = get_cands_by_doc(session, gold_doc_ids[0:split], ImplantComplication)

    # assign to splits
    for c in train_cands:
        c.split = 1
    for c in dev_cands:
        c.split = 2
    for c in test_cands:
        c.split = 3
        
    session.commit()
    
def reset_split(candidates):
    """Remove all candidate split information."""
    for c in candidates:
        c.split = 0
    session.commit()


### Load candidates and gold labels

In [None]:
#init_implant_complications_splits(session)
# reset_split(train_cands)
# reset_split(dev_cands)
# reset_split(test_cands)
'''
X_train = session.query(Candidate).filter(Candidate.split == 1).all()
X_dev   = session.query(Candidate).filter(Candidate.split == 2).all()
X_test  = session.query(Candidate).filter(Candidate.split == 3).all()

print(len(X_train))
print(len(X_dev))
print(len(X_test))
'''

In [None]:
scaleup = {
    150 :'/data4/jfries/scaleup-samples/docs_150.cands_798.tsv',
    500 :'/data4/jfries/scaleup-samples/docs_500.cands_2699.tsv',
    1000 :'/data4/jfries/scaleup-samples/docs_1000.cands_5363.tsv',  # 17 s
    5000 :'/data4/jfries/scaleup-samples/docs_5000.cands_28028.tsv',
    10000 :'/data4/jfries/scaleup-samples/docs_10000.cands_55851.tsv', 
    20000 :'/data4/jfries/scaleup-samples/docs_20000.cands_111421.tsv',
    30000 :'/data4/jfries/scaleup-samples/docs_30000.cands_168844.tsv',
    40000 :'/data4/jfries/scaleup-samples/docs_40000.cands_224855.tsv',
    50000 :'/data4/jfries/scaleup-samples/docs_50000.cands_282004.tsv', # 14min 6s
    60000 :'/data4/jfries/scaleup-samples/docs_60000.cands_337560.tsv',
    70000 :'/data4/jfries/scaleup-samples/docs_70000.cands_394021.tsv',
    80000 :'/data4/jfries/scaleup-samples/docs_80000.cands_449237.tsv',
    90000 :'/data4/jfries/scaleup-samples/docs_90000.cands_504943.tsv',
    100000 :'/data4/jfries/scaleup-samples/docs_100000.cands_560198.tsv' 
}

def load_train_sample(fpath):
    """Load TSV defining documents that comprise our training set"""
    d = [row.strip().split("\t") for row in open(fpath,'r').read().splitlines()]
    return {r[0]:int(r[1]) for r in d}

def load_train_candidates(session, fpath, candidate_class):
    doc_names = load_train_sample(fpath).keys()
    doc_ids = get_doc_ids_by_name(session, doc_names)
    return get_cands_by_doc(session, doc_ids, candidate_class)

In [None]:
NUM_SAMPLES = 500

X_train = load_train_candidates(session, scaleup[NUM_SAMPLES], ImplantComplication)
X_dev   = session.query(Candidate).filter(Candidate.split == 2).all()
X_test  = session.query(Candidate).filter(Candidate.split == 3).all()

### Load documents

In [None]:
def get_docs_by_cands(candidates):
    doc_names = set([c[0].get_stable_id().split(":")[0] for c in candidates])
    return session.query(Document).filter(Document.name.in_(doc_names)).all()
    
train_docs = get_docs_by_cands(X_train)
dev_docs   = get_docs_by_cands(X_dev)
test_docs  = get_docs_by_cands(X_test)

documents = train_docs + dev_docs + test_docs

print(len(train_docs))
print(len(dev_docs))
print(len(test_docs))

In [None]:
# from metal.label_model.baselines import MajorityLabelVoter

# mv = MajorityLabelVoter(seed=123)
# scores = mv.score(L_dev, L_gold_dev, metric=['precision', 'recall', 'f1'])

### Load BRAT Gold Annotations

In [None]:
gold = BratAnnotations("/data4/jfries/brat-iaa/")
gold.annotator_summary()

'''
gold.annotator_agreement(ignore_types=['Header',"Anatomy",'Indication'], 
                         relations_only=True, method='randolph')
'''

class_map = lambda c: 1 if c["Finding"].attribute("PresentPositive") else 2
gold.init_labels(class_map, ['Complication'], verbose=True)

Y_dev  = gold.get_labels(X_dev, neg_label=2)
Y_test = gold.get_labels(X_test, neg_label=2)

print("[DEV]  T/F: {} {}".format(list(Y_dev).count(1), list(Y_dev).count(2)))
print("[TEST] T/F: {} {}".format(list(Y_test).count(1), list(Y_test).count(2)))


In [None]:
Y_dev  = gold.get_labels(X_dev)
Y_test = gold.get_labels(X_test)

Y_dev[Y_dev == 0] = 2
Y_test[Y_test == 0] = 2

print(len(Y_dev))
print(len(Y_test))

## Weak Supervision Labeling Functions

In [None]:
import collections

cands_by_doc = collections.defaultdict(list)
for c in X_train:
    doc_name = c[0].get_stable_id().split(":")[0]
    cands_by_doc[doc_name].append(c)
    
for c in X_dev:
    doc_name = c[0].get_stable_id().split(":")[0]
    cands_by_doc[doc_name].append(c)

for c in X_test:
    doc_name = c[0].get_stable_id().split(":")[0]
    cands_by_doc[doc_name].append(c)


## Extract clinical note primitives

In [None]:
def apply_taggers(documents, taggers, ngrams=6, stopwords=[]):
    """ Apply taggers to documents """
    markup = defaultdict(lambda :defaultdict(list))
    for doc in documents: 
        for name in taggers:
            tags = taggers[name].tag(doc, ngrams=ngrams, stopwords=stopwords)
            for layer in tags:
                markup[doc.name][layer] = tags[layer]
    return markup

In [None]:
from tagger import DatetimeTagger, UMLSTagger, SectionHeaderTagger

# target UMLS concepts
concepts = {
'ACTIVITY' : ['daily_or_recreational_activity'],
'CHEMICAL' : ['clinical_drug', 'antibiotic', 'pharmacologic_substance', 'vitamin'],
'DISORDER' : ['disease_or_syndrome', 'acquired_abnormality', 'sign_or_symptom', 
              'injury_or_poisoning', 'congenital_abnormality', 
              'anatomical_abnormality', 'pathologic_function'],
'ANATOMY'  : ['body_part,_organ,_or_organ_component', 'body_location_or_region', 
              'body_space_or_junction', 'body_system'],
'PROCEDURE': ['diagnostic_procedure', 'therapeutic_or_preventive_procedure', 
              'health_care_activity'],
'BACTERIUM': ['bacterium'],
'TEMPORAL' : ['temporal_concept']
}

taggers = {
"umls"   : UMLSTagger(concepts, data_root="../data/supervision/ontologies/UMLS_2014AB/data/"),
"date"   : DatetimeTagger(),
"header" : SectionHeaderTagger()
}

In [None]:
from rwe.utils import load_dict

# stopwords
sw = load_dict("../data/supervision/dicts/stopwords.txt")
sw = sw.union(set(['today', 'per', 'md', 'unknown', 'date', 'add', 'active', 'none', 
                   'report', 'doc', 'control', 'stopping', 'level', 'tomorrow', 
                   'ser', 'relief', 'air', 'new', 'take', 'weight', 'skin', 'edema']))

tagged_sentences = apply_taggers(train_docs + dev_docs + test_docs, 
                              taggers, ngrams=6, stopwords=sw)

### Helper functions

In [None]:
#
# Document
#

def get_note_sign_date(markup, field='T', note_type=None):
    """
    Some notes include footer info of the form:
    D: 01/01/2001 08:00 A CT
    T: 01/01/2001 09:00 P CT / SPH
    
    others include
    Date: 
    
    TODO: Dates seem to be note_type dependant
    TODO: Ask what D and T mean (start/close?)
    TODO: How does this date compare to the jittered timestamp in STRIDE?
    """
    sign_date = None
    matches = []
    for sidx in markup["HEADER"]:
        # no header found 
        if not markup["HEADER"][sidx]:
            continue
        h1 = markup["HEADER"][sidx][0].get_span()
        # for all dates under the matched header, return the max
        dates = markup["DATETIME"][sidx] if sidx in markup["DATETIME"] else []
        if dates and re.search("^\s*{}[:]".format(field), h1):
            matches.append(dates[0])
            
    # sometimes multiple dates appear on a single line due to sentence boundary errors
    if matches:
        sign_date = sorted(matches, key=lambda x:x[-1], reverse=1)[0]
        
    # select max date from all within document dates
    elif markup["DATETIME"]:   
        ts = list(itertools.chain.from_iterable(markup["DATETIME"].values()))
        ts = [m for m in ts if m[-1]]
        if ts:   
            sign_date = sorted(ts, key=lambda x:x[-1], reverse=1)[0]
            
    return sign_date

#
# Sentences
# 

def get_sentence_markup(sentence, layer, markup):
    """ """
    doc_name = sentence.document.name

    if doc_name not in markup or layer not in markup[doc_name] \
    or sentence.position not in markup[doc_name][layer] \
    or markup[doc_name][layer][sentence.position] == None:
        return []
    return sorted(markup[doc_name][layer][sentence.position], key=lambda x:x.char_start, reverse=0)

def sentence_contains_list(c, threshold=4):
    """ """
    s = " ".join([w for w in sent.words if w.strip()]).strip()
    tokens = re.split(r'''[;,]''', s)
    return len(tokens) > threshold

#
# Candidates
#

def is_in_list(c, threshold=4):
    """ If sentence contains a list (defined by number of commas), is this item in the list?"""
    if not sentence_contains_list(c.get_parent(), threshold):
        return False
    # TODO

def is_list_item(s, threshold=4):
    """
    Identfy sentences of the form: 
        1. Patient underwent surgery. 
        2. Recovery went well
    """
    text = c.get_parent().text.strip()
    return True if re.search("^[1-9]+[.)]|•", text) else False

def is_past_tense(c, verb_window="left"):
    # VBD	Verb, past tense
    # VBN	Verb, past participle
    past_tense = set(['VBD', 'VBN'])
    btw = list(get_between_tokens(c, attrib='pos_tags', case_sensitive=True))
    return True if past_tense.intersection(btw) else False


#print(get_sentence_markup(X_train[0].get_parent(), "DISORDER", markup))
#print(get_sentence_markup(X_train[0].get_parent(), "HEADER", markup))
#print(get_sentence_markup(X_train[0].get_parent(), "DATE", markup))

doc_ts = {}
doc_ts['train'] = {doc.name:get_note_sign_date(tagged_sentences[doc.name]) for doc in train_docs}
n = len(doc_ts['train']) - list(doc_ts['train'].values()).count(None)
print("Extracted {:>5.1f}% ({:>4}/{:>4}) [TRAIN] document timestamps".format(n / len(doc_ts['train']) * 100, n , len(doc_ts['train'])))

doc_ts['dev'] = {doc.name:get_note_sign_date(tagged_sentences[doc.name]) for doc in dev_docs}
n = len(doc_ts['dev']) - list(doc_ts['dev'].values()).count(None)
print("Extracted {:>5.1f}% ({:>4}/{:>4}) [DEV]   document timestamps".format(n / len(doc_ts['dev']) * 100, n , len(doc_ts['dev'])))

doc_ts['test'] = {doc.name:get_note_sign_date(tagged_sentences[doc.name]) for doc in test_docs}
n = len(doc_ts['test']) - list(doc_ts['test'].values()).count(None)
print("Extracted {:>5.1f}% ({:>4}/{:>4}) [TEST]  document timestamps".format(n / len(doc_ts['test']) * 100, n , len(doc_ts['test'])))

doc_ts_all = {}
doc_ts_all.update(doc_ts['train'])
doc_ts_all.update(doc_ts['dev'])
doc_ts_all.update(doc_ts['test'])


In [None]:
from helpers import *

doctimes = extract_doctimes(documents, tagged_sentences)

In [None]:
from collections import defaultdict

DOCTIME_OVERLAP = 1
DOCTIME_BEFORE  = 2

TRUE = 1
FALSE = 2

historical_headers = set(['ADMITTING HISTORY','PAST SURGICAL HISTORY', 'CLINICAL HISTORY',
                          'PAST MEDICAL HISTORY', 'Past Medical/Surgical History'])
present_illness_headers = set(['HISTORY OF PRESENT ILLNESS', 'IMPRESSION', 'DIAGNOSIS', 
                               'FINDINGS', 'ID/HPI', 'HPI', 'History of Present Illness'])

#OPERATION PERFORMED
#PROCEDURE IN DETAIL
#INDICATIONS

historical_headers = set([x.lower() for x in historical_headers])
present_illness_headers = set([x.lower() for x in present_illness_headers])

#
# Helper Functions
#
def get_relation_head(c):
    """Get first word (in order of terms) in relation pair"""
    return c[0] if c[0].char_start < c[1].char_start else c[1]

# def get_span_entity_type(span, markup):
#     """Given a span and sentence markup, determine entity type """
#     pass

def overlaps(c, span):
    char_start, char_end = span
    v = c.char_start >= char_start and c.char_start <= char_end
    return v or c.char_end >= char_start and c.char_end <= char_end
    
    
def contained_by(c, span):
    char_start, char_end = span
    return c.char_start >= char_start and c.char_end <= char_end   


def get_span_entities(sentence, span, layer, markup):
    """Return all entities of type 'layer' in the provided span."""
    entities = get_sentence_markup(sentence, layer, markup)
    if not entities:
        return []
    return [e for e in entities if contained_by(e, span)]
    #return sorted(entities, key=lambda x:x.char_start, reverse=0)

def is_hypothetical(c, window=25):
    """ Hypothetical mention detection akin to NegEx """
    head = c.implant if c.implant.char_start < c.complication.char_start else c.complication
    left = " ".join(get_left_tokens(head, window=window))
    rgxs = [r'''\b(if need be)\b''',
            r'''\b((if|should)\s+(you|she|he|be)|(she|he|you)\s+(might|could|may)\s*(be)*|if)\b''',
            r'''\b((possibility|potential|chance|need) (for|of)|potentially)\b''',
            r'''\b(possible)\b''',
            r'''\b(candidate for)\b''',
            r'''\b(assuming)\s+(you|she|he)\b'''
           ]
    for rgx in rgxs:
        if re.search(rgx, left, re.I):
            return True
    return False

def is_history_of(c, window=25):
    """Historical mention detection akin to NegEx """
    head = c.implant if c.implant.char_start < c.complication.char_start else c.complication
    left = " ".join(get_left_tokens(head, window=window))
    rgxs = [r'''\b(h/o|hx|history of)\b''']
    for rgx in rgxs:
        m = re.search(rgx, left, re.I)
        if m:
            return True
    return False

def get_section_header(c):
    """
    What section header does this candidate live under? 
    """
    header = get_sentence_markup(c.get_parent(), "HEADER", tagged_sentences)
    
    try:
        return (header[0].get_span().replace(":", "") if header[0] else None) if header else None
    except IndexError:
        return None

def get_doctime_class(c, doc_ts, threshold = 24 * 60 * 60):
    """
    Use DATETIME layer and document timestamp to heuristically determine if
    mention occurs during the note (DOCTIME_OVERLAP) or sometime 
    in the past (DOCTIME_BEFORE). This uses in-sentence dates to determine overlap.
    
    TODO: cleanup
    """
    if not doc_ts:
        return None
    sent_ts = get_sentence_markup(c.get_parent(), "DATETIME", tagged_sentences)
        
    if not sent_ts:
        return None
    sent_ts = [d for d in list(zip(*sent_ts))[-1] if d]
    if not sent_ts:
        return None
    sent_ts = max(sent_ts)
    doc_ts = doc_ts[-1]
    
    if doc_ts == sent_ts:
        return DOCTIME_OVERLAP

    tdelta = (doc_ts - sent_ts).total_seconds()
    return DOCTIME_BEFORE if tdelta > threshold else DOCTIME_OVERLAP

In [None]:
def misattached_entities2_v2(c):
    """
    Check whether a pain anatomy mention is mis-attached where pain mention precedes anatomy mention
    e.g. if note contains 'chest pain, left leg also tender'
    and candidate is (pain, left leg)

    :param c: pain-anatomy candidate
    :return: boolean; True if candidate is misattached, False otherwise
    """
    right_window = get_right_tokens(c, 10)
    between_tokens = get_between_tokens(c)
    between_phrase = ' '.join(between_tokens).lower()
    
    #reduced accuracy of LF from 86 to 84, but improved F1
    '''
    complication_boolean = False
    if any(complication in between_phrase for complication in complications) and "and" in between_phrase:
        complication_boolean = True
    '''
    of_the_boolean = "of the" in between_phrase
    due_to_boolean = "due to" in between_phrase 
    
    b = c.complication.char_end < c.implant.char_start
    b &= list_contains_pain_mention(right_window)
    b &= not of_the_boolean
    b &= not due_to_boolean
    
    return True if b else False

def misattached_entities4(c):
    between_tokens = list(get_between_tokens(c))
    right_window = get_right_tokens(c, 14)

    b = c.implant.char_end < c.complication.char_start
    b &= list_contains_anatomy_mention(right_window)
    b &= not list_contains_anatomy_mention(between_tokens)

    return True if b else False

def misattached_any(c):

    b = misattached_entities(c)
    b |= misattached_entities2(c)
    b |= misattached_entities2_v2(c)
    b |= misattached_entities3(c)
    b |= misattached_entities4(c)
    
    return True if b else False

## New labeling functions

In [None]:
anatomy_dict = load_dict("../data/supervision/dicts/implant_types/anatomy_shc.txt")

implant_dict = load_dict("../data/supervision/dicts/implant_types.filtered.txt")

indications = set(['deep vein thrombosis', 'dvt', 'degenerative joint disease', 'narrowing',
                   'dvt', 'fracture', 'osteoarthritis', 'avascular necrosis'])

complications_all = load_dict("../data/supervision/dicts/implant_complications.all.txt")
complications = set(['infection', 'infected', 'wear', 'osteolysis', 'lucency', 'lucencies', 'revision', 'migration'])

def LF_anatomy_mention(c):
    """Mention is in anatomy dictionary"""
    mention = c.implant.get_span()
    
    complication = c.complication.get_span().lower()
    
    is_revision = "revision" in complication
    
    v = mention in anatomy_dict
    v &= not is_revision
    
    return -1 if v else 0
    
def LF_anatomy_as_implant(c):
    """Anatomy term referring to an implant + clear complication"""
    mention = c.implant.get_span().lower()
    if not mention in anatomy_dict:
        return 0
    keywords = set(['replacement', 'wear', 'heterotopic ossification', 
                    'mechanical failure', 'migration'])
    lemma = " ".join([w for w in c.complication.get_attrib_tokens('lemmas') if w.strip()])
    return -1 if lemma in keywords else 0
    
def LF_anatomy_revision(c):
    """Anatomy term referring to an implant + revision"""
    
    distance = len(list(get_between_tokens(c)))
    
    mention = c.implant.get_span().lower()
    if not mention in anatomy_dict:
        return 0
    keywords = set(['revision'])
    lemma = " ".join([w for w in c.complication.get_attrib_tokens('lemmas') if w.strip()])
    
    v = distance < 3
    v &= lemma in keywords
    
    return 1 if v else 0

def LF_anatomy_pain_implant(c):
    """Hip pain due to infected prosthesis"""
    mention = c.implant.get_span().lower()
    complication = c.complication.get_span().lower()

    right_window_text = ' '.join(get_right_tokens(c, window=15))
    
    #if not mention in anatomy_dict:
    #    return 0
    
    #if not "pain" in complication:
    #    return 0
   
    implant_boolean  = False

    if (any(implant_term in right_window_text for implant_term in implant_dict) or 'prosthesis' in right_window_text) and 'due to' in right_window_text:
        implant_boolean = True    
    
    return 1 if implant_boolean else 0

def LF_indication(c):
    """Anatomy term + common indication"""
    mention = c.implant.get_span().lower()
    if not mention in anatomy_dict:
        return 0
    keywords = set()
    lemma = " ".join([w.lower() for w in c.complication.get_attrib_tokens('lemmas') if w.strip()])
    return -1 if lemma in indications else 0

def LF_implant_indication(c):
    """Implant term + common indication"""
    
    mention = c.implant.get_span().lower()
    
    implant_boolean = False
    
    if any(implant_term in mention for implant_term in implant_dict):
        implant_boolean = True
    keywords = set()
    lemma = " ".join([w.lower() for w in c.complication.get_attrib_tokens('lemmas') if w.strip()])
    
    v = implant_boolean
    v&= lemma in indications
    
    return -1 if v else 0
 
def LF_hypothetical(c):
    sent_spans = get_sent_candidate_spans(c)
    sent = ''
    for span in sent_spans:
        words = span.get_parent()._asdict()['words']
        sent += ' '.join(words)
        sent = sent.lower()
    
    b1 = is_hypothetical(c)
    b = False
    if 'prevent' in sent or b1:
        b = True
    return -1 if b else 0

def count_implant_mentions_in_doc(c):
    doc_name = c[0].get_stable_id().split(":")[0]
    implant_mention = c.implant.get_span().lower()

    implants = []
    for cs in cands_by_doc[doc_name]:
        implants.append(cs.implant)
    
    counts = 0
    
    for cs_implant in implants:
        if any(implant_term in cs_implant.get_span().lower() for implant_term in implant_dict):
                counts +=1
    return counts

def count_implant_laterality_mentions_in_doc(c, laterality):
    
    doc_name = c[0].get_stable_id().split(":")[0]

    implants = []
    for cs in cands_by_doc[doc_name]:
        implants.append(cs.implant)
    
    counts = 0
    
    for cs_implant in implants:
        if any(implant_term in cs_implant.get_span().lower() for implant_term in implant_dict):
            if laterality in cs_implant.get_span().lower():
                counts +=1
    return counts

def LF_misattached_entities2_v2(c):
    """
    Check whether a pain anatomy mention is mis-attached where pain mention precedes anatomy mention
    e.g. if note contains 'chest pain, left leg also tender'
    and candidate is (pain, left leg)

    :param c: pain-anatomy candidate
    :return: boolean; True if candidate is misattached, False otherwise
    """
        
    b = misattached_entities2_v2(c)
    return -1 if b else 0

def LF_anatomy_implant(c):
    doc_name = c[0].get_stable_id().split(":")[0]
    implant_mention = c.implant.get_span().lower()
    
    implants = []
    for cs in cands_by_doc[doc_name]:
        implants.append(cs.implant)
    
    left_counts = 0
    right_counts = 0
    
    if not implant_mention in anatomy_dict:
        return 0
    for cs_implant in implants:
        if any(implant_term in cs_implant.get_span().lower() for implant_term in implant_dict):
                if 'left' in cs_implant.get_span().lower():
                    left_counts +=1
                if 'right' in cs_implant.get_span().lower():
                    right_counts +=1
    
    b = False
    if ('left' in implant_mention and left_counts > 1) or ('right' in implant_mention and right_counts > 1):
        b = True
    #print(b)
    return 1 if b else 0

def LF_date(c):
    right_window = get_right_tokens(c, window=15)
    right_window_text = ' '.join(right_window)
    
    pattern = '[0-9]{1,2}/[0-9]{1,2}/[0-9]{2,4}'

    matcher = re.compile(pattern, flags=re.I)
    result = matcher.search(right_window_text)

    b = result is not None
        
    return -1 if b else 0

def LF_complication(c):
    """Implant term + common complication"""
    
    distance = len(list(get_between_tokens(c)))
    
    sent_spans = get_sent_candidate_spans(c)
    sent = ''
    for span in sent_spans:
        words = span.get_parent()._asdict()['words']
        sent += ' '.join(words)
        sent = sent.lower()

    negated_boolean = False
    if "without evidence" in sent or "no evidence" in sent:
        negated_boolean = True
    
    implant_mention = c.implant.get_span().lower()
    
    implant_boolean = False
    if any(implant_term in implant_mention for implant_term in implant_dict):
        implant_boolean = True
    
    lemma = " ".join([w.lower() for w in c.complication.get_attrib_tokens('lemmas') if w.strip()])

    history_of = is_history_of(c)
    
    misattached = misattached_any(c)
        
    doc = c.get_parent().document
    doc_ts = doctimes[doc.name]
    doctime = get_doctime_class(c, doc_ts, tagged_sentences)
        
    before = False
    
    # mention occurs before note doctime
    if doctime == DOCTIME_BEFORE:
        before = True
         
    v = lemma in complications
    v &= distance < 8
    v &= not misattached
    v &= implant_boolean
    v &= not negated_boolean
    v &= not history_of 
    v &= not before
    
    return 1 if v else 0

def LF_misattached_entities4(c):
    
    b = misattached_entities4(c)
    
    return -1 if b else 0

def LF_misattached_any(c):
    
    b = misattached_any(c)
    
    return -1 if b else 0
    
def LF_contiguous_right_finding(c):
    """
    Check for complication-anatomy candidates that are compound mentions
    e.g 'hip revision', 'implant infection'
    where finding mention is directly attached to implant mention
    and where the mention is not negated (using Negex)

    :param c: finding-implant candidate
    :return: 1 if True, 0 otherwise
    """
    possible_terms = [x['term'].split(' ') for x in negex.dictionary['definite'] if x['direction'] == 'forward']
    longest = len(max(possible_terms, key=len))
    window = longest + 2
    distance = len(list(get_between_tokens(c)))

    history_of = is_history_of(c)
    
    hypothetical = is_hypothetical(c)
    
    implant_boolean = False
    if any(implant_term in c.implant.get_span().lower() for implant_term in implant_dict):
        implant_boolean = True
        
    fracture_boolean = False
    if any(fracture_term in c.complication.get_span().lower() for fracture_term in ["fracture", "nonunion", "non-union"]):
        fracture_boolean = True
    
    removal_boolean = False
    if any(removal_term in c.complication.get_span().lower() for removal_term in ["removal", "remove", "removed"]):
        fracture_boolean = True
    
    v = distance < 1
    v &= implant_boolean
    v &= not fracture_boolean
    v &= not history_of
    v &= not hypothetical
    v &= c.complication.char_end > c.implant.char_end
    v &= not negex.is_negated(c.complication, 'definite', 'left', window)
    
    return 1 if v else 0

def LF_contiguous_left_finding(c):
    """
    Check for finding-implant candidates that are compound mentions
    e.g 'infection hip'
    where finding mention is directly attached to implant mention
    and where the mention is not negated (using Negex)

    :param c: pain-anatomy candidate
    :return: 1 if True, 0 otherwise
    """
        
    misattached = misattached_any(c)
    
    v = len(list(get_between_tokens(c))) < 1
    v &= not misattached
    v &= c.complication.char_end < c.implant.char_end

    return 1 if v else 0

def LF_near_contiguous_right_finding(c):
    
    possible_terms = [x['term'].split(' ') for x in negex.dictionary['definite'] if x['direction'] == 'forward']
    longest = len(max(possible_terms, key=len))
    left_window_length = longest + 2
    left_window = get_left_tokens(c, window=left_window_length)
    between_tokens = list(get_between_tokens(c))

    f = (lambda w: w.lower())
    between_terms = [ngram for ngram in tokens_to_ngrams(
        map(f, c.complication.get_parent()._asdict()['words'][c.implant.get_word_start():c.complication.get_word_start() + 1]),
        n_max=1)]
    between_phrase = ' '.join(between_terms) + ' '

    negated_in_between = False

    for pt in possible_terms:
        pattern = '\s'.join(pt) + '[-\s]*'
        negated_in_text = regex_in_text(pattern, between_phrase)
        if negated_in_text:
            negated_in_between = True

    odd_list = ["flexion", "regimen", "raise"]
    odd = False
    for oe in odd_list:
        if oe in between_phrase:
            odd = True

    pain_between = False
    for nterm in complications_all:
        if nterm in between_tokens or nterm in left_window:
            pain_between = True
    
    misattached = misattached_any(c)
    
    hypothetical = is_hypothetical(c)
                
    v = len(between_tokens) < 10
    v &= ',' in between_tokens or 'and' in between_tokens
    v &= c.complication.char_start > c.implant.char_end
    v &= not negex.is_negated(c.implant, 'definite', 'left', left_window_length)
    v &= not negated_in_between
    v &= not pain_between
    v &= not candidate_in_list(c)
    v &= not misattached
    v &= not hypothetical
    v &= not odd

    return 1 if v else 0

def LF_near_contiguous_left_finding(c):
    """
    Check for finding-implant candidates that are
    non-contiguous but close mentions,
    e.g., 'pain in the hip', 'tenderness of the right side'

    :param c: finding-implant candidate
    :return: 1 if True, 0 otherwise
    """
    between_tokens = list(get_between_tokens(c))

    left_window_length = 8
    negated = negex.is_negated(c, 'definite', 'left', left_window_length)

    right_window = get_right_tokens(c, window=3)

    pain_in_right_window = list_contains_pain_mention(right_window)

    misattached = misattached_any(c)
    
    v = len(between_tokens) < 3
    v &= c.complication.char_end < c.implant.char_start
    v &= not misattached
    v &= not pain_in_right_window
    v &= not negated
    v &= not '(' in between_tokens
    v &= not ',' in between_tokens

    return 1 if v else 0

def LF_long_distance_left_finding(c):
    """
    Check for finding-implant candidates that are
    long distance mentions,
    e.g., 'infection in the left THA'
    where candidate is (pain, left ankle)

    :param c: pain-anatomy candidate
    :return: 1 if True, 0 otherwise
    """

    between_tokens = list(get_between_tokens(c))
    right_window = get_right_tokens(c, 3)
    left_window_length = 5
    left_window = get_left_tokens(c, left_window_length)
    
    hypothetical = is_hypothetical(c)
    
    misattached = misattached_any(c)
    
    v = len(between_tokens) < 10
    v &= c.complication.char_end < c.implant.char_start
    v &= not list_contains_pain_mention(right_window)
    v &= not list_contains_pain_mention(between_tokens)
    v &= not list_contains_anatomy_mention(left_window)
    v &= not negex.is_negated(c, 'definite', 'left', left_window_length)
    v &= not date_between(c)
    v &= not candidate_in_list(c)
    v &= not left_pain_multiple_anatomy(c)
    v &= not hypothetical
    v &= not misattached

    return 1 if v else 0

def LF_right_finding_causing(c):
    between_tokens = list(get_between_tokens(c))
    
    left = " ".join(get_left_tokens(c.implant, window=5))
    
    complication_boolean = any(complication in left for complication in complications)

    v = "causing" in between_tokens
    v &=  c.implant.char_end < c.complication.char_start
    
    return 1 if v else 0

def LF_icd_complication(c):
    complication = c.complication.get_span().lower()
    
    v = '996' in complication
    
    return 1 if v else 0

def LF_implant_revision(c):
    complication =  c.complication.get_span().lower()
    implant = c.implant.get_span().lower()
    
    revision_boolean = "revision" in complication
    
    implant_boolean = any(implant_term in implant for implant_term in implant_dict)
    
    misattached = misattached_any(c)
    
    v = revision_boolean
    v &= implant_boolean
    v &= not misattached
    return 1 if v else 0

def LF_implant_bacteria(c):
    
    implant = c.implant.get_span().lower()

    complication = c.complication
    
    bacterium = get_sentence_markup(c.get_parent(), "BACTERIUM", tagged_sentences)
        
    bacteria_boolean = False
    
    implant_boolean = any(implant_term in implant for implant_term in implant_dict)
    
    if bacterium:
        for bacteria in bacterium:
            if overlaps(complication, (bacteria.char_start, bacteria.char_end)):
                bacteria_boolean = True
   
    misattached = misattached_any(c)
    
    date = LF_date(c)
    
    v = bacteria_boolean
    v &= implant_boolean
    v &= not misattached
 
    return 1 if v else 0

def LF_complication_explant(c):
    complication = c.complication.get_span().lower()
    
    v = complication == "explant"
    
    return -1 if v else 0

def LF_nonunion(c):
    complication = c.complication.get_span().lower()
    
    v = complication == "nonunion" or complication == "non-union"
    
    return -1 if v else 0    

### Load previous labeling functions

In [None]:
from rwe.labelers import *

# get our pain/anatomy relation labeling functions
lfs = get_labeling_functions("pain_anatomy")

rm = ['LF_contiguous_left_pain', 
      'LF_contiguous_right_pain', 
      'LF_near_contiguous_right_pain', 
      'LF_long_distance_left_pain',
      'LF_misattached_entities',
      'LF_misattached_entities2',
      'LF_misattached_entities3',
      'LF_left_pain_anatomy_between'
     ]

lfs = [lf for lf in lfs if lf.__name__ not in rm]

print("Loaded {} labeling functions\n".format(len(lfs)))
for lf in lfs:
    print(lf.__name__)
    

### Add new labeling functions

In [None]:
lfs += [
    LF_anatomy_mention,
    LF_indication,
    LF_hypothetical, 
    LF_complication,
    LF_implant_indication,
    LF_anatomy_revision,
    LF_anatomy_pain_implant,
    LF_date,
    LF_contiguous_right_finding,
    LF_contiguous_left_finding,
    #LF_long_distance_left_finding,
    LF_right_finding_causing,
    LF_icd_complication,
    LF_implant_revision,
    #LF_implant_bacteria,
    LF_complication_explant,
    LF_nonunion
]
print("Labeling Functions", len(lfs))

In [None]:
for c in X_dev:
    v = LF_implant_bacteria(c)
    if v:
        print('-'*30)
        print(v)
        print('-'*30)
        print(c)
        print('-'*30)
        print(c.get_parent().text)
        print('-'*30)

### Applying Labeling Functions

In [None]:
from mp_lf import mp_apply_lfs

L_train = mp_apply_lfs(lfs, X_train, 1)
L_dev   = mp_apply_lfs(lfs, X_dev, 1)
L_test  = mp_apply_lfs(lfs, X_test, 1)

print("Label Matrix [TRAIN]", L_train.shape)
print("Label Matrix [DEV]  ", L_dev.shape)
print("Label Matrix [TEST] ", L_test.shape)

In [None]:
# Fix label assignments
L_train[L_train==-1] = 2
L_dev[L_dev==-1] = 2
L_test[L_test==-1] = 2

In [None]:
from metal.analysis import lf_summary, view_label_matrix, view_overlaps

lf_summary(L_dev, Y_dev, lf_names = [lf.__name__ for lf in lfs])

In [None]:
view_label_matrix(L_dev)

In [None]:
from metal.analysis import view_conflicts
view_conflicts(L_train, normalize=False)

## Majority Vote Models

In [None]:
import copy

def majority_vote(L):
    '''Majority vote'''
    
    pred = L.sum(axis=1)
    pred[(pred > 0).nonzero()[0]] = 1
    pred[(pred < 0).nonzero()[0]] = 0
    return pred

L_dev_hat = copy.deepcopy(L_dev)
L_dev_hat[L_dev_hat==2] = -1
Y_dev_pred = majority_vote(L_dev_hat)

errors = gold.score(X_dev, Y_dev_pred, ignore_attributes=True)

In [None]:
from snorkel.viewer import SentenceNgramViewer

sv = SentenceNgramViewer(errors['fn'], session=session, n_per_page=1, height=225, annotator_name='gold')
sv

In [None]:
selected = sv.get_selected()

print(selected.get_parent().document.name)

for lf in lfs:
    v = lf(selected)
    if v != 0:
        print(v, lf.__name__)

In [None]:
print(get_section_header(selected))

## Train MeTaL Generative Model
IGNORE FOR NOW

This uses the new Snorkel MeTaL code for learning LF accuracies. 

In [None]:
from metal.label_model import LabelModel
label_model = LabelModel(k=2, seed=123)

In [None]:
%%time
label_model.train(L_train, Y_dev=Y_dev, n_epochs=100, print_every=25)

In [None]:
print(label_model.mu.shape)
print(label_model.mu)

In [None]:
# noise-aware
scores = label_model.score(L_dev, Y_dev, metric=['precision', 'recall', 'f1'])

In [None]:
from metal.label_model.baselines import MajorityLabelVoter
# majority vote
scores = mv.score(L_dev, Y_dev, metric=['precision', 'recall', 'f1'])

### Labeling Function Empirical Accuracy Statistics
Since we have labeled development data, we can examine empirical statistics for labeling function performance. Good labeling function design requires than any heuristic be correct with probablity better than random chance.

In [None]:
#L_dev.lf_stats(session, labels=L_gold_dev.toarray().ravel())

## Training the Generative Model
Grid search for tuning.

In [None]:
import glob
from scipy.sparse import load_npz
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch

# use random search to optimize the generative model
param_ranges = {
    'step_size' : [1e-3, 1e-4, 1e-5, 1e-6],
    'decay'     : [0.9, 0.95],
    'epochs'    : [100, 500],
    'reg_param' : [1e-4],
}
model_class_params = {'lf_propensity' : True}

searcher = RandomSearch(GenerativeModel, param_ranges, L_train, n=5, model_class_params=model_class_params)
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev)
run_stats


In [None]:
# from snorkel.annotations import load_marginals, save_marginals

# train_marginals = gen_model.marginals(L_train)
# save_marginals(session, L_train, train_marginals, training=True)

###  Labeling Function Accuracy Weights
These are the accuracy factor weights learned during training

In [None]:
lf_accs = []
for name,acc in zip([lf.__name__ for lf in lfs], gen_model.weights.lf_accuracy):
    lf_accs.append({"LF-NAME":name, "Acc. Factor Weight":acc})
pd.DataFrame(lf_accs)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(18, 6))
df = pd.DataFrame(data=train_marginals, columns=['marginals'])
pd.DataFrame.hist(df,range=(0.0, 1.0),bins=20, ax=ax)