# Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract *mentions of* **chemical-induced-disease relationships** from Pubmed abstracts, as per the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.

Let's start by reloading from the last notebook.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()
test = session.query(ChemicalDisease).filter(ChemicalDisease.split == 2).all()

print 'Training set:\t{0} candidates'.format(len(train))
print 'Dev set:\t{0} candidates'.format(len(dev))
print 'Test set:\t{0} candidates'.format(len(test))    

Training set:	8272 candidates
Dev set:	888 candidates
Test set:	4620 candidates


In [3]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, split=0)

### NEW

In [4]:
from snorkel.annotations import FeatureAnnotator
featurizer = FeatureAnnotator()

In [5]:
%time F_train = featurizer.apply(split=0, parallelism=1)

Clearing existing...
Running UDF...
CPU times: user 8min 8s, sys: 5.86 s, total: 8min 14s
Wall time: 8min 13s


In [6]:
from snorkel.annotations import load_gold_labels
L_gold_train = load_gold_labels(session, annotator_name='gold', split=0)
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [7]:
from snorkel.learning import SparseLogisticRegression
disc_model = SparseLogisticRegression()

In [8]:
from snorkel.learning.utils import MentionScorer
from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# Searching over learning rate
rate_param = RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10)
l1_param  = RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10)
l2_param  = RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)

searcher = RandomSearch(session, disc_model, F_train, train_marginals, [rate_param, l1_param, l2_param], n=5)

Initialized RandomSearch search of size 5. Search space size = 125.


In [9]:
%time F_dev = featurizer.apply_existing(split=1, parallelism=1)

Clearing existing...
Running UDF...
CPU times: user 48 s, sys: 476 ms, total: 48.5 s
Wall time: 48.3 s


In [10]:
%time searcher.fit(F_dev, L_gold_dev, n_epochs=50, rebalance=True, print_freq=25)

[1] Testing lr = 1.00e-05, l1_penalty = 1.00e-02, l2_penalty = 1.00e-02
[SparseLR] lr=1e-05 l1=0.01 l2=0.01
[SparseLR] Building model
[SparseLR] Training model
[SparseLR] #examples=3198  #epochs=50  batch size=100
[SparseLR] Epoch 0 (2.44s)	Avg. loss=0.791743	NNZ=122840
[SparseLR] Epoch 25 (61.33s)	Avg. loss=0.717843	NNZ=122840
[SparseLR] Epoch 49 (118.88s)	Avg. loss=0.678038	NNZ=122840
[SparseLR] Training done (118.88s)
[SparseLR] Model saved. To load, use name
		SparseLR_0
[2] Testing lr = 1.00e-03, l1_penalty = 1.00e-03, l2_penalty = 1.00e-02
[SparseLR] lr=0.001 l1=0.001 l2=0.01
[SparseLR] Building model
[SparseLR] Training model
[SparseLR] #examples=3198  #epochs=50  batch size=100
[SparseLR] Epoch 0 (2.50s)	Avg. loss=0.660500	NNZ=122840
[SparseLR] Epoch 25 (63.37s)	Avg. loss=0.361979	NNZ=122840
[SparseLR] Epoch 49 (122.16s)	Avg. loss=0.343909	NNZ=122840
[SparseLR] Training done (122.16s)
[SparseLR] Model saved. To load, use name
		SparseLR_1
[3] Testing lr = 1.00e-02, l1_penalty =

Unnamed: 0,lr,l1_penalty,l2_penalty,Prec.,Rec.,F1
3,0.01,0.001,0.001,0.552632,0.496622,0.523132
1,0.001,0.001,0.01,0.607143,0.402027,0.48374
4,1e-06,0.0001,1e-05,0.330508,0.527027,0.40625
2,0.01,0.0001,0.001,0.643939,0.287162,0.397196
0,1e-05,0.01,0.01,0.714286,0.168919,0.273224


In [11]:
disc_model.train(F_train, train_marginals, n_epochs=50, lr=0.01, l1_penalty=0.000001, l2_penalty=0.01, rebalance=True)

