# Building Training Sets with Weak Supervision
In this tutorial, we'll build a `Pain-At` relation training set using weakly superivsed methods. This notebook covers: 
- Loading pre-processed documents
- Generating relational candidates 
- Applying labeling functions
- Training a Snorkel Label Model

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.insert(0,'../../ehr-rwe/')

import glob
import collections
import numpy as np
import pandas as pd


## 1. Load MIMIC-III Documents

This notebook assumes documents have already been preprocessed and dumped into JSON format. We created a small annotated dataset using MIMIC-III patient notes. See the `XXXX.ipynb` notebook for instructions on creating the required JSON files 

You will need access to MIMIC-III data to run this notebook using our tutorial annotations.  See https://mimic.physionet.org/gettingstarted/access/

See `preprocessing/README.md` for details.

In [2]:
from rwe import dataloader

inputdir = '/Users/fries/Desktop/foobar/output/'
 
corpus = [
    dataloader([f'{inputdir}/mimic_gold.json']), 
    dataloader([f'{inputdir}/mimic_unlabeled.json'])
]

for split in corpus:
    print(f'Loaded {len(split)} documents')


Loaded 55 documents
Loaded 1322 documents


## 2. Generate Candidates

This is an example pipeline for generating `Pain-At` relation candidates. Relations are defined as a tuple $k$ entity spans. For simplicity's sake, we consider binary relations between all `Anatomy` and `Pain` entity pairs found within the same sentence. Entities can be tagged using a clinical named entity recognition (NER) model if available. Here we use a dictionary-based method to tag our initial `Anatomy` and `Pain` entities. 

### Clinical Text Markup
When writing labeling functions, it's conveinant to have access to document markup and other metadata. For example, we might want to know what document section we are currently in (e.g., Past Medical History) or if we have temporal information above an event, such as a data of occurence, we might want to incorproate that information into our labeling heuristics. 

### Timing Benchmarks 

- 50,000 MIMIC-III documents
- 4 core MacBook Pro 2.5Ghz mid-2015

| N Documents   | N Cores | Time |
|---------------|---------|----------------|
| 299           | 4       | 17 seconds |
| 10,000        | 4       | 4 minutes  |


In [3]:
from rwe.utils import load_dict
from rwe.labelers.taggers import (
    ResetTags, RelationTagger, 
    DictionaryTagger, NegExTagger, HypotheticalTagger, HistoricalTagger,
    SectionHeaderTagger, ParentSectionTagger,
    DocTimeTagger, MappedDocTimeTagger, 
    Timex3Tagger, Timex3NormalizerTagger, TimeDeltaTagger,
)

dict_pain = load_dict('../data/dicts/pain/pain.txt')
dict_anat = load_dict('../data/dicts/anatomy/anat.bz2')

target_entities = ['pain']

# NOTE: Pipelines are *order dependant* as normalizers and attribute taggers assume
# the existence of certain concept targets (e.g., Timex3Normalizer requires timex3 entities)
pipeline = {
    # 1. Clear any previous runs
    "reset"        : ResetTags(),
    
    # 2. Clinical concepts
    "concepts"  : DictionaryTagger({'pain': dict_pain, 'anatomy': dict_anat}),
    "headers"   : SectionHeaderTagger(),
    "timex3"    : Timex3Tagger(),
    
    # 3. Normalize datetimes
    "doctimes"  : DocTimeTagger(prop='CHARTDATE'),
    "normalize" : Timex3NormalizerTagger(),
    
    # 4. Concept attributes
    "section"      : ParentSectionTagger(targets=target_entities),
    "tdelta"       : TimeDeltaTagger(targets=target_entities),
    "negation"     : NegExTagger(targets=target_entities, data_root="../data/dicts/negex/"),
    "hypothetical" : HypotheticalTagger(targets=target_entities),
    'historical'   : HistoricalTagger(targets=target_entities),
    
    # 5. Extract relation candidates
    "pain-at"      : RelationTagger('pain-at', ('pain', 'anatomy'))
}


In [4]:
%%time
from rwe.labelers import TaggerPipelineServer

tagger = TaggerPipelineServer(num_workers=4)
documents = tagger.apply(pipeline, corpus)


auto block size=345
Partitioned into 4 blocks, [342 345] sizes
CPU times: user 2.38 s, sys: 394 ms, total: 2.77 s
Wall time: 1min 5s


In [None]:
# for split in documents:
#     for doc in split:
#         #print(doc.name)
   
