In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

In [3]:
from snorkel.models import CandidateSet

train = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Train Candidates').one()
#dev = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates').one()
#test = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Test Candidates').one()

# GET FEATS

In [4]:
from snorkel.annotations import FeatureManager

feature_manager = FeatureManager()

### Get span features

In [5]:
from snorkel.features import get_span_feats

%time F_train_span = feature_manager.create(session, train, 'training span feats', get_span_feats)


Loading sparse Feature matrix...
CPU times: user 6min 58s, sys: 2.35 s, total: 7min 1s
Wall time: 7min


In [None]:
%time F_dev_span = feature_manager.update(session, dev, 'training span feats', False, get_span_feats)
%time F_test_span = feature_manager.update(session, test, 'training span feats', False, get_span_feats)

In [None]:
%time F_train_span = feature_manager.load(session, train, 'training span feats')
%time F_dev_span = feature_manager.load(session, dev, 'training span feats')
%time F_test_span = feature_manager.load(session, test, 'training span feats')

In [None]:
F_train_span

### Get mention split feats

In [6]:
from snorkel.features import get_span_splits

%time F_train_splits = feature_manager.create(session, train, 'training span splits', get_span_splits)

[=                                       ] 0%

NameError: global name 're' is not defined

In [None]:
%time F_dev_splits = feature_manager.update(session, dev, 'training span splits', False, get_span_splits)
%time F_test_splits = feature_manager.update(session, test, 'training span splits', False, get_span_splits)

In [None]:
%time F_train_splits = feature_manager.load(session, train, 'training span splits')
%time F_dev_splits = feature_manager.load(session, dev, 'training span splits')
%time F_test_splits = feature_manager.load(session, test, 'training span splits')

### Get key mention

In [None]:
from cdr_feats import get_is_key

%time F_train_keys = feature_manager.create(session, train, 'training key ents', get_is_key)

In [None]:
%time F_dev_keys = feature_manager.update(session, dev, 'training key ents', False, get_is_key)
%time F_test_keys = feature_manager.update(session, test, 'training key ents', False, get_is_key)

In [None]:
%time F_train_keys = feature_manager.load(session, train, 'training key ents')
%time F_dev_keys = feature_manager.load(session, dev, 'training key ents')
%time F_test_keys = feature_manager.load(session, test, 'training key ents')

### Get title span feats

In [None]:
from cdr_feats import get_title_span_feats

%time F_train_title_span = feature_manager.create(session, train, 'training title span', get_title_span_feats)

In [None]:
%time F_dev_title_span = feature_manager.update(session, dev, 'training title span', False, get_title_span_feats)
%time F_test_title_span = feature_manager.update(session, test, 'training title span', False, get_title_span_feats)

In [None]:
%time F_train_title_span = feature_manager.load(session, train, 'training title span')
%time F_dev_title_span = feature_manager.load(session, dev, 'training title span')
%time F_test_title_span = feature_manager.load(session, test, 'training title span')

In [None]:
print repr(F_train_span)
print repr(F_train_splits)
print repr(F_train_keys)
print repr(F_train_title_span)

# Learn gen model

In [None]:
%time L_train = label_manager.load(session, train, 'LF Labels')
L_train

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.3)

In [None]:
deps

In [None]:
from snorkel.learning import GenerativeModel
from snorkel.learning.constants import *

gen_model = GenerativeModel(lf_prior=True, lf_propensity=True)
gen_model.train(L_train, step_size=0.1/L_train.shape[0], reg_type=2, epochs=15, decay=1.0, reg_param=0.00001)

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.hist(train_marginals, bins=30)
plt.show()

# TRAINING TRICKS ON TRICKS ON TRICKS

In [None]:
from snorkel.learning.utils import ListParameter, RangeParameter

In [None]:
from snorkel.learning.fastmulticontext import get_matrix_keys
train_embed_xs = get_matrix_keys([F_train_span, F_train_title_span])

In [None]:
from itertools import product
from utils import CDRFMCT, CDRRandomSearch

epoch_param = RangeParameter('epoch', 20, 200, step=20)
lr_param = RangeParameter('lr', 1e-5, 0.1, step=0.5, log_base=10)
lambda_param = RangeParameter('lr', 1e-5, 10, step=1, log_base=10)
dim_param = RangeParameter('dim', 25, 150, step=25)
minct_opts = [1, 2, 3, 5, 7, 10, 12, 15]
minct_param = ListParameter('min_ct', minct_opts)

disc_model = CDRFMCT()

searcher = CDRRandomSearch(disc_model, train_marginals, train_embed_xs, 20,
                           epoch_param, lr_param, dim_param, lambda_param, minct_param)

In [None]:
from snorkel.learning.fastmulticontext import get_matrix_keys
dev_embed_xs = get_matrix_keys([F_dev_span, F_dev_title_span])

In [None]:
from snorkel.models import Corpus
dev_corpus = session.query(Corpus).filter(Corpus.name == 'CDR Development').one()

D = searcher.fit(dev_embed_xs, F_dev_keys.toarray(), dev_doc_dict, dev, dev_corpus, b=0.5,
                 raw_xs=F_train_keys.toarray(), n_threads=4, n_print=10000)

In [None]:
D

In [None]:
disc_model.train(train_marginals, train_embed_xs, raw_xs=F_train_keys.toarray(),
                 epoch=160, dim=100, lr = 0.001, min_ct = [10,5], lambda_l2 = 0.01,
                 seed=1701, n_threads=4)

In [None]:
test_embed_xs = get_matrix_keys([F_test_span, F_test_title_span])

In [None]:
test_corpus = session.query(Corpus).filter(Corpus.name == 'CDR Test').one()
buckets = disc_model.score(test_embed_xs, F_test_keys.toarray(), test_doc_dict, test, test_corpus, b=0.5)