<img align="left" src="imgs/logo.jpg" width="50px" style="margin-right:10px">
# Snorkel Workshop: Extracting Spouse Relations <br> from the News
## Part 4: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a fairly effective deep learning model. More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np

# Connect to the database backend and initalize a Snorkel session
from lib.init import *
from snorkel.annotations import load_marginals
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

## I. Loading Candidates and Gold Labels


In [3]:
train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

train_marginals = load_marginals(session, split=0)

In [4]:
from snorkel.annotations import load_gold_labels

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

## II. Training a _Long Short-term Memory_ (LSTM) Neural Network

[LSTMs](https://en.wikipedia.org/wiki/Long_short-term_memory) can acheive state-of-the-art performance on many text classification tasks. We'll train a simple LSTM model below. 

In deep learning, hyperparamter tuning is very important and computationally expensive step in training models. For purposes of this tutorial, we've selected some pre-tuned settings so that you can train a model in under 10 minutes. Advanced users can look at the grid search tutorial under the advanced notebooks section of this workshop. 

In [None]:
# from snorkel.learning.disc_models.rnn import reRNN
# from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# batch_size_param  = ListParameter('batch_size', [32, 64, 128, 256])
# rate_param        = RangeParameter('lr', 1e-4, 1e-2, step=1, log_base=10)
# dropout_param     = RangeParameter('dropout', 0.0, 0.5, step=0.25)
# balance_param     = ListParameter('rebalance', [0.0, 0.3, 0.5])
# b_param           = ListParameter('b', [0.5, 0.6, 0.7])
# dim_param         = ListParameter('dim', [100])

# param_grid        = [rate_param, dropout_param, dim_param, batch_size_param, balance_param, b_param]

# np.random.seed(1701)
# searcher = RandomSearch(reRNN, param_grid, train_cands, train_marginals, n=4, n_threads=1)
# %time lstm, run_stats = searcher.fit(dev_cands, L_gold_dev, n_epochs=10, max_sentence_length=100,
#                                      print_freq=5, n_threads=4)

# run_stats

In [None]:
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   20,
    'dropout':    0.5,
    'print_freq': 1,
    'batch_size': 32,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=4)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

Now, we get the precision, recall, and F1 score from the discriminative model:

In [None]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [None]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Finally, let's save our model for later use. 

In [None]:
lstm.save("spouse.lstm")