In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import re
import numpy as np

from lib.init import *
from snorkel.models import candidate_subclass
from snorkel.annotations import load_gold_labels

Spouse = candidate_subclass('Spouse', ['person1', 'person2'])

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

## Load LFs Snapshot

In [None]:
import glob
import time
import cPickle
import datetime

snapshots = glob.glob("label_fn.*.pkl")
ts = [map(int,fn.split(".")[1].split("_")) for fn in snapshots]
ts = [(datetime.datetime(*t),snapshots[i]) for i,t in enumerate(ts)]
snapshots = sorted(ts, reverse=1)

LFs = cPickle.load(open(snapshots[0][-1],"rb"))

print "Loaded {} LFs".format(len(LFs))

In [None]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator()

In [None]:
np.random.seed(1701)

%time L_train = labeler.apply(split=0, lfs=LFs, parallelism=4)
print L_train.shape

%time L_dev = labeler.apply_existing(split=1, lfs=LFs, parallelism=4)
print L_dev.shape

## Majority Vote

In [None]:
from lib.scoring import *

majority_vote_score(L_dev, L_gold_dev)

## Generative Model

In [None]:
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch, ListParameter, RangeParameter

# use grid search to optimize the generative model
step_size_param     = ListParameter('step_size', [1e-3, 1e-5])
decay_param         = ListParameter('decay', [0.95])
epochs_param        = ListParameter('epochs', [10, 50])
reg_param           = ListParameter('reg_param', [1e-3, 1e-6])
prior_param         = ListParameter('LF_acc_prior_weight_default', [1.0, 0.9, 0.8])

# search for the best model
param_grid = [step_size_param, decay_param, epochs_param, reg_param, prior_param]
searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=24, lf_propensity=False)
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, deps=set())

run_stats

In [None]:
train_marginals = gen_model.marginals(L_train)

## Disc Model

In [None]:
from snorkel.annotations import load_gold_labels

train_cands = session.query(Spouse).filter(Spouse.split == 0).order_by(Spouse.id).all()
dev_cands   = session.query(Spouse).filter(Spouse.split == 1).order_by(Spouse.id).all()
test_cands  = session.query(Spouse).filter(Spouse.split == 2).order_by(Spouse.id).all()

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

In [None]:
from snorkel.learning.disc_models.rnn import reRNN

batch_size_param  = ListParameter('batch_size', [32, 128])
rate_param        = RangeParameter('lr', 1e-4, 1e-2, step=1, log_base=10)
dropout_param     = RangeParameter('dropout', 0.0, 0.5, step=0.25)
balance_param     = ListParameter('rebalance', [0.5])
b_param           = ListParameter('b', [0.5, 0.6])
dim_param         = ListParameter('dim', [100])

param_grid        = [rate_param, dropout_param, dim_param,
                    batch_size_param, balance_param, b_param]

searcher = RandomSearch(reRNN, param_grid, train_cands, train_marginals, n=4, n_threads=7)

lstm, run_stats = searcher.fit(dev_cands, L_gold_dev, n_epochs=25, print_freq=1,
                               n_threads=1, X_dev=dev_cands, Y_dev=L_gold_dev)
print run_stats

In [None]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

In [None]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)