# COVID-19 Complex Symptom Classification
## Building Training Sets with Weak Supervision
In this tutorial, we'll build a weakly supervised sentence classifier for identfiying evidence of recent international and domestic travel by patients. 

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.insert(0,'../../ehr-rwe')

import glob
import collections
import numpy as np
import pandas as pd

In [None]:
import snorkel

print(f'Python version {sys.version}')
print(f'Snorkel v{snorkel.__version__}')
print(f'NumPy v{np.__version__}')

## 1. Load EHR Documents

This notebook assumes documents have already been preprocessed and dumped into JSON format. See `tutorials/README.md` for details and the `preprocess.py` scripts in `preprocessing/` to create the required JSON files.


In [None]:
from rwe import dataloader

# these notes are restricted to edits made within the first 1 hour of ED admission
inputdir = '' # preprocessed JSON doc directory

corpus = dataloader(glob.glob(f'{inputdir}/*.json'))


## 2. Load Concept Dictionaries
For travel, we want some definition of geographic location. Luckly, there are lots of resources for this. In our experience, off-the-shelf NER with models like spaCy perform quite poorly with EHR text, so we'll just load some U.S. Census gazeteer data and manually tag potential geopolitical named entities. 

In [None]:
%%time
import re
import collections

# https://www.usgs.gov/core-science-systems/ngp/board-on-geographic-names/download-gnis-data
def load_usgs_gnis_geonames(fpath, sw=None):
    data = collections.defaultdict(set)
    df = pd.read_csv(fpath,sep='|', chunksize=10000)
    for block in df:
        for row in block.itertuples():
            for term in [row.FEATURE_NAME, f'{row.FEATURE_NAME}, {row.STATE_ALPHA}']:
                if term and term not in sw and term.lower() not in sw:
                    data[row.FEATURE_CLASS].add(term)
            
    return dict(data)

def load_geonames_countries(fpath, sw=None):
    sw = {} if not sw else sw
    data = set()
    with open(fpath,'r') as fp:
        for line in fp:
            row = line.strip()
            if row[0] == '#':
                continue
            row = row.split('\t')
            country = row[4].strip().lower() 
            capital_city = row[5].strip().lower()
            for term in [country, capital_city]:
                if term and term not in sw and term.lower() not in sw:
                    data.add(term)
    return data


# setup dictionaries and entity typing (for attaching modifiers)
dict_root ='../data/supervision/dicts/'

# remove names of hospitals and other stopwords
# TODO -- remove common first names and surnames (as defined by US Census data) since these
# cause a lot of false positive matches in patient notes
stopwords = {
    'male', 'well', 'unknown', 'likely', 'non', 
    'rash', 'el camino', 'camino',  'lima',
    'mark', 'social', 'felt', 'post', 'sun', 
}
gnis_dict = load_usgs_gnis_geonames(f'{dict_root}/GPE/NationalFile_20200301.txt', sw=stopwords)
country_dict = load_geonames_countries(f'{dict_root}/GPE/countryInfo.txt', sw=stopwords)

gpe_dict = set(list(gnis_dict['Populated Place']) + list(country_dict))



## 3. Generate Candidates

Some clinical tasks don't neatly fit into an entity tagging framework. Complex symptoms such as travel history cover a range of statements involving travel location, subject of the travel (patient, family, coworker etc), or other nuance. Consider these examples and their class label. 

- Patient denies recent travel [NEGATIVE]
- Returned 1.5 weeks ago from a trip in Italy. [POSITIVE]
- Attended a conference in New York City last month. [POSITIVE]
- Patient's father recently returned from Italy. [NEGATIVE]
- Last week attended a work event and talked with several visitors from China [NEGATIVE]

We formulate this task as a sentence classification problem

### Clinical Text Markup
When writing labeling functions, it's helpful to have access to document markup and other metadata. For example, we might want to know what document section we are currently in (e.g., Past Medical History) or if we have temporal information above an event, such as a date of occurence, we might want to incorporate that information into our labeling heuristics. 

### Timing Benchmarks 

- 4 core MacBook Pro 2.5Ghz mid-2015

| N Sentences   | N Cores | Time |
|---------------|---------|----------------|
| 2115          | 4       | 1 minute 10 secs |



In [None]:
from rwe.labelers.taggers import (
    ResetTags, DocTimeTagger,
    DictionaryTagger, HypotheticalTagger, HistoricalTagger,
    SectionHeaderTagger, ParentSectionTagger,
    Timex3Tagger, Timex3NormalizerTagger, TimeDeltaTagger,
    FamilyTagger, PolarityTagger, TextFieldDocTimeTagger
)

