# Snorkel Workshop: Extracting Spouse Relations from the News

## Part 4: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a simple - but fairly effective - logistic regression model.  More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# Connect to the database backend and initalize a Snorkel session
from lib.init import *

We repeat our definition of the `Spouse` `Candidate` subclass, and load the test set:

In [2]:
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

We use the training marginals to train a discriminative model that classifies each `Candidate` as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer.

In [3]:
from snorkel.annotations import load_marginals

train_marginals = load_marginals(session, split=0)

In [4]:
train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

In [5]:
from snorkel.annotations import load_gold_labels

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

In [6]:
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   20,
    'dropout':    0.5,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[reRNN] Training model
[reRNN] n_train=17217  #epochs=20  batch size=256
[reRNN] Epoch 0 (40.25s)	Average loss=0.582597	Dev F1=13.55
[reRNN] Epoch 1 (84.93s)	Average loss=0.542717	Dev F1=33.33
[reRNN] Epoch 2 (132.68s)	Average loss=0.537067	Dev F1=35.32
[reRNN] Epoch 3 (178.95s)	Average loss=0.536432	Dev F1=37.55
[reRNN] Epoch 4 (223.45s)	Average loss=0.535999	Dev F1=34.89
[reRNN] Epoch 5 (268.81s)	Average loss=0.535783	Dev F1=37.27
[reRNN] Epoch 6 (315.81s)	Average loss=0.535106	Dev F1=37.78
[reRNN] Epoch 7 (360.98s)	Average loss=0.535300	Dev F1=37.94
[reRNN] Epoch 8 (407.18s)	Average loss=0.535028	Dev F1=38.80
[reRNN] Epoch 9 (454.38s)	Average loss=0.535006	Dev F1=39.16
[reRNN] Epoch 10 (500.35s)	Average loss=0.534756	Dev F1=39.64
[reRNN] Epoch 11 (549.25s)	Average loss=0.534703	Dev F1=40.47
[reRNN] Epoch 12 (597.67s)	Average loss=0.534757	Dev F1=39.57
[reRNN] Epoch 13 (651.31s)	Average loss=0.534573	Dev F1=40.76
[reRNN] Model saved as <reRNN>
[reRNN] Epoch 14 (702.29s)	Average loss=

Now, we get the precision, recall, and F1 score from the discriminative model:

In [7]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.384, Recall: 0.606, F1 Score: 0.470


We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [8]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.606
Neg. class accuracy: 0.915
Precision            0.384
Recall               0.606
F1                   0.47
----------------------------------------
TP: 132 | FP: 212 | TN: 2271 | FN: 86



Finally, let's save our model for later use. 

In [11]:
lstm.save("spouse.lstm")

[reRNN] Model saved as <best.lstm>