[SparseLR] lr=0.01 l1=1e-06 l2=0.01
[SparseLR] Building model
[SparseLR] Training model
[SparseLR] #examples=3198  #epochs=50  batch size=100
[SparseLR] Epoch 0 (2.46s)	Avg. loss=0.558532	NNZ=122840
[SparseLR] Epoch 5 (14.45s)	Avg. loss=0.338575	NNZ=122840
[SparseLR] Epoch 10 (26.83s)	Avg. loss=0.349029	NNZ=122840
[SparseLR] Epoch 15 (38.95s)	Avg. loss=0.418208	NNZ=122840
[SparseLR] Epoch 20 (51.07s)	Avg. loss=0.357322	NNZ=122840
[SparseLR] Epoch 25 (63.30s)	Avg. loss=0.357986	NNZ=122840
[SparseLR] Epoch 30 (75.55s)	Avg. loss=0.357713	NNZ=122840
[SparseLR] Epoch 35 (87.96s)	Avg. loss=0.359545	NNZ=122840
[SparseLR] Epoch 40 (100.36s)	Avg. loss=0.359664	NNZ=122840
[SparseLR] Epoch 45 (113.09s)	Avg. loss=0.373870	NNZ=122840
[SparseLR] Epoch 49 (123.02s)	Avg. loss=0.366724	NNZ=122840
[SparseLR] Training done (123.02s)


In [12]:
TP, FP, TN, FN = disc_model.score(session, F_dev, L_gold_dev)

Scores (Un-adjusted)
Pos. class accuracy: 0.507
Neg. class accuracy: 0.796
Precision            0.554
Recall               0.507
F1                   0.529
----------------------------------------
TP: 150 | FP: 121 | TN: 471 | FN: 146



In [13]:
from snorkel.learning.utils import scores_from_counts
p, r, f1 = scores_from_counts(TP, FP, TN, FN)
print(p)
print(r)
print(f1)

0.553505535055
0.506756756757
0.529100529101


### NEW

# Part V: Training an extraction model

In the intro tutorial, we automatically featurized the candidates and trained a linear model over these features. Here, we'll train a more complicated model for relation extraction: an LSTM network. You can read more about LSTMs [here](https://en.wikipedia.org/wiki/Long_short-term_memory) or [here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/). An LSTM is a type of recurrent neural network and automatically generates a numerical representation for the candidate based on the sentence text, so no need for featurizing explicitly as in the intro tutorial. LSTMs take longer to train, and Snorkel doesn't currently support hyperparameter searches for them. We'll train a single model here, but feel free to try out other parameter sets. Just make sure to use the development set - and not the test set - for model selection.

In [14]:
from snorkel.contrib.learning import reLSTM

lstm = reLSTM()
lstm.train(
    train, train_marginals, lr=0.005, dim=200, n_epochs=30,
    dropout_rate=0.5, rebalance=0.25, print_freq=5
)

[reLSTM] Dimension=200  LR=0.005
[reLSTM] Begin preprocessing
[reLSTM] Preprocessing done (25.01s)
[reLSTM] Training model
[reLSTM] #examples=6396  #epochs=30  batch size=100
[reLSTM] Epoch 0 (80.67s)	Average loss=0.689927
[reLSTM] Epoch 5 (468.00s)	Average loss=0.507944
[reLSTM] Epoch 10 (863.39s)	Average loss=0.477988
[reLSTM] Epoch 15 (1254.26s)	Average loss=0.471424
[reLSTM] Epoch 20 (1658.61s)	Average loss=0.468855
[reLSTM] Epoch 25 (2185.58s)	Average loss=0.467803
[reLSTM] Epoch 29 (2608.96s)	Average loss=0.466509
[reLSTM] Training done (2608.96s)


### Scoring on the test set

Finally, we'll evaluate our performance on the blind test set of 500 documents. We'll load labels similar to how we did for the development set, and use the `score` function of our extraction model to see how we did.

In [15]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=2, annotator='gold')

AnnotatorLabels created: 4620


In [16]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)
L_gold_test

<4620x1 sparse matrix of type '<type 'numpy.float64'>'
	with 4620 stored elements in Compressed Sparse Row format>

In [17]:
_, _, _, _ = lstm.score(session, test, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.308
Neg. class accuracy: 0.907
Precision            0.615
Recall               0.308
F1                   0.411
----------------------------------------
TP: 464 | FP: 291 | TN: 2824 | FN: 1041

