<img align="left" src="imgs/logo.jpg" width="50px" style="margin-right:10px">
# Snorkel Workshop: Extracting Spouse Relations <br> from the News
## Part 4: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a fairly effective deep learning model. More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np

# Connect to the database backend and initalize a Snorkel session
from lib.init import *
from snorkel.annotations import load_marginals
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

## I. Loading Candidates and Gold Labels


In [2]:
from snorkel.annotations import load_gold_labels

train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

train_marginals = load_marginals(session, split=0)

## II. Training a _Long Short-term Memory_ (LSTM) Neural Network

[LSTMs](https://en.wikipedia.org/wiki/Long_short-term_memory) can acheive state-of-the-art performance on many text classification tasks. We'll train a simple LSTM model below. 

In deep learning, hyperparamter tuning is very important and computationally expensive step in training models. For purposes of this tutorial, we've selected some pre-tuned settings so that you can train a model in under 10 minutes. Advanced users can look at our [grid search tutorial](https://github.com/HazyResearch/snorkel/blob/master/tutorials/advanced/Hyperparameter_Search.ipynb). 

**Parameter Definitions**
    
    n_epochs    A single pass through all the data in your training set
    dim         Vector embedding (i.e., learned representation) dimension 
    lr          The learning rate by which we update model weights after 
                computing the gradient
    dropout     A neural network regularization techique [0.0 - 1.0]
    print_freq  Print updates every k epochs
    batch_size  Estimate the gradient using k samples. Larger batch sizes run faster,
                but may perform worse
    max_sentence_length  The max length of an input sequence. Setting this too large 
                can slow your training down substantially  
                
### Please Note !!!
With the provided hyperparameters below, your model should train in about 9.5 minutes. Don't change them!

In [17]:
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   20,
    'dropout':    0.25,
    'print_freq': 1,
    'batch_size': 128,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=7)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[reRNN] Training model
[reRNN] n_train=5356  #epochs=20  batch size=128
[reRNN] Epoch 0 (12.89s)	Average loss=0.678806	Dev F1=17.20
[reRNN] Epoch 1 (29.11s)	Average loss=0.663771	Dev F1=17.26
[reRNN] Epoch 2 (45.69s)	Average loss=0.660834	Dev F1=17.18
[reRNN] Epoch 3 (62.46s)	Average loss=0.659845	Dev F1=18.39
[reRNN] Epoch 4 (80.11s)	Average loss=0.659381	Dev F1=17.16
[reRNN] Epoch 5 (98.66s)	Average loss=0.658705	Dev F1=17.26
[reRNN] Epoch 6 (118.06s)	Average loss=0.658434	Dev F1=17.49
[reRNN] Epoch 7 (139.20s)	Average loss=0.657881	Dev F1=16.49
[reRNN] Epoch 8 (161.58s)	Average loss=0.657722	Dev F1=17.99
[reRNN] Epoch 9 (185.26s)	Average loss=0.657645	Dev F1=17.98
[reRNN] Epoch 10 (207.66s)	Average loss=0.657276	Dev F1=16.79
[reRNN] Epoch 11 (230.90s)	Average loss=0.657069	Dev F1=16.64
[reRNN] Epoch 12 (254.99s)	Average loss=0.657084	Dev F1=16.05
[reRNN] Epoch 13 (279.11s)	Average loss=0.656984	Dev F1=16.55
[reRNN] Epoch 14 (299.94s)	Average loss=0.656826	Dev F1=16.49
[reRNN] Epoch 

Now, we get the precision, recall, and F1 score from the discriminative model:

In [18]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.117, Recall: 0.798, F1 Score: 0.205


We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [19]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.798
Neg. class accuracy: 0.473
Precision            0.117
Recall               0.798
F1                   0.205
----------------------------------------
TP: 174 | FP: 1308 | TN: 1175 | FN: 44



Finally, let's save our model for later use. 

In [20]:
lstm.save("spouse.lstm")

[reRNN] Model saved as <spouse.lstm>
