# Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract *mentions of* **chemical-induced-disease relationships** from Pubmed abstracts, as per the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  At core, we will be constructing a model to classify _candidate_ CDR mentions as either true or false.

## Part IV: Training a Model with Data Programming

In this part of the tutorial, we will train a statistical model to differentiate between true and false `ChemicalDisease` mentions.

We will train this model using _data programming_, and we will **ignore** the training labels provided with the training data. This is a more realistic scenario; in the wild, hand-labeled training data is rare and expensive. Data programming enables us to train a model using only a modest amount of hand-labeled data for validation and testing. For more information on data programming, see the [NIPS 2016 paper](https://arxiv.org/abs/1605.07723).

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# Note: We run automated tests on this tutorial to make sure that it is always up to date! 
# However, certain interactive components cannot currently be tested automatically, and will 
# be skipped with if-then statements using the variable below
AUTOMATED_TESTING = os.environ.get('TESTING') is not None

import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


We repeat our definition of the `ChemicalDisease` `Candidate` subclass from Parts II and III.

In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

## Loading `CandidateSet` objects

We reload the `CandidateSet` objects from the previous parts of the tutorial. Note that we will now process all three (training, validation, and test) as we go, because each plays a distinct role in Parts IV and V.

In [3]:
from snorkel.models import CandidateSet

train = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Training Candidates').one()
dev = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates').one()
test = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Test Candidates').one()

candidate_sets = [train, dev, test]

## Creating Labeling Functions
Labeling functions are a core tool of data programming. They are heuristic functions that aim to classify candidates correctly. Their outputs will be automatically combined and denoised to estimate the probabilities of training labels for the training data.

## Inspecting some examples in the training set

We'll start here to come up with some ideas for LFs:

In [4]:
train_gold = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Training Candidates -- Gold').one()
len(train_gold)

1745

In [5]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
if not AUTOMATED_TESTING:
    sv = SentenceNgramViewer(train_gold, session, annotator_name="Tutorial Part IV User")
else:
    sv = None

<IPython.core.display.Javascript object>

In [6]:
sv

In [19]:
c = sv.get_selected()
c

ChemicalDisease(Span("scopolamine", parent=521, chars=[157,167], words=[22,22]), Span("overdosage", parent=521, chars=[180,189], words=[26,26]))

## Traditional "distant supervision" as a single LF

See http://ctdbase.org/downloads/;jsessionid=5B8E7F187A4772BB9478B6B3D9FCA5D1#cd

**TODO: Make download script**

In [7]:
DATA_ROOT = os.environ['SNORKELHOME'] + '/tutorial/data/'

CTD_lower = set()
with open(DATA_ROOT + 'dicts/CTD_chemicals_diseases.tsv', 'rb') as f:
    for line in f:
        if not line.startswith("#"):
            chem, chem_id, _, disease, disease_id, evidence, _, _, _, _ = line.split('\t')
            CTD_lower.add((chem.lower(), disease.lower()))
print len(CTD_lower)

1914160


# LFs

In [28]:
from random import random
from snorkel.lf_helpers import get_text_between
from snorkel.models import split_stable_id
from snorkel.lf_helpers import get_tagged_text

def LF_in_CDT(c, p_neg=0.1):
    """Match against the CDT KB, with random negative supervision as well"""
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower:
        return 1
    else:
        return -1 if random() < p_neg else 0
    
def LF_in_CDT_filtered_neg(c, p_neg=0.1):
    """Match against the CDT KB, with random negative supervision as well"""
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    
    # Don't look in the title
    if split_stable_id(c.chemical.parent.stable_id)[2] == 0:
        return -1 if random() < p_neg else 0
    
    # Filter by some basic heuristics
    #text = c.chemical.parent.text
    #FILTER_RGXS = r'in order to|none of|no|(did|does) not|associated|protective|against|risk of|was calculated|study|seems|suggest|is reported|patient|rat|mouse'
    
    if (chem, disease) in CTD_lower:
        return 1
    else:
        return -1 if random() < p_neg else 0
    
def LF_in_CDT_filtered_pos(c, p_neg=0.1):
    """Match against the CDT KB, with random negative supervision as well"""
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    
    # Filter by some basic heuristics
    POS_RGX = r'{{A}}.{0,10}(caused|secondary to|induced).{0,10}{{B}}'
    
    if (chem, disease) in CTD_lower and re.search(POS_RGX, get_tagged_text(c)):
        return 1
    else:
        return -1 if random() < p_neg else 0

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'associated with', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0


ANIMAL_RGX = r'mouse|mice'

def LF_animal(c):
    return -1 if re.search(ANIMAL_RGX, get_tagged_text(c), flags=re.I) else 0

def LF_cause(c):
    return 1 if re.search(r'cause', get_text_between(c), flags=re.I) else 0

def LF_in_CDT_causes(c):
    """Match against the CDT KB, with random negative supervision as well"""
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower:
        if re.search(r'causes', get_text_between(c), flags=re.I):
            return 1
    return 0

def LF_in_CDT_induced(c):
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower and re.search(r'{{A}}.{0,20}induce.{0,20}{{B}}', get_tagged_text(c), flags=re.I):
        return 1
    return 0

def LF_in_CDT_exposure(c):
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower and re.search(r'{{B}}.{0,20}expos.{0,20}{{A}}', get_tagged_text(c), flags=re.I):
        return 1
    return 0

def LF_in_CDT_caused_by(c):
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower and re.search(r'{{B}}.{0,20}caused by.{0,20}{{A}}', get_tagged_text(c), flags=re.I):
        return 1
    return 0

def LF_in_CDT_caused_by_2(c):
    chem    = c.chemical.get_span().lower()
    disease = c.disease.get_span().lower()
    if (chem, disease) in CTD_lower and re.search(r'{{B}}.*(caused|induced) by.*{{A}}', get_tagged_text(c), flags=re.I):
        return 1
    return 0

In [29]:
from snorkel.lf_helpers import get_matches

matches = get_matches(LF_in_CDT_caused_by_2, train)
print len(set(matches).intersection(train_gold))

56 matches
35


## Collecting LFs...

In [30]:
LFs = [
    LF_weak_assertions,
    LF_animal,
    LF_in_CDT_causes,
    LF_in_CDT_induced,
    LF_in_CDT_caused_by_2
]

# _Notes:_

* LF application / testing is slow
* Need to be able to get empirical accuracy estimates quickly
    - Specifically over single LFs, without touching the DB

## Applying Labeling Functions

First we construct a `CandidateLabeler`.

In [31]:
from snorkel.annotations import LabelManager

label_manager = LabelManager()

Next we run the `CandidateLabeler` to to apply the labeling functions to the training `CandidateSet`.

In [32]:
%time L = label_manager.create(session, train, 'LF Labels 3', f=LFs)
L


CPU times: user 2min 31s, sys: 14.8 s, total: 2min 46s
Wall time: 2min 35s


<27002x5 sparse matrix of type '<type 'numpy.float64'>'
	with 4837 stored elements in Compressed Sparse Row format>

We can view statistics about the resulting label matrix:

In [33]:
L.lf_stats()

Unnamed: 0,conflicts,coverage,j,overlaps
LF_weak_assertions,0.000444,0.119769,0,0.005
LF_animal,0.000222,0.055996,1,0.004777
LF_in_CDT_causes,0.000111,0.000111,2,0.000111
LF_in_CDT_induced,0.000259,0.001185,3,0.000259
LF_in_CDT_caused_by_2,0.000259,0.002074,4,0.000259


In [34]:
len(train)

27002

## Fitting the Generative Model
We estimate the accuracies of the labeling functions without supervision. Specifically, we estimate the parameters of a `NaiveBayes` generative model.

In [35]:
from snorkel.learning import NaiveBayes

gen_model = NaiveBayes()
gen_model.train(L)

because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



Training marginals (!= 0.5):	27002
Features:			5
Begin training for rate=0.01, mu=1e-06
	Learning epoch = 0	Gradient mag. = 0.023791
	Learning epoch = 250	Gradient mag. = 0.025721
	Learning epoch = 500	Gradient mag. = 0.026371
	Learning epoch = 750	Gradient mag. = 0.026994
Final gradient magnitude for rate=0.01, mu=1e-06: 0.028


In [36]:
gen_model.w

array([ 1.17745996,  1.19277661,  0.99321592,  0.98242971,  0.97965408])

In [None]:
gen_model.save(session, 'Generative Params 2')

In [None]:
gen_model.load(session, 'Generative Params 2')
gen_model.w

We now apply the generative model to the training candidates.

In [None]:
train_marginals = gen_model.marginals(L)

## Training the Discriminative Model
We use the estimated probabilites to train a discriminative model that classifies each `Candidate` as a true or false mention.

In [None]:
from snorkel.learning import LogReg

disc_model = LogReg()
disc_model.train(F, train_marginals, n_iter=1500, rate=1e-5)

In [None]:
disc_model.w.shape

In [None]:
%time disc_model.save(session, "Discriminative Params 1")

In [None]:
w_prev = disc_model.w
%time disc_model.load(session, "Discriminative Params 1")
np.all(disc_model.w == w_prev)

## Evaluating on the Development `CandidateSet`

In [None]:
test_labels=[]
for candidate in sorted_test_candidates:
    test_labels.append(1 if candidate in gold_candidate_set else -1)
test_labels = np.asarray(test_labels)

score(sorted_test_candidates, test_labels, pred, gold_candidate_set, \
      train_marginals=train_marginals, test_marginals=test_marginals)

After evaluating on the development `CandidateSet`, the labeling functions can be modified. Try changing the labeling functions to improve performance. You can view the true positives, false positives, true negatives, and false negatives using the `Viewer`.

## Saving the Discriminative Model's Parameters
We save the model's parameters for use in Part V.

Next, in Part V, we will test our model on the test `CandidateSet`.