#         for sent_i in doc.annotations:
#             # print(sent_i, doc.annotations[sent_i])
#             for layer_name in doc.annotations[sent_i]:
#                 if layer_name == 'pain':
#                     for span in doc.annotations[sent_i]['pain']:
#                         print(span.props)
                
# #                 if layer_name == 'pain-at':
# #                     print(doc.annotations[sent_i][layer_name])
                    
# #                 if layer_name == 'TIMEX3':
# #                     for span in doc.annotations[sent_i][layer_name]:
# #                         print(span.text, span.normalized)

In [5]:
from rwe.utils import build_candidate_set

Xs_pain_at = build_candidate_set(documents[0], "pain-at")
#Xs_comp_at = build_candidate_set(documents[0], "comp-at")


In [6]:
def collapse_relation_args(relations):
    return set([s for x in relations for s in x])
    
# print some summary stats about candidates
pain_at_spans = collapse_relation_args(Xs_pain_at)
#comp_at_spans = collapse_relation_args(Xs_comp_at)

doc_span_index = collections.defaultdict(set)
#for s in comp_at_spans:
#    doc_span_index[s.sentence.document.name].add(s)
for s in pain_at_spans:
    doc_span_index[s.sentence.document.name].add(s)
    
n = 0
for doc_name in doc_span_index:
    print(doc_name, len(doc_span_index[doc_name]))
    n += len(doc_span_index[doc_name])
print(n)
    

27718_11314_111136 4
27162_29346_185828 41
24347_63_195961 2
27695_13917_192671 13
22934_11162_113700 35
271_61898_170625 19
1843_57753_109617 24
23500_7383_156659 11
27673_26993_194796 4
1880_13330_193396 19
26845_31054_119104 13
24498_57342_113484 28
19430_7400_190280 13
25243_3811_187890 12
25035_51203_193364 17
1711_67619_132244 11
18971_19649_132894 6
17242_9616_194603 5
21790_18230_127941 4
1885_24242_185235 9
23714_4574_133306 4
2604_68251_150426 16
28492_21867_175298 10
21984_21916_174115 8
26559_19286_195185 12
1757_29581_122726 9
24237_23792_134550 5
28875_7319_103943 11
16932_67853_178133 7
19420_91074_106110 34
28895_10758_183338 6
24867_2008_184705 6
25496_25519_145528 10
16877_22927_197951 9
1946_88857_124316 24
23025_29965_154849 8
19532_44156_181689 6
20705_15862_114219 3
20086_16055_177220 4
28381_68704_193105 6
25944_78995_124007 7
25018_7921_177363 2
1679_6989_197945 10
1960_59291_109058 8
2364_44605_191210 7
21850_21834_143794 4
18751_8799_198888 15
24274_9082_14300

In [None]:
outdir = '/Users/fries/Desktop/mimic-sample-brat/'
etype = 'Concept'

for doc_name in doc_span_index:
    outfname = f'{outdir}/{doc_name}.ann'
    with open(outfname, 'w') as fp:
        items = set()
        for i,s in enuamerate(doc_span_index[doc_name]):
            # T8	Concept 468 479;480 489	right lower extremity
            multi_spans = []
            start = s.abs_char_start
            if '\n' in s.text:
                toks = s.text.split("\n")
                for t in toks:
                    multi_spans.append((start, start + len(t)))
                    start += len(t) + 1
                
                span_str = [f'{span[0]} {span[1]}' for span in multi_spans]
                anno = (etype, " ", ";".join(span_str), "\t", s.text.replace("\n", " "))
            else:
                anno = (etype, " ", f"{s.abs_char_start} {s.abs_char_end+1}", "\t", s.text.replace("\n", " "))
                
            items.add(anno)
            
        for i,s in enumerate(sorted(items,key=lambda x:x[1], reverse=0)):
            anno = f'T{i+1}\t{"".join(s)}'
            print(anno)
            fp.write(anno + '\n')


In [None]:
from rwe.utils import build_candidate_set, load_gold

fpath = "../data/annotations/mimic.gold.final.tsv"
gold = load_gold(fpath, documents[0], ('part-at', ('pain','anatomy')))

Xs = build_candidate_set(documents[0], "pain-at")
Ys = [gold[x] if x in gold else 0 for x in Xs]
Ys = [y if y == 1 else 2 for y in Ys]

print("Class Balance")
print(f'Positive: {Ys.count(1)} ({Ys.count(1)/len(Ys)*100:2.1f})%')
print(f'Negative: {Ys.count(2)} ({Ys.count(2)/len(Ys)*100:2.1f})%')

## 3. Apply Labeling Functions

In [None]:
import re
from rwe.labelers.taggers import get_left_span, get_right_span, get_between_span