# These are largely note type and institution specific. 
# TODO: Train a proper header tagger
def get_header_dict():
    return set([
        'Allergen Reactions', 'Attending Attestations', 'Attending Attestions', 
        'Chief Complaint', 'Clinical Decision Rules', 'Critical Care and Sepsis', 
        'Critical Care and Sepsis Alert', 'Diagnosis Code', 
        'ED Course, Data Review & Interpretation', 'ED Treatment', 
        'Family History', 'HPI', 'History & Physical', 'History From Shared Lists', 
        'Labs & Imaging', 'Labs ordered', 'Medical Decision Making', 'Medications', 
        'New Prescriptions', 'Occupational History', 'Past Medical History', 
        'Patient Active Problem List', 'Patient Active Problem List', 
        'Physical Exam', 'Prior to Admission Medications', 'Procedures', 
        'Reason for Hospitalization', 'Recent Labs', 'Review of Systems', 
        'Social History', 'Substance and Sexual Activity', 'Summary of assessment', 
        'Tobacco Use', 'Ultrasounds & Procedures'
    ])

# Major header are typically sections, minor headers are key/value pairs
# Physical Exam
# EYES: ....
# ABD: ....
def get_major_section_headers():
    return set([
        'Clinical Decision Rules', 'Diagnosis Code', 
        'ED Course, Data Review & Interpretation', 'ED Treatment', 'Family History', 
        'HPI', 'History & Physical', 'Labs & Imaging', 'Medical Decision Making', 
        'New Prescriptions', 'Past Medical History', 'Physical Exam', 
        'Prior to Admission Medications', 'Procedures', 'Review of Systems', 
        'Social History', 'Summary of assessment'
    ])

target_entities = ['GPE']

# Entity Pipeline
pipeline = {
    "reset"     : ResetTags(),
        
    # 2. Clinical concepts
    #"headers"   : SectionHeaderTagger(header_dict=get_header_dict(), 
    #                                stop_headers={}),
    "concepts"  : DictionaryTagger({'GPE': gpe_dict}),
    "timex3"    : Timex3Tagger(),

    # Normalize Datetimes
    "doctimes"  : DocTimeTagger(prop='CREATED_AT', format='%Y-%m-%d %H:%M:%S'),
    "normalize" : Timex3NormalizerTagger(),

    # Concept Modifiers
    #"section"   : ParentSectionTagger(targets=target_entities + ['TIMEX3'], 
    #                                 major_headers=None),
    #"tdelta"    : TimeDeltaTagger(targets=target_entities),
    #"polarity"  : PolarityTagger(targets=target_entities,
    #                          data_root=f"{dict_root}/negex/"),
}


In [None]:
%%time
from rwe.labelers import TaggerPipelineServer

tagger = TaggerPipelineServer(num_workers=4)
documents = tagger.apply(pipeline, [corpus])


### Rank and Select Candidate Sentences
For simplicity and computational efficiency, we represent each document  as the top-k sentences given a query. 
We use a simple boolean query, but something more sophisticated (embeddings, search engine) works fine too. 

In [None]:
import itertools
import collections

query_terms = [
    'travel', 'returned', 'vacation', 'trip',  'fly', 
    'flying', 'flew', 'flight', 'airplane', 'plane', 'cruise'
] + list(country_dict)
query_terms = set(query_terms)
print(f'Query Terms: {len(query_terms)}')

doc_sent_idx = collections.defaultdict(list)
candidates = []

for i, doc in enumerate(documents[0]):
    for sent in doc.sentences:
        terms = list(map(lambda x:x.lower(), sent.words))
        if query_terms.intersection(terms):
            doc_sent_idx[doc.name].append(len(candidates)) 
            candidates.append(sent)


In [None]:
len(candidates)

### Export Sentence Corpus for Manual Annotation
We need *some* hand labeled data to verify our model's output, so we export all of our ranked sentences for manual labeling. In practice, just loading these in Excel and distributing to 2-3 clinical experts works fine (and much faster than a heavy-weight annotation system like BRAT). 

In [None]:
def export_annotation_tsv(candidates, fpath):
    data = ['\t'.join(['','Y','DOC_NAME', 'SENT_IDX', 'TEXT'])]
    for i, x in enumerate(candidates):
        text = x.text.strip().replace('\n',' ')
        row = [i, '', x.document.name, x.i, text]
        row = list(map(str, row))
        data.append('\t'.join(row))
    with open(fpath, 'w') as fp:
        fp.write('\n'.join(data))
        
#export_annotation_tsv(candidates, 'travel.gold.ANNOTATOR.tsv')

