<img align="left" src="imgs/logo.jpg" width="50px" style="margin-right:10px">
# Snorkel Workshop: Extracting Spouse Relations <br> from the News
## Part 3: Training the Generative Model

Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import re
import numpy as np

# Connect to the database backend and initalize a Snorkel session
from lib.init import *
from snorkel.models import candidate_subclass
from snorkel.annotations import load_gold_labels

from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

# I. Loading Labeling Matricies 

First we'll load our label matrices from notebook 2

In [2]:
from snorkel.annotations import LabelAnnotator

labeler = LabelAnnotator()
L_train = labeler.load_matrix(session, split=0)
L_dev   = labeler.load_matrix(session, split=1)

Now we set up and run the hyperparameter search, training our model with different hyperparamters and picking the best model configuration to keep. We'll set the random seed to maintain reproducibility.

Note that we are fitting our model's parameters to the training set generated by our labeling functions, while we are picking hyperparamters with respect to score over the development set labels which we created by hand.

# II: Unifying supervision

## A. Majority Vote
The most simple way to unify the output of all your LFs is by computed the _unweighted majority vote_. 

In [3]:
from lib.scoring import *

majority_vote_score(L_dev, L_gold_dev)

pos/neg    190:2621 6.8%/93.2%
precision  44.17
recall     37.89
f1         40.79


## B. Generative Model
In data programming, we use a more sophisitcated model to unify our labeling functions. We know that these labeling functions will not be perfect, and some may be quite low-quality, so we will _model_ their accuracies with a generative model, which Snorkel will help us easily apply.

This will ultimately produce a single set of **noise-aware training labels**, which we will then use to train an end extraction model in the next notebook.  For more technical details of this overall approach, see our [NIPS 2016 paper](https://arxiv.org/abs/1605.07723).

### 1. Training the Model
When training the generative model, we'll tune our hyperparamters using a simple grid search.  

In [4]:
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# use grid search to optimize the generative model
step_size_param     = ListParameter('step_size', [0.1 / L_train.shape[0], 1e-5])
decay_param         = ListParameter('decay', [0.9, 0.95])
epochs_param        = ListParameter('epochs', [50])

# search for the best model
param_grid = [step_size_param, decay_param, epochs_param]
searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=4, lf_propensity=False)
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, deps=set())

run_stats

Initialized RandomSearch search of size 4. Search space size = 4.


TypeError: __init__() got an unexpected keyword argument 'deps'

### 2. Model Accuracies
These are the weights learned for each LF

In [None]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])

In [None]:
train_marginals = gen_model.marginals(L_train)

### 3. Plotting Marginal Probabilities
One immediate santity check  you can peform using the generative model is to visually examine the distribution of predicted training marginals. Ideally, there should get a bimodal distribution with large seperation between each peaks.

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20, range=(0.0, 1.0))
plt.show()

### 4. Generative Model Metrics

In [None]:
dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.error_analysis(session, L_dev, L_gold_dev)

### 5. Saving our training labels

Finally, we'll save the `training_marginals`, which are our **"noise-aware training labels"**, so that we can use them in the next tutorial to train our end extraction model:

In [None]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)

# III. Advanced Generative Model Features

## A. Structure Learning

We may also want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! `DependencySelector` runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies. 

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.3)
print len(deps)

Now train the generative model with dependencies, we just pass in the above set as the `deps` argument to our model train function.

    searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=4, lf_propensity=False)
    gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, deps=deps)
    run_stats