# Intro. to Snorkel: Extracting Spouse Relations from the News

## Part V: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a simple - but fairly effective - logistic regression model.  More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:///snorkel-intro'

from snorkel import SnorkelSession
session = SnorkelSession()

We repeat our definition of the `Spouse` `Candidate` subclass, and load the test set:

In [None]:
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

## 2. Training the Discriminative Model
We use the training marginals to train a discriminative model that classifies each `Candidate` as a true or false mention. We'll use a random hyperparameter search, evaluated on the development set labels, to find the best hyperparameters for our model. To run a hyperparameter search, we need labels for a development set. If they aren't already available, we can manually create labels using the Viewer.

In [22]:
from snorkel.annotations import load_marginals

train_marginals = load_marginals(session, split=0)

In [None]:
train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

In [None]:
from snorkel.annotations import load_gold_labels

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

In [None]:
len(train_cands)

In [37]:
from snorkel.contrib.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   50,
    'dropout':    0.5,
  #  'rebalance':  0.5,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train_cands, train_marginals, dev_candidates=dev_cands, dev_labels=L_gold_dev, **train_kwargs)

[reRNN] Dimension=100  LR=0.001
[reRNN] Begin preprocessing
[reRNN] Loaded 2754 candidates for evaluation
[reRNN] Preprocessing done (6.09s)
[reRNN] Training model
[reRNN] #examples=17272  #epochs=50  batch size=256
[reRNN] Epoch 0 (56.82s)	Average loss=0.558467	Dev F1=17.92
[reRNN] Epoch 1 (109.72s)	Average loss=0.484166	Dev F1=37.27
[reRNN] Epoch 2 (162.73s)	Average loss=0.473554	Dev F1=36.04
[reRNN] Epoch 3 (215.70s)	Average loss=0.471496	Dev F1=37.39
[reRNN] Epoch 4 (268.65s)	Average loss=0.470444	Dev F1=39.29
[reRNN] Epoch 5 (321.58s)	Average loss=0.469278	Dev F1=40.25
[reRNN] Epoch 6 (374.47s)	Average loss=0.468272	Dev F1=39.02
[reRNN] Epoch 7 (427.36s)	Average loss=0.467444	Dev F1=39.51
[reRNN] Epoch 8 (480.40s)	Average loss=0.466709	Dev F1=42.46
[reRNN] Epoch 9 (533.63s)	Average loss=0.466255	Dev F1=42.13
[reRNN] Epoch 10 (586.59s)	Average loss=0.466110	Dev F1=43.94
[reRNN] Epoch 11 (639.59s)	Average loss=0.465791	Dev F1=43.16
[reRNN] Epoch 12 (692.50s)	Average loss=0.465056	De

In [41]:
tp, fp, tn, fn = lstm.score(session, test_cands, L_gold_test, b=.7)

Scores (Un-adjusted)
Pos. class accuracy: 0.461
Neg. class accuracy: 0.958
Precision            0.464
Recall               0.461
F1                   0.462
----------------------------------------
TP: 89 | FP: 103 | TN: 2368 | FN: 104



In [49]:
from snorkel.contrib.rnn import reRNN

train_kwargs = {
    'lr':         0.0001,
    'dim':        100,
    'n_epochs':   50,
    'dropout':    0.5,
  #  'rebalance':  0.5,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=3)
lstm.train(train_cands, train_marginals, dev_candidates=dev_cands, dev_labels=L_gold_dev, **train_kwargs)

[reRNN] Dimension=100  LR=0.0001
[reRNN] Begin preprocessing
[reRNN] Loaded 2754 candidates for evaluation
[reRNN] Preprocessing done (9.82s)
[reRNN] Training model
[reRNN] #examples=17272  #epochs=50  batch size=256
[reRNN] Epoch 0 (99.73s)	Average loss=0.647129	Dev F1=0.00
[reRNN] Epoch 1 (189.94s)	Average loss=0.562570	Dev F1=0.00
[reRNN] Epoch 2 (279.77s)	Average loss=0.548148	Dev F1=0.00
[reRNN] Epoch 3 (487.53s)	Average loss=0.528396	Dev F1=0.00
[reRNN] Epoch 4 (1167.04s)	Average loss=0.508069	Dev F1=13.43
[reRNN] Epoch 5 (11983.05s)	Average loss=0.492076	Dev F1=29.55
[reRNN] Epoch 6 (12076.61s)	Average loss=0.483190	Dev F1=31.96
[reRNN] Epoch 7 (12166.91s)	Average loss=0.478926	Dev F1=33.33
[reRNN] Epoch 8 (12256.24s)	Average loss=0.476748	Dev F1=33.18
[reRNN] Epoch 9 (12345.06s)	Average loss=0.475008	Dev F1=33.18
[reRNN] Epoch 10 (12436.70s)	Average loss=0.473900	Dev F1=33.85
[reRNN] Epoch 11 (12529.63s)	Average loss=0.472954	Dev F1=34.07
[reRNN] Epoch 12 (12624.44s)	Average lo

In [None]:
from snorkel.contrib.rnn import reRNN

train_kwargs = {
    'lr':         0.005,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.5,
    'rebalance':  0.25,
    'print_freq': 5
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train_cands, train_marginals, dev_candidates=dev_cands, dev_labels=L_gold_dev, **train_kwargs)

Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.

Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

In [None]:
from snorkel.learning.utils import (
    MentionScorer, RandomSearch, ListParameter, RangeParameter
)

# Searching over learning rate
rate_param = RangeParameter('lr', 1e-6, 1e-2, step=1, log_base=10)
l1_param  = RangeParameter('l1_penalty', 1e-6, 1e-2, step=1, log_base=10)
l2_param  = RangeParameter('l2_penalty', 1e-6, 1e-2, step=1, log_base=10)

searcher = RandomSearch(session, disc_model, F_train, train_marginals, [rate_param, l1_param, l2_param], n=20)

Next, we'll load in our dev set labels. We will pick the optimal result from the hyperparameter search by testing against these labels:

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

Finally, we run the hyperparameter search / train the end extraction model:

In [None]:
np.random.seed(1701)
searcher.fit(F_dev, L_gold_dev, n_epochs=50, rebalance=0.5, print_freq=25)

_Note that to train a model without tuning any hyperparameters (at your own risk) just use the `train` method of the discriminative model. For instance, to train with 20 epochs and a learning rate of 0.001, you could run:_
```
disc_model.train(F_train, train_marginals, n_epochs=20, lr=0.001)
```

We can analyze the learned model by examining the weights. For example, we can print out the features with the highest weights.

In [None]:
w, _ = disc_model.get_weights()
largest_idxs = reversed(np.argsort(np.abs(w))[-5:])
for i in largest_idxs:
    print('Feature: {0: <70}Weight: {1:.6f}'.format(F_train.get_key(session, i).name, w[i]))

## 3. Evaluating on the Test Set

In this last section of the tutorial, we'll get the score we've been after: the performance of the extraction model on the blind test set (`split` 2). First, we load the test set labels and gold candidates we made in Part III.

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

Now, we get the precision, recall, and F1 score from the discriminative model:

In [None]:
p, r, f1 = disc_model.score(F_test, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [None]:
tp, fp, tn, fn = disc_model.error_analysis(session, F_test, L_gold_test)

Note that if this is the final test set that you will be reporting final numbers on, to avoid biasing results you should not inspect results.  However you can run the model on your _development set_ and, as we did in the previous part with the generative labeling function model, inspect examples to do error analysis.

##### More importantly, you completed the introduction to Snorkel! Give yourself a pat on the back!