In [None]:
import codecs
import pandas as pd

fpath = 'travel.gold.ANNOTATOR.tsv'
gold = pd.read_csv(fpath, sep='\t', encoding='latin-1')

# build gold index by keys to maintain ordering
gold = {(doc_name,sent_i):y for doc_name,sent_i,y in zip(gold.DOC_NAME, gold.SENT_IDX, gold.Y)}

Y_gold = []
for x in candidates:
    key = (x.document.name, x.i)
    if key not in gold:
        Y_gold.append(-1)
    else:
        y = int(gold[key]) if not np.isnan(gold[key]) else -1
        Y_gold.append(y)

Y_gold = np.array(Y_gold)
Y_gold[Y_gold == 0] = 2
Y_gold[Y_gold == -1] = 0

print(f'Neg: {len(Y_gold[Y_gold==2])}')
print(f'Pos: {len(Y_gold[Y_gold==1])}')

In [None]:
doc_index = {doc.name:doc for doc in documents[0]}

## 4. Define and Apply Labeling Functions

In [None]:
import datetime
from rwe.labelers.taggers.negex import NegEx 
from rwe.helpers import (
    match_regex, token_distance, 
    get_left_span, get_right_span
)

ABSTAIN   = 0
TRAVEL    = 1
NO_TRAVEL = 2

negex = NegEx(f'{dict_root}/negex/')

#
# Helper Functions
#

def is_negated(span, window=None):
    left = get_left_span(span, window=window)
    trigger = match_regex(negex.rgxs['definite']['left'], left)
    return True if trigger else False

def no_recent_travel(sent):
    """Explicit statement of no travel"""
    rgx = r'''\b(travel(s|ed|ing)*|vacation|trip)\b'''
    trigger = match_regex(rgx, sent)
    return True if trigger and is_negated(trigger) else False

#
# Labeling Functions
#

def LF_travel_mode(s):
    """Mode of transportation indicating long(er) distance travel"""
    rgx = r'''\b(flight|fly(ing)*|flew|airplane)\b'''
    return TRAVEL if re.search(rgx, s.text, re.I) else ABSTAIN

def LF_cruise_ships(s):
    """
    Cruise ships are more complicated because there are secondary 
    exposures from clean-up crews and others.
    """
    rgx = r'''((ruby|coral|grand|diamond) princess|celebrity eclipse|(princess|carnival) cruise(s)*)'''
    v = re.search(rgx, s.text, re.I) is not None
    rgx = r'''(visitor(s)*|friend(s)*|roomate|father|mother|coworker)'''
    v &= re.search(rgx, s.text, re.I) is None
    return TRAVEL if v else ABSTAIN

def LF_old_travel(s):
    """Ignore travel from time windows outside the COVID-19 pandemic."""
    rgx = r'''\b(travel(s|ed|ing)*|vacation|trip)\b'''
    trigger = match_regex(rgx, s)
    if not trigger or is_negated(trigger):
        return ABSTAIN
    
    rgx = r'''[1-9]+ (year|week|day)[s]* (ago|prior)'''
    right_trigger = match_regex(rgx, get_right_span(trigger))
    if not right_trigger:
        return ABSTAIN
    
    return NO_TRAVEL if 'year' in s.text else TRAVEL
    
def LF_travel(s):
    """Check for travel terms + simple negation"""
    rgx = r'''\b(travel(s|ed|ing)*|vacation|trip)\b'''
    trigger = match_regex(rgx, s)
    if not trigger:
        return ABSTAIN
    return TRAVEL if not is_negated(trigger) else NO_TRAVEL

def LF_geo_terms(s, time_window_days=120):
    """
    Some mention of travel to a geolocation with 
    a date in our time window of interest
    """
    # explicit mention of no travel
    if no_recent_travel(s):
        return NO_TRAVEL
    
    # how far back to we want to register travel as a risk factor?
    anchor_ts = datetime.datetime.now() - datetime.timedelta(days=time_window_days)
    # filter out connections (my father returned from China)
    rgx = r'''(visitor(s)*|friend(s)*|roomate|father|mother|colleague|co[-]*worker)'''
    if re.search(rgx, s.text, re.I):
        return ABSTAIN
    
    # exact all named entities in this sentence
    entities = doc_index[s.document.name].annotations[s.i]
    if 'GPE' in entities and 'TIMEX3' in entities:
        # filter out dates that don't fall in our time window
        ts = [t.normalized for t in entities['TIMEX3'] if t.normalized]
        if ts and anchor_ts > max(ts):
            return NO_TRAVEL
        return TRAVEL

    return ABSTAIN
    