def dict_matches(span, dictionary):
    matches = []
    toks = span.get_attrib_tokens('words')
    for i in range(len(toks)):
        for j in range(i+1, len(toks)):
            term = ' '.join(toks[i:j]).lower()
            if term in dictionary:
                matches.append(term)
    return matches

ABSTAIN  = 0
NEGATIVE = 2
POSITIVE = 1

neg_rgx = re.compile(
    r'''\b(insensitivity|paresthesias|paresthesia|sensitivity|tenderness|discomfort|heaviness|sensitive|itchiness|tightness|throbbing|numbness|tingling|cramping|coldness|soreness|painful|hurting|itching|burning|tender|buring|aching|hurts|aches|pains|hurt|pain|ache|achy|numb)\b''',
    re.I
)

def LF_is_negated(x):
    return NEGATIVE if 'negated' in x.pain.props else ABSTAIN

def LF_is_hypothetical(x):
    return NEGATIVE if 'hypothetical' in x.pain.props else ABSTAIN

def LF_section_headers(x):
    sections = {
        'past medical history': NEGATIVE,
        'chief complaint': POSITIVE,
        'discharge instructions': NEGATIVE,
        'discharge condition': NEGATIVE
    }
    header = x.pain.props['section'].text.lower() if x.pain.props['section'] else None
    return ABSTAIN if header not in sections else sections[header]

def LF_contiguous_args(x):
    v = not get_between_span(x.pain, x.anatomy)
    v &= not 'negated' in x.pain.props
    v &= not 'hypothetical' in x.pain.props
    return POSITIVE if v else ABSTAIN

def LF_distant_args(x, max_toks=10):
    """Reject candidate if the arguments occur too far apart (in token distance)."""
    span = get_between_span(x.pain, x.anatomy)
    n_toks = len(span.get_attrib_tokens('words')) if span else 0
    return NEGATIVE if n_toks > max_toks else ABSTAIN
    
def LF_between_terms(x):
    """Reject if some key terms occur between arguments."""
    span = get_between_span(x.pain, x.anatomy)
    if not span:
        return ABSTAIN
    # negation term      
    flag = neg_rgx.search(span.text) is not None
    # anatomical term 
    flag |= len(dict_matches(span, dict_anat)) > 0
    return NEGATIVE if flag else ABSTAIN

def LF_complains_of(x):
    rgx = re.compile(r'''\b(complain(s*|ing*) of)\b''', re.I)
    is_negated = 'negated' in x.pain.props
    is_complains_of = rgx.search(get_left_span(x.pain).text, re.I) is not None
    return POSITIVE if not is_negated and is_complains_of else ABSTAIN
    

lfs = [
    LF_is_negated,
    LF_is_hypothetical,
    LF_section_headers,
    LF_contiguous_args,
    LF_distant_args,
    LF_between_terms,
    LF_complains_of
]


In [None]:

            
            

# # for x in Xs:
# #     v = LF_between_terms(x)
# #     print(v)

# x = Xs[10]
# span = get_between_span(x.pain, x.anatomy)




In [None]:
%%time
from rwe.labelers import LabelingServer

labeler = LabelingServer(num_workers=4)
Ls = labeler.apply(lfs, [Xs])


In [None]:
from rwe.analysis import lf_summary

lf_summary(Ls[0], Y=Ys, lf_names=[lf.__name__ for lf in lfs])

In [None]:
from rwe.visualization.analysis import view_conflicts, view_label_matrix, view_overlaps

view_overlaps(Ls[0], normalize=False)
view_label_matrix(Ls[0])
view_conflicts(Ls[0], normalize=False)

## 4. Train Snorkel Label Model 

In [None]:
# convert sparse matrix to new Snorkel format

def convert_label_matrix(L):
    L = L.toarray().copy()
    L[L == 0] = -1
    L[L == 2] = 0
    return L


Ls_hat = [
    convert_label_matrix(Ls[0])
]

Ys_hat = [
    np.array([0 if y == 2 else 1 for y in Ys])
]


In [None]:
from snorkel.labeling import LabelModel

lr = 0.01
l2 = 0.001
prec_init = 0.9

label_model = LabelModel(cardinality=2, device='cpu', verbose=True)
label_model.fit(L_train=Ls_hat[0], 
                n_epochs=1000, 
                lr=lr,
                l2=l2,
                prec_init=prec_init,
                optimizer='adam',
                log_freq=100)

metrics = ['accuracy', 'precision', 'recall', 'f1']
label_model.score(L=Ls_hat[0], Y=Ys_hat[0], metrics=metrics)
