<img align="left" src="imgs/logo.jpg" width="50px" style="margin-right:10px">
# Snorkel Workshop: Extracting Spouse Relations <br> from the News
## Part 4: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end extraction model.

For this tutorial, we will be training a fairly effective deep learning model. More generally, however, Snorkel plugs in with many ML libraries including [TensorFlow](https://www.tensorflow.org/), making it easy to use almost any state-of-the-art model as the end extractor!

In [2]:
%load_ext autoreload
%autoreload 
%matplotlib inline
import os
import numpy as np

# Connect to the database backend and initalize a Snorkel session
from lib.init import *
from snorkel.annotations import load_marginals
from snorkel.models import candidate_subclass

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## I. Loading Candidates and Gold Labels


In [25]:
from snorkel.annotations import load_gold_labels

train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

L_gold_train = load_gold_labels(session, annotator_name='gold', split=0, zero_one=True)
L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

#train_marginals = load_marginals(session, split=0)

## II. Training a _Long Short-term Memory_ (LSTM) Neural Network

[LSTMs](https://en.wikipedia.org/wiki/Long_short-term_memory) can acheive state-of-the-art performance on many text classification tasks. We'll train a simple LSTM model below. 

In deep learning, hyperparameter tuning is very important and computationally expensive step in training models. For purposes of this tutorial, we've pre-selected some settings so that you can train a model in under 10 minutes. Advanced users can look at our [Grid Search Tutorial](https://github.com/HazyResearch/snorkel/blob/master/tutorials/advanced/Hyperparameter_Search.ipynb) for more details on choosing these parameters. 

| Parameter           | Definition                                            |
|---------------------|--------------------------------------------------------------------------------------------------------|
| n_epochs            | A single pass through all the data in your training set                                                |
| dim                 | Vector embedding (i.e., learned representation) dimension                                              |
| lr,                 | The learning rate by which we update model weights after,computing the gradient                        |
| dropout             | A neural network regularization techique [0.0 - 1.0]                                                   |
| print_freq          | Print updates every k epochs                                                                           |
| batch_size          | Estimate the gradient using k samples. Larger batch sizes run faster, but may perform worse            |
| max_sentence_length | The max length of an input sequence. Setting this too large, can slow your training down substantially |
                
### Please Note !!!
With the provided hyperparameters below, your model should train in about 9.5 minutes.

In [3]:
from snorkel.learning.pytorch.rnn import LSTM

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.25,
    'print_freq': 1,
    'batch_size': 128,
    'max_sentence_length': 100
}

lstm = LSTM(n_threads=1)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[LSTM] Training model
[LSTM] n_train=3115  #epochs=10  batch size=128




[LSTM] Epoch 1 (12.32s)	Average loss=0.693709	Dev F1=12.37
[LSTM] Epoch 2 (29.49s)	Average loss=0.693280	Dev F1=12.06
[LSTM] Epoch 3 (45.57s)	Average loss=0.693159	Dev F1=11.80
[LSTM] Epoch 4 (61.72s)	Average loss=0.693085	Dev F1=13.12
[LSTM] Epoch 5 (78.21s)	Average loss=0.693019	Dev F1=12.70
[LSTM] Epoch 6 (95.07s)	Average loss=0.692990	Dev F1=12.59
[LSTM] Epoch 7 (112.05s)	Average loss=0.692936	Dev F1=12.41
[LSTM] Epoch 8 (129.04s)	Average loss=0.692899	Dev F1=12.63
[LSTM] Epoch 9 (144.90s)	Average loss=0.692852	Dev F1=12.53
[LSTM] Model saved as <LSTM>
[LSTM] Epoch 10 (162.11s)	Average loss=0.692813	Dev F1=12.80
[LSTM] Model saved as <LSTM>
[LSTM] Training done (166.52s)
[LSTM] Loaded model <LSTM>


Now, we get the precision, recall, and F1 score from the discriminative model:

In [4]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.088, Recall: 0.661, F1 Score: 0.156


We can also get the candidates returned in sets (true positives, false positives, true negatives, false negatives) as well as a more detailed score report:

In [5]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

Scores (Un-adjusted)
Pos. class accuracy: 0.661
Neg. class accuracy: 0.401
Precision            0.0882
Recall               0.661
F1                   0.156
----------------------------------------
TP: 144 | FP: 1488 | TN: 995 | FN: 74



Finally, let's save our model for later use. 

In [6]:
lstm.save("spouse.lstm")

[LSTM] Model saved as <spouse.lstm>


## Pandas Conversion

In [3]:
import sqlite3
import pandas as pd

In [26]:
def to_csv():
    db = sqlite3.connect('snorkel.db')
    cursor = db.cursor()
    dfs = {}
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    for table_name in tables:
        table_name = table_name[0]
        table = pd.read_sql_query("SELECT * from %s" % table_name, db)
        table.to_csv(table_name + '.csv', index_label='index')
        dfs[table_name] = table
    cursor.close()
    db.close()
    return dfs

dfs = to_csv()
print("Table names in the dictionary:", dfs.keys())

Table names in the dictionary: dict_keys(['feature_key', 'candidate', 'label_key', 'gold_label_key', 'context', 'prediction_key', 'stable_label', 'gold_label', 'marginal', 'label', 'prediction', 'document', 'feature', 'sentence', 'span', 'spouse'])


In [34]:
import pickle
pickle.dump(dfs, open( "list_of_tables.pkl", "wb" ) )

In [80]:
dfs['sentence'].count()

id                  67820
document_id         67820
position            67820
text                67820
words               67820
char_offsets        67820
abs_char_offsets    67820
lemmas              67820
pos_tags            67820
ner_tags            67820
dep_parents         67820
dep_labels          67820
entity_cids         67820
entity_types        67820
dtype: int64

In [78]:
dfs['candidate'].count()

id       27766
type     27766
split    27766
dtype: int64

In [46]:
dfs['span'].loc[0]

id             70412
sentence_id    28836
char_start        69
char_end          73
meta            None
Name: 0, dtype: object

In [44]:
dfs['span'].loc[0]['sentence_id']

28836

In [67]:
dfs['sentence'][dfs['sentence']['id'] == 28836]['text']

26244    While many women saw the less-than-perfect ima...
Name: text, dtype: object

In [61]:
a

26244    While many women saw the less-than-perfect ima...
Name: text, dtype: object