def LF_returned(s):
    rgx = r''' returned to ([A-Za-z0-9]+\s*){1,3}hospital'''
    if re.search(rgx, s.text, re.I):
        return NO_TRAVEL
    rgx = r'''returned from'''
    return TRAVEL if re.search(rgx, s.text, re.I) else ABSTAIN

def LF_motion_sickness(s):
    """
    Discussions of air travel also include prescriptions of medication
    for motion sickness
    """
    terms = {}
    
    
def LF_boilerplate(s):
    """
    Catch some copy-n-pasted responses that use 'returned' in the non-trip sense
    'I have asked the patient to remain in home isolation until results have returned.'
    """
    # ignore extraneous whitespace
    text = re.sub(r'''[\n\s]+''', ' ', s.text)
    rgx = r'''(result(s)* ([A-Za-z0-9]+\s*){1,3}returned)'''
    if re.search(rgx, text, re.I):
        return NO_TRAVEL  
    cdc = {'China', 'Iran', 'Italy', 'Japan', 'South Korea'}
    if [t in text for t in cdc].count(True) == len(cdc):
        return NO_TRAVEL
    return ABSTAIN

# TODO add more labeling functions 
lfs = [
    LF_travel_mode,
    LF_cruise_ships,
    LF_old_travel,
    LF_travel,
    LF_geo_terms,
    LF_boilerplate,
    LF_returned
]

# split into train/test sets
train_idxs = np.where(Y_gold == 0)
test_idxs = np.where(Y_gold != 0)

Xs = [
    np.array(candidates)[train_idxs], 
    np.array(candidates)[test_idxs]
]

Ys = [
    Y_gold[np.where(Y_gold == 0)],
    Y_gold[np.where(Y_gold != 0)],
]

Let's apply these LFs to our training data.

In [None]:
%%time
from rwe.labelers import LabelingServer

labeler = LabelingServer(num_workers=4)
Ls = labeler.apply(lfs, Xs)


We can then examine the accuracy of each of our LFs using the gold labeled data as ground truth.

In [None]:
from rwe.analysis import lf_summary

lf_summary(Ls[1], Y=Ys[1], lf_names=[lf.__name__ for lf in lfs])


In [None]:
from rwe.visualization.analysis import view_conflicts, view_label_matrix, view_overlaps

view_overlaps(Ls[1], normalize=False)
view_conflicts(Ls[1], normalize=False)

## 4. Train Snorkel Label Model 

### The next step is to train a Snorkel Label model using the data labeled with our labeling functions.

Note -- we aren't doing any tuning here at all! 

In [None]:
# convert sparse matrix to new Snorkel format

def convert_label_matrix(L):
    L = L.toarray().copy()
    L[L == 0] = -1
    L[L == 2] = 0
    return L


Ls_hat = [
    convert_label_matrix(Ls[0]),
    convert_label_matrix(Ls[1])
]

Ys_hat = [
    [],
    np.array([0 if y == 2 else 1 for y in Ys[1]]),
]


In [None]:
from snorkel.labeling import LabelModel

n = Ls_hat[1].shape[0]
lr = 0.001
l2 = 0.001
prec_init = 0.8
mu_eps = 1 / 10 ** np.ceil(np.log10(n))
              

label_model = LabelModel(cardinality=2, device='cpu', verbose=True)
label_model.fit(L_train=Ls_hat[1], 
                n_epochs=100, 
                lr=lr,
                l2=l2,
                prec_init=prec_init,
                optimizer='adamax',
                log_freq=100,
                mu_eps=mu_eps)

metrics = ['accuracy', 'precision', 'recall', 'f1']
label_model.score(L=Ls_hat[1], Y=Ys_hat[1], metrics=metrics)


In [None]:
from sklearn.metrics import classification_report

def mv(L, break_ties):
    """Simple majority vote"""
    from statistics import mode
    y_hat = []
    for row in L:
        row = row[row != -1]
        try:
            l = mode(row)
        except:
            l = break_ties
        y_hat.append(l)
    return np.array(y_hat).astype(np.int)

mv_pred = mv(Ls_hat[1], 0)
y_gold = [1 if y == 1 else 0 for y in Ys[1]]
print(classification_report(y_gold, mv_pred))

y_pred = label_model.predict(Ls_hat[1])
y_pred[y_pred == -1] = 0

print(classification_report(y_gold, y_pred))

## 5. Create Probabalisitic Labels for Unlabeled Data

### The last step in this tutorial is to use our labeling model to predict on new data.

Now you can load these labels into your end model (LSTM, BERT, etc)

In [None]:
Y_proba = label_model.predict_proba(Ls_hat[1])
Y_proba