# Disease Tagging Tutorial

In this example, we'll be writing an application to extract *mentions of* diseases from Pubmed abstracts, using annotations from the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial, which has 5 parts, walks through the process of constructing a model to classify _candidate_ disease mentions as either true (i.e., that it is truly a mention of a disease) or false.

## Part IV: Training a Model with Data Programming

In this part of the tutorial, we will train a statistical model to differentiate between true and false `Disease` mentions.

We will train this model using _data programming_, and we will **ignore** the training labels provided with the training data. This is a more realistic scenario; in the wild, hand-labeled training data is rare and expensive. Data programming enables us to train a model using only a modest amount of hand-labeled data for validation and testing. For more information on data programming, see the [NIPS 2016 paper](https://arxiv.org/abs/1605.07723).

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os


# Note: We run automated tests on this tutorial to make sure that it is always up to date! 
# However, certain interactive components cannot currently be tested automatically, and will 
# be skipped with if-then statements using the variable below
AUTOMATED_TESTING = os.environ.get('TESTING') is not None

import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

We repeat our definition of the `Disease` `Candidate` subclass from Parts II and III.

In [None]:
from snorkel.models import candidate_subclass

Disease = candidate_subclass('Disease', ['disease'])

## Loading `CandidateSet` objects

We reload the training and development `CandidateSet` objects from the previous parts of the tutorial.

In [None]:
from snorkel.models import CandidateSet

train = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Training Candidates').one()
dev = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates').one()

## Automatically Creating Features
Recall that our goal is to distinguish between true and false mentions of chemical-disease relations. To train a model for this task, we first embed our `ChemicalDisease` candidates in a feature space.

In [None]:
from snorkel.annotations import FeatureManager

feature_manager = FeatureManager()

We can create a new feature set:

In [None]:
%time F_train = feature_manager.create(session, train, 'Train Features')

**OR** if we've already created one, we can simply load as follows:

In [None]:
%time F_train = feature_manager.load(session, train, 'Train Features')

Note that the returned matrix is a special subclass of the `scipy.sparse.csr_matrix` class, with some special features which we demonstrate below:

In [None]:
F_train

In [None]:
F_train.get_candidate(0)

In [None]:
F_train.get_key(0)

## Creating Labeling Functions
Labeling functions are a core tool of data programming. They are heuristic functions that aim to classify candidates correctly. Their outputs will be automatically combined and denoised to estimate the probabilities of training labels for the training data.

In [None]:
import re
from lf_terms import *
from snorkel.lf_helpers import get_left_tokens, get_right_tokens
TRUE,FALSE,ABSTAIN = 1,-1,0

We also load some publicly-available biomedical dictionaries, which we will leverage in some of our LFs below as a source of weak supervision:

In [None]:
from utils import *

diseases               = load_disease_dictionary()
diseases.update(load_acronym_dictionary())
umls_dict              = load_umls_dictionary()
chemicals              = load_chemdner_dictionary()
abbrv2text, text2abbrv = load_specialist_abbreviations()

#### Document-Level Labeling Functions
We start with some labeling functions that label candidates based on document-level features.

In [None]:
from snorkel.lf_helpers import get_doc_candidate_spans

def LF_known_abbreviation(c):
    """
    Utilize heuristic that if a phrase and its acronym are in same sentence they are
    likely actual mentions, using the SPECIALIST acronyms lexicon
    """
    doc_spans = get_doc_candidate_spans(c)
    phrase = c[0].get_span().lower()
    mentions = set([s.get_span().lower() for s in doc_spans])  
    if phrase in abbrv2text and set(abbrv2text[phrase].keys()).intersection(mentions):
        return 1
    elif phrase in text2abbrv and set(text2abbrv[phrase].keys()).intersection(mentions):
        return 1
    return 0

def LF_undefined_abbreviation(c):
    '''Candidate is a known abbreviation, but no corresponding full name in document'''
    doc_spans = get_doc_candidate_spans(c)
    phrase = c[0].get_span().lower()
    mentions = set([s.get_span().lower() for s in doc_spans])
    if len(phrase) > 1 and phrase in abbrv2text and not set(abbrv2text[phrase].keys()).intersection(mentions):
        return -1
    return 0

#### Sentence-Level Labeling Functions
We also include some labeling functions that label candidates based on sentence-level features.

In [None]:
from snorkel.lf_helpers import get_sent_candidate_spans

def LF_contiguous_mentions(c):
    '''Contiguous candidates are likely wrong'''
    neighbor_spans = get_sent_candidate_spans(c)
    start, end = c[0].get_word_start(), c[0].get_word_end()
    for s in neighbor_spans:
        if s.get_word_end() + 1 == start or s.get_word_start() - 1 == end:
            return -1
    return 0

#### Mention-Level Labeling Functions
We now define a number of labeling functions that label candidates based on attributes related to the mention.

In [None]:
def LF_tumors_growths(c):
    phrase = " ".join(c[0].get_attrib_tokens('lemmas'))
    return TRUE if re.search("^(\w* ){0,2}(['] )*(tumor|tumour|polyp|pilomatricoma|cyst|lipoma)$",phrase) else ABSTAIN

def LF_cancer(c):
    '''<TYPE> cancer'''
    phrase = " ".join(c[0].get_attrib_tokens('lemmas'))
    return TRUE if re.search("\w* cancer",phrase) else ABSTAIN

def LF_disease_syndrome(c):
    '''<TYPE> disease or <TYPE> syndrome'''
    phrase = " ".join(c[0].get_attrib_tokens('lemmas'))
    return TRUE if re.search("\w* (disease|syndrome)+",phrase) else ABSTAIN

def LF_indicators(c):
    '''Indicator words'''
    return TRUE if " ".join(c[0].get_attrib_tokens()).lower() in indicators else ABSTAIN

def LF_common_disease(c):
    '''Common disease'''
    return TRUE if " ".join(c[0].get_attrib_tokens()).lower() in common_disease else ABSTAIN

def LF_common_disease_acronyms(c):
    '''Common disease acronyms'''
    return TRUE if " ".join(c[0].get_attrib_tokens()) in common_disease_acronyms else ABSTAIN

def LF_deficiency_of(c):
    '''deficiency of <TYPE>'''
    phrase = " ".join(c[0].get_attrib_tokens()).lower()
    return TRUE if phrase.endswith('deficiency') or phrase.startswith('deficiency') or phrase.endswith('dysfunction') else ABSTAIN

def LF_come_with(c):
    phrase = ' '.join(c[0].get_attrib_tokens())
    w = ' '.join(right_window(c[0], window=1))
    return TRUE if phrase in ["APC", "PDS"] and w in ['gene', 'mutations', 'mutation'] else ABSTAIN

def LF_positive_indicator(c):
    flag = False
    for i in c[0].get_attrib_tokens():
        if i.lower() in positive_indicator:
            flag = True
            break
    return TRUE if flag else ABSTAIN

def LF_left_positive_argument(c):    
    phrase = " ".join(c[0].get_attrib_tokens('lemmas')).lower()
    pattern = "(\w+ ){1,2}(infection|lesion|neoplasm|attack|defect|anomaly|abnormality|degeneration|carcinoma|lymphoma|tumor|tumour|deficiency|malignancy|hypoplasia|disorder|deafness|weakness|condition|dysfunction|dystrophy)$"
    return TRUE if re.search(pattern,phrase) else ABSTAIN

def LF_right_negative_argument(c):    
    phrase = " ".join(c[0].get_attrib_tokens('lemmas')).lower()
    pattern = "^(history of|mitochondrial|amino acid)( \w+){1,2}"
    return FALSE if re.search(pattern,phrase) else ABSTAIN

def LF_medical_afixes(c):
    pattern = "(\w+(pathy|stasis|trophy|plasia|itis|osis|oma|asis|asia)$|^(hyper|hypo)\w+)"
    phrase = " ".join(c[0].get_attrib_tokens('lemmas')).lower()
    return TRUE if re.search(pattern,phrase) else ABSTAIN

def LF_adj_diseases(c):
    adj_diseases = ['acromegalic', 'akinetic', 'allergic', 'arrhythmic', 'arteriopathic', 'asthmatic', 
                    'atherosclerotic', 'bradycardic', 'cardiotoxic', 'cataleptic', 'cholestatic', 
                    'cirrhotic', 'diabetic', 'dyskinetic', 'dystonic', 'eosinophilic', 'epileptic', 
                    'exencephalic', 'haemorrhagic', 'hemolytic', 'hemorrhagic', 'hemosiderotic', 'hepatotoxic'
                    'hyperalgesic', 'hyperammonemic', 'hypercalcemic', 'hypercapnic', 'hyperemic', 
                    'hyperkinetic', 'hypertrophic', 'hypomanic', 'hypothermic', 'ischaemic', 'ischemic', 
                    'leukemic', 'myelodysplastic', 'myopathic', 'necrotic', 'nephrotic', 'nephrotoxic', 
                    'neuropathic', 'neurotoxic', 'neutropenic', 'ototoxic', 'polyuric', 'proteinuric', 
                    'psoriatic', 'psychiatric', 'psychotic', 'quadriplegic', 'schizophrenic', 'teratogenic', 
                    'thromboembolic', 'thrombotic', 'traumatic', 'vasculitic']
    return TRUE if ' '.join(c[0].get_attrib_tokens()) in adj_diseases else ABSTAIN

#### Dictionary Labeling Functions
We can use existing dictionaries for distant supervision.

In [None]:
def LF_SNOWMED_CT_sign_or_symptom(c):
    return TRUE if c[0].get_span() in umls_dict["snomedct"]["sign_or_symptom"] else ABSTAIN

def LF_SNOWMED_CT_disease_or_syndrome(m):
    return TRUE if c[0].get_span() in umls_dict["snomedct"]["disease_or_syndrome"] else ABSTAIN

def LF_MESH_disease_or_syndrome(m):
    return TRUE if c[0].get_span() in umls_dict["mesh"]["disease_or_syndrome"] else ABSTAIN

def LF_MESH_sign_or_symptom(m):
    return TRUE if c[0].get_span() in umls_dict["mesh"]["sign_or_symptom"] else ABSTAIN

#### Negative Labeling Functions
When writing labeling functions, it is important to provide negative supervision in addition to positive supervision.

In [None]:
def LF_organs(c):
    phrase = " ".join(c[0].get_attrib_tokens()).lower()
    return FALSE if phrase in organs else ABSTAIN      

def LF_chemical_name(c):
    phrase = " ".join(c[0].get_attrib_tokens())
    return FALSE if phrase in chemicals and not phrase.isupper() else ABSTAIN

def LF_bodysym(c):
    phrase = " ".join(c[0].get_attrib_tokens()).lower()
    return FALSE if phrase in bodysym else ABSTAIN  

def LF_protein_chemical_abbrv(c):
    '''Gene/protein/chemical name'''
    lemma = " ".join(c[0].get_attrib_tokens('lemmas'))
    return FALSE if re.search("\d+",lemma) else ABSTAIN

def LF_has_punctuation(c):
    return FALSE if re.search("[=%]+"," ".join(c[0].get_attrib_tokens())) else ABSTAIN

def LF_gene_abbrv(c):
    '''Gene/protein/chemical name'''
    lemma = " ".join(c[0].get_attrib_tokens('lemmas'))
    return FALSE if re.search("\d+",lemma) and lemma.isupper() else ABSTAIN

def LF_base_pair_seq(c): 
    lemma = " ".join(c[0].get_attrib_tokens('lemmas'))
    return FALSE if re.search("^[GACT]{2,}$",lemma) else ABSTAIN

def LF_too_vague(c):
    phrase = " ".join(c[0].get_attrib_tokens('lemmas')).lower()
    phrase_ = " ".join(c[0].get_attrib_tokens()).lower()
    return FALSE if phrase in vague or phrase_ in vague else ABSTAIN

def LF_negation(c):
    neg = set(["not","no"])
    return FALSE if neg.intersection(c[0].get_attrib_tokens('lemmas')) else ABSTAIN

def LF_neg_surfix(c):
    terms = ['deficiency', 'the', 'the', 'of', 'to', 'a']
    rw = get_right_tokens(c, window=1, attrib='lemmas')
    if len(rw) > 0 and rw[0].lower() in terms:
        return FALSE
    return ABSTAIN

def LF_non_common_disease(c):
    '''Non common diseases'''
    return FALSE if " ".join(c[0].get_attrib_tokens()).lower() in non_common_disease else ABSTAIN

def LF_non_disease_acronyms(c):
    '''Non common disease acronyms'''
    return FALSE if " ".join(c[0].get_attrib_tokens()) in non_disease_acronyms else ABSTAIN

def LF_pos_in(c):
    '''Candidates beginning with a preposition or subordinating conjunction'''
    poses = c[0].get_attrib_tokens('poses')
    return FALSE if "IN" in poses[0:1] else ABSTAIN

def LF_gene_chromosome_link(c):
    '''Mentions of the form "Huntington Disease gene"'''
    genetics_terms = set(["gene","chromosome"])
    diseases_terms = set(["disease","syndrome","disorder"])
    context = get_left_tokens(c,window=10, attrib='lemmas') + get_right_tokens(c,window=10, attrib='lemmas')
    # 1: contains a disease keyword or 2: in disease dictionaries
    is_disease = diseases_terms.intersection(map(lambda x:x.lower(), c[0].get_attrib_tokens()))
    is_disease = is_disease or " ".join(c[0].get_attrib_tokens()) in diseases
    is_gene = genetics_terms.intersection(context)    
    return FALSE if is_gene and not is_disease else ABSTAIN

def LF_right_window_incomplete(c):
    return FALSE if right_terms.intersection(get_right_tokens(c,window=2, attrib='lemmas')) else ABSTAIN

def LF_negative_indicator(c):
    flag = False
    for i in c[0].get_attrib_tokens():
        if i.lower() in negative_indicator:
            flag = True
            break
    return FALSE if flag else ABSTAIN

We maintain a list of all LFs for convenience.

In [None]:
LFs_doc = [LF_known_abbreviation,
           LF_undefined_abbreviation
          ]

LFs_sent = [LF_contiguous_mentions]

LFs_mention = [LF_tumors_growths,
               LF_cancer,
               LF_disease_syndrome,
               LF_indicators,
               LF_common_disease,
               LF_common_disease_acronyms,
               LF_deficiency_of,
               LF_come_with,
               LF_positive_indicator,
               LF_left_positive_argument,
               LF_right_negative_argument,
               LF_medical_afixes,
               LF_adj_diseases
              ]

LFs_dicts =  [LF_SNOWMED_CT_sign_or_symptom,
              LF_SNOWMED_CT_disease_or_syndrome,
              LF_MESH_disease_or_syndrome,
              LF_MESH_sign_or_symptom
            ]

LFs_false = [LF_chemical_name,
             LF_organs,
             LF_bodysym,
             LF_protein_chemical_abbrv,
             LF_has_punctuation,
             LF_gene_abbrv,
             LF_base_pair_seq,
             LF_too_vague,
             LF_negation,
             LF_neg_surfix,
             LF_non_common_disease,
             LF_non_disease_acronyms,
             LF_pos_in,
             LF_gene_chromosome_link,
             LF_right_window_incomplete,
             LF_negative_indicator
            ]

LFs = LFs_doc + LFs_sent + LFs_mention + LFs_dicts + LFs_false

## Applying Labeling Functions

First we construct a `CandidateLabeler`.

In [None]:
from snorkel.annotations import LabelManager

label_manager = LabelManager()

Next we run the `CandidateLabeler` to to apply the labeling functions to the training `CandidateSet`.

In [None]:
%time L_train = label_manager.create(session, train, 'LF Labels', f=LFs)
L_train

**OR** load if we've already created:

In [None]:
%time L_train = label_manager.load(session, train, 'LF Labels')
L_train

Now say we want to add a new labeling function to our matrix:

In [None]:
import random
def LF_test_3(candidate):
    return -1 if random.random() < 0.2 else 0

We can also add or rerun a single labeling function (or more!) with the below command. Note that we set the argument `expand_key_set` to `True` to indicate that the set of matrix columns should be allowed to expand. 

Do this to test changes to the labeling functions.

In [None]:
L_train = label_manager.update(session, train, 'LF Labels', True, f=[LF_test_3])
L_train

We can view statistics about the resulting label matrix:

In [None]:
L_train.lf_stats()

## Fitting the Generative Model
We estimate the accuracies of the labeling functions without supervision. Specifically, we estimate the parameters of a `NaiveBayes` generative model.

In [None]:
from snorkel.learning import NaiveBayes

gen_model = NaiveBayes()
gen_model.train(L_train)

In [None]:
gen_model.w

In [None]:
gen_model.save(session, 'Generative Params')

In [None]:
gen_model.load(session, 'Generative Params')
gen_model.w

We now apply the generative model to the training candidates.

In [None]:
train_marginals = gen_model.marginals(L_train)

## Training the Discriminative Model
We use the estimated probabilites to train a discriminative model that classifies each `Candidate` as a true or false mention.

In [None]:
from snorkel.learning import LogReg

disc_model = LogReg()
disc_model.train(F_train, train_marginals, n_iter=1500, rate=1e-5)

In [None]:
disc_model.w.shape

In [None]:
%time disc_model.save(session, "Discriminative Params")

In [None]:
w_prev = disc_model.w
%time disc_model.load(session, "Discriminative Params")
np.all(disc_model.w == w_prev)

## Evaluating on the Development `CandidateSet`

First, we create features for the development set.

Note that we use the training features feature set, because those are the only features for which we have learned parameters. Features that were not encountered during training, e.g., a token that does not appear in the training set, are ignored, because we do not have any information about them.

To do so with the `FeatureManager`, we call update with the new `CandidateSet`, the name of the training `AnnotationKeySet`, and the value `False` for the parameter `extend_key_set` to indicate that the `AnnotationKeySet` should not be expanded with new `Feature` keys encountered during processing.

In [None]:
%time F_dev = feature_manager.update(session, dev, 'Train Features', False)

**OR** if we've already created one, we can simply load as follows:

In [None]:
%time F_dev = feature_manager.load(session, dev, 'Train Features')

Next, we load the development set labels and gold candidates we made in Part III.

In [None]:
L_dev = label_manager.load(session, dev, "CDR Development Labels -- Gold")

In [None]:
gold_dev_set = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates -- Gold').one()

Now we can evaluate the discriminative model on the development set.

In [None]:
tp, fp, tn, fn = disc_model.score(F_dev, L_dev, gold_dev_set, b=0.4)

## Viewing Examples
After evaluating on the development `CandidateSet`, the labeling functions can be modified. Try changing the labeling functions to improve performance. You can view the true positives, false positives, true negatives, and false negatives using the `Viewer`.

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
if not AUTOMATED_TESTING:
    sv = SentenceNgramViewer(tp, session, annotator_name="Tutorial Part IV User")
else:
    sv = None

In [None]:
sv

In [None]:
sv.get_selected()[0].parent

In [None]:
sv.g

Next, in Part V, we will test our model on the test `CandidateSet`.