# Intro. to Snorkel: Extracting Spouse Relations from the News

## Part 3: Training the Generative Model

Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

We repeat our definition of the `Spouse` `Candidate` subclass, and load the test set:

In [None]:
from snorkel.models import candidate_subclass
Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

Load development set gold labels

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

Helper functions

In [None]:
import re
from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

## 1. Labeling Functions

Add all your labeling functions (and their dependencies) here. These will be used to train our generative model.

Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.

Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

In [None]:
LFs = [
]

In [None]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

np.random.seed(1701)
%time L_train = labeler.apply(split=0)
L_train.shape

## 2: Unifying supervision

### Majority Vote
The most simple way to unify the output of all your LFs is by computed the _unweighted majority vote_. 

In [None]:
L_dev = labeler.apply_existing(split=1)

In [None]:
from utils import *

majority_vote_score(L_dev, L_gold_dev)

### Generative Model
In data programming, we use a more sophisitcated model to unify our labeling functions. We know that these labeling functions will not be perfect, and some may be quite low-quality, so we will _model_ their accuracies with a generative model, which Snorkel will help us easily apply.

This will ultimately produce a single set of **noise-aware training labels**, which we will then use to train an end extraction model in the next notebook.  For more technical details of this overall approach, see our [NIPS 2016 paper](https://arxiv.org/abs/1605.07723).

In [None]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train, epochs=500, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=1e-4)

In [None]:
train_marginals = gen_model.marginals(L_train)

### Plotting Marginal Probabilities
One immediate santity check  you can peform using the generative model is to visually examine the distribution of predicted training marginals. Ideally, there should get a bimodal distribution with large seperation between each peaks.

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.score(session, L_dev, L_gold_dev)

## 3. Evaluating on the Test Set

In this last section of the tutorial, we'll get the score we've been after: the performance of the extraction model on the blind test set (`split` 2). First, we load the test set labels and gold candidates we made in Part III.

In [None]:
from snorkel.annotations import load_gold_labels
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

Now, we score using the discriminative model:

In [None]:
L_test = labeler.apply_existing(split=2)

In [None]:
test_marginals = gen_model.marginals(L_test)
_, _, _, _ = gen_model.score(session, L_test, L_gold_test)

Note that if this is the final test set that you will be reporting final numbers on, to avoid biasing results you should not inspect results.  However you can run the model on your _development set_ and, as we did in the previous part with the generative labeling function model, inspect examples to do error analysis.

### Saving our training labels

Finally, we'll save the `training_marginals`, which are our **"noise-aware training labels"**, so that we can use them in the next tutorial to train our end extraction model:

In [None]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)

## 4. Structure Learning

We want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! `DependencySelector` runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies. 

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.3)
print len(deps)

Now we'll train the generative model, using the `deps` argument to account for the learned dependencies. We'll also model LF propensity here, unlike the intro tutorial. In addition to learning the accuracies of the LFs, this also learns their likelihood of labeling an example.

In [None]:
gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, epochs=500, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=1e-4
)

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.score(session, L_dev, L_gold_dev)