# Intro. to Snorkel: Extracting Spouse Relations from the News

## Part III: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a simple - but fairly effective - logistic regression model.  More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:///snorkel-intro'

from snorkel import SnorkelSession
session = SnorkelSession()

We repeat our definition of the `Spouse` `Candidate` subclass, and load the test set:

In [2]:
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

We use the training marginals to train a discriminative model that classifies each `Candidate` as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer.

In [3]:
from snorkel.annotations import load_marginals

train_marginals = load_marginals(session, split=0)

In [4]:
train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

In [5]:
from snorkel.annotations import load_gold_labels

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

In [6]:
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.5,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[reRNN] Training model
[reRNN] n_train=17337  #epochs=10  batch size=256
[reRNN] Epoch 0 (56.22s)	Average loss=0.579711	Dev F1=8.23
[reRNN] Epoch 1 (111.87s)	Average loss=0.540559	Dev F1=31.23
[reRNN] Epoch 2 (170.00s)	Average loss=0.536027	Dev F1=36.19
[reRNN] Epoch 3 (227.51s)	Average loss=0.535496	Dev F1=37.56
[reRNN] Epoch 4 (282.67s)	Average loss=0.534999	Dev F1=37.14
[reRNN] Epoch 5 (338.99s)	Average loss=0.534686	Dev F1=38.61
[reRNN] Epoch 6 (400.96s)	Average loss=0.534425	Dev F1=38.30
[reRNN] Epoch 7 (464.35s)	Average loss=0.534329	Dev F1=38.68
[reRNN] Model saved as <reRNN>
[reRNN] Epoch 8 (530.23s)	Average loss=0.534315	Dev F1=40.00
[reRNN] Model saved as <reRNN>
[reRNN] Epoch 9 (596.07s)	Average loss=0.534055	Dev F1=41.85
[reRNN] Model saved as <reRNN>
[reRNN] Training done (601.11s)
INFO:tensorflow:Restoring parameters from checkpoints/reRNN/reRNN-9
[reRNN] Loaded model <reRNN>


Now, we get the precision, recall, and F1 score from the discriminative model:

In [9]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.318, Recall: 0.487, F1 Score: 0.384


We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [8]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.487
Neg. class accuracy: 0.918
Precision            0.318
Recall               0.487
F1                   0.384
----------------------------------------
TP: 94 | FP: 202 | TN: 2269 | FN: 99



Note that if this is the final test set that you will be reporting final numbers on, to avoid biasing results you should not inspect results.  However you can run the model on your _development set_ and, as we did in the previous part with the generative labeling function model, inspect examples to do error analysis.

##### More importantly, you completed the introduction to Snorkel! Give yourself a pat on the back!