In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

import numpy as np
from snorkel import SnorkelSession

from utils import mesh_pairs_from_candidate

session = SnorkelSession()

In [None]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

In [None]:
from snorkel.models import CandidateSet

train = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Train Candidates').one()
dev = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates').one()
test = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Test Candidates').one()

# GET FEATS

In [None]:
from snorkel.annotations import FeatureManager

feature_manager = FeatureManager()

### Get span features

In [None]:
from snorkel.features import get_span_feats

%time F_train_span = feature_manager.create(session, train, 'training span feats n', get_span_feats)

In [None]:
%time F_dev_span = feature_manager.update(session, dev, 'training span feats n', False, get_span_feats)
%time F_test_span = feature_manager.update(session, test, 'training span feats n', False, get_span_feats)

In [None]:
%time F_train_span = feature_manager.load(session, train, 'training span feats')
%time F_dev_span = feature_manager.load(session, dev, 'training span feats')
%time F_test_span = feature_manager.load(session, test, 'training span feats')

In [None]:
F_train_span

### Get mention split feats

In [None]:
from cdr_feats import get_span_splits

%time F_train_splits = feature_manager.create(session, train, 'training span splits', get_span_splits)

In [None]:
%time F_dev_splits = feature_manager.update(session, dev, 'training span splits', False, get_span_splits)
%time F_test_splits = feature_manager.update(session, test, 'training span splits', False, get_span_splits)

In [None]:
%time F_train_splits = feature_manager.load(session, train, 'training span splits')
%time F_dev_splits = feature_manager.load(session, dev, 'training span splits')
%time F_test_splits = feature_manager.load(session, test, 'training span splits')

### Get key mention

In [None]:
from cdr_feats import get_is_key

%time F_train_keys = feature_manager.create(session, train, 'training key ents', get_is_key)

In [None]:
%time F_dev_keys = feature_manager.update(session, dev, 'training key ents', False, get_is_key)
%time F_test_keys = feature_manager.update(session, test, 'training key ents', False, get_is_key)

In [None]:
%time F_train_keys = feature_manager.load(session, train, 'training key ents')
%time F_dev_keys = feature_manager.load(session, dev, 'training key ents')
%time F_test_keys = feature_manager.load(session, test, 'training key ents')

### Get title span feats

In [None]:
from cdr_feats import get_title_span_feats

%time F_train_title_span = feature_manager.create(session, train, 'training title span', get_title_span_feats)

In [None]:
%time F_dev_title_span = feature_manager.update(session, dev, 'training title span', False, get_title_span_feats)
%time F_test_title_span = feature_manager.update(session, test, 'training title span', False, get_title_span_feats)

In [None]:
%time F_train_title_span = feature_manager.load(session, train, 'training title span')
%time F_dev_title_span = feature_manager.load(session, dev, 'training title span')
%time F_test_title_span = feature_manager.load(session, test, 'training title span')

In [None]:
print repr(F_train_span)
print repr(F_train_splits)
print repr(F_train_keys)
print repr(F_train_title_span)

# GET LABELS

In [None]:
import cPickle
with open('data/doc_relation_dict.pkl', 'rb') as f:
    train_doc_dict, dev_doc_dict, test_doc_dict = cPickle.load(f)

In [None]:
train_labels = np.zeros(len(train))

for i in xrange(len(train)):
    candidate = train[i]
    pubmed_id, pairs = mesh_pairs_from_candidate(candidate)
    if pubmed_id not in train_doc_dict:
        continue
    for c, d in pairs:
        if (c, d) in train_doc_dict[pubmed_id]:
            train_labels[i] = 1
            break
    else:
        train_labels[i] = -1
            
with open('taggerone-train-labels.pkl', 'wb') as f:
    cPickle.dump(train_labels, f)

In [None]:
with open('taggerone-train-labels.pkl', 'rb') as f:
    train_labels = cPickle.load(f)

# LEARN GEN MODEL

In [None]:
import cPickle

with open('data/ctd.pkl', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = cPickle.load(ctd_f)

In [None]:
def cand_in_ctd_unspecified(c):
    _, pairs = mesh_pairs_from_candidate(c)
    return 1 if any([pair in ctd_unspecified for pair in pairs]) else 0

def cand_in_ctd_therapy(c):
    _, pairs = mesh_pairs_from_candidate(c)
    return 1 if any([pair in ctd_therapy for pair in pairs]) else 0

def cand_in_ctd_marker(c):
    _, pairs = mesh_pairs_from_candidate(c)
    return 1 if any([pair in ctd_marker for pair in pairs]) else 0

In [None]:
from utils import (
    gen_LF_text_btw,
    gen_LF_span,
    gen_LF_regex,
    gen_LF_regex_AB,
    gen_LF_regex_BA,
    gen_LF_regex_A,
    gen_LF_regex_B,
    ltp,
)

In [None]:
from random import random
import re
from snorkel.lf_helpers import get_tagged_text, get_text_between

#####################################################################################
##################################### BASIC CTD #####################################
#####################################################################################

def LF_in_ctd_unspecified(c):
    """Match against the ctd KB, with random negative supervision as well"""
    return -1 * cand_in_ctd_unspecified(c)

def LF_in_ctd_therapy(c):
    """Match against the ctd KB, with random negative supervision as well"""
    return -1 * cand_in_ctd_therapy(c)

def LF_in_ctd_marker(c):
    """Match against the ctd KB, with random negative supervision as well"""
    return cand_in_ctd_marker(c)

#####################################################################################
##################################### BASIC BTW #####################################
#####################################################################################

def LF_induce(c):
    return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

causal_past = ['induced', 'caused', 'due']

def LF_d_induced_by_c(c):
    return gen_LF_regex_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1)
def LF_d_induced_by_c_tight(c):
    return gen_LF_regex_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1)
def LF_d_augmented_by_c_tight(c):
    return gen_LF_regex_BA(c, '.{0,250}augmented by ', 1)

def LF_induce_name(c):
    return 1 if 'induc' in c.chemical.get_span().lower() else 0     

causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
    ) else 0

def LF_observe(c):
    return 1 if re.search(r'{{A}}.{0,20}observ.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']
def LF_d_treat_c(c):
    return gen_LF_regex_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d(c):
    return gen_LF_regex_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_treat_d(c):
    return gen_LF_regex_B(c, ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d_wide(c):
    return gen_LF_regex_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', -1)

def LF_didnot(c):
    return -1 if re.search(r'{{A}}.{0,20}does|did not.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

def LF_close_CD(c):
    return 1 if re.search(r'{{A}}.{2,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

def LF_close_DC(c):
    return 1 if re.search(r'{{B}}.{2,20}{{A}}', get_tagged_text(c), flags=re.I) else 0

def LF_c_d(c):
    return 1 if ('{{A}} {{B}}' in get_tagged_text(c)) else 0


def LF_c_induced_d(c):
    return 1 if (
        ('{{A}} {{B}}' in get_tagged_text(c)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else 0

def LF_improve_before_disease(c):
    return gen_LF_regex_B(c, 'improv.*', -1)

def LF_not_chemical(c):
    return gen_LF_regex_A(c, 'not.{0,3}', -1)

def LF_c_give_increases_d(c):
    return gen_LF_regex_AB(c, '.{0,10}giv.{0,25}increas.{0,25}', 1)
def LF_c_increases_d(c):
    return gen_LF_regex_AB(c, '.{0,25}increas.{0,25}', 1)

def LF_no_effect(c):
    return -1 if re.search('no effect on.{0,5}{{B}}', get_tagged_text(c), flags=re.I) else 0

pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return -1 if re.search(ltp(pat_terms) + '{{B}}', get_tagged_text(c), flags=re.I) else 0

uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
    return gen_LF_regex_A(c, ltp(uncertain) + '.*', -1)

def LF_induced_other(c):
    return gen_LF_regex(c, '{{A}}.{20,1000}-induced {{B}}', -1)

def LF_far_c_d(c):
    return gen_LF_regex_AB(c, '.{100,5000}', -1)
def LF_far_d_c(c):
    return gen_LF_regex_BA(c, '.{100,5000}', -1)

def LF_risk_d(c):
    return gen_LF_regex_B(c, 'risk of ', 1)

other_meaning = ['depression']
def LF_d_meaning(c):
    return -1 if (c[1].get_span().lower() in other_meaning) and (re.search(r'{{B}} (in|of)', get_tagged_text(c), flags=re.I)) else 0

def LF_develop_d_following_c(c):
    return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', get_tagged_text(c), flags=re.I) else 0

procedure = ['inject', 'administrat']
occur = ['occur']
following = ['following']
def LF_c_d_occur(c):
    return 1 if re.search(ltp(procedure) + '.{0,50}{{A}}.{0,50}{{B}}.{0,50}' + ltp(occur), get_tagged_text(c), flags=re.I) else 0
def LF_d_following_c(c):
    return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), get_tagged_text(c), flags=re.I) else 0

def LF_measure(c):
    return -1 if re.search('measur.{0,75}{{A}}', get_tagged_text(c), flags=re.I) else 0

def LF_level(c):
    return -1 if re.search('{{A}}.{0,25} level', get_tagged_text(c), flags=re.I) else 0

def LF_protein(c):
    return -1 if re.search('{{A}}.{0,50}protein', get_tagged_text(c), flags=re.I) else 0

def LF_gene(c):
    return -1 if re.search('{{A}} .{0,50} gene', get_tagged_text(c), flags=re.I) else 0

def LF_neg_d(c):
    return -1 if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else 0

def LF_preexist(c):
    return -1 if 'exist' in get_tagged_text(c) else 0

#####################################################################################
##################################### DEPEND CTD ####################################
#####################################################################################

def LF_ctd_marker_c_d(c):
    return LF_c_d(c) * cand_in_ctd_marker(c)

def LF_ctd_marker_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)

def LF_ctd_therapy_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)

def LF_ctd_unspecified_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)

def LF_ctd_unspecified_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)
    
#####################################################################################
###################################### LOGICAL ######################################
#####################################################################################

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0

#####################################################################################
###################################### ADVANCED #####################################
#####################################################################################

def LF_closer_chem(c):
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    sent = c.chemical.parent
    closest_other_chem = float('inf')
    for i in range(dis_end, min(len(sent.words), dis_end + dist / 2)):
        t = sent.ner_tags[i] 
        if t.startswith('Chemical') and t != sent.ner_tags[chem_start]:
            return -1
    for i in range(max(0, dis_start - dist / 2), dis_start):
        t = sent.ner_tags[i] 
        if t.startswith('Chemical') and t != sent.ner_tags[chem_start]:
            return -1
    return 0

def LF_closer_dis(c):
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    sent = c.chemical.parent
    for i in range(chem_end, min(len(sent.words), chem_end + dist / 8)):
        t = sent.ner_tags[i] 
        if t.startswith('Disease') and t != sent.ner_tags[dis_start]:
            return -1
    for i in range(max(0, chem_start - dist / 8), chem_start):
        t = sent.ner_tags[i] 
        if t.startswith('Disease') and t != sent.ner_tags[dis_start]:
            return -1
    return 0

In [None]:
LFs = [
    LF_in_ctd_therapy,
    LF_in_ctd_marker,
    LF_ctd_marker_c_d,
    LF_ctd_marker_induce,
    LF_ctd_therapy_treat,
    LF_ctd_unspecified_treat,
    LF_ctd_unspecified_induce,
    LF_induce,
    LF_d_induced_by_c,
    LF_d_induced_by_c_tight,
    LF_d_augmented_by_c_tight,
    LF_induce_name,
    LF_c_cause_d,
    LF_observe,
    LF_d_treat_c,
    LF_c_treat_d,
    LF_treat_d,
    LF_c_treat_d_wide,
    LF_c_d,
    LF_c_induced_d,
    LF_improve_before_disease,
    LF_not_chemical,
    LF_c_increases_d,
    LF_no_effect,
    LF_in_patient_with,
    LF_uncertain,
    LF_induced_other,
    LF_far_c_d,
    LF_far_d_c,
    LF_risk_d,
    LF_d_meaning,
    LF_develop_d_following_c,
    LF_d_following_c,
    LF_weak_assertions,
    LF_measure,
    LF_level,
    LF_protein,
    LF_neg_d,
    LF_preexist,
    LF_closer_chem,
    LF_closer_dis,
]

In [None]:
from snorkel.annotations import LabelManager

label_manager = LabelManager()

In [None]:
%time L_train = label_manager.create(session, train, 'LF Labels', f=LFs)
L_train

In [None]:
%time L_train = label_manager.load(session, train, 'LF Labels')
L_train

In [None]:
L_train.lf_stats(train_labels).sort('accuracy')

In [None]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.3)

In [None]:
deps

In [None]:
from snorkel.learning import GenerativeModel
from snorkel.learning.constants import *

gen_model = GenerativeModel(lf_prior=True, lf_propensity=True)
gen_model.train(L_train, step_size=0.1/L_train.shape[0], reg_type=2, epochs=15, decay=1.0, reg_param=0.00001)

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.hist(train_marginals, bins=30)
plt.show()

In [None]:
cov = np.where(np.abs(train_marginals - 0.5) > 1e-6)[0]
print "Non-0.5 examples: {0} ({1:.2f}%)".format(len(cov), 100 * float(len(cov)) / len(train_marginals))

gen_labels = (train_marginals[cov] > 0.5)
cov_gold = (1+train_labels[cov]) / 2
print "Non-0.5 accuracy wrt train labels: {0:.3f}%".format(100 * np.mean(gen_labels == cov_gold))

pos = np.where(cov_gold > 0.5)[0]
print "Positive class accuracy: {0:.3f}%".format(100 * np.mean(gen_labels[pos] == cov_gold[pos]))
neg = np.where(cov_gold < 0.5)[0]
print "Negative class accuracy: {0:.3f}%".format(100 * np.mean(gen_labels[neg] == cov_gold[neg]))

In [None]:
b = 0.5

fn_idxs = np.where((train_marginals <= b) * (train_labels > b))[0]
fp_idxs = np.where((train_marginals > b) * (train_labels <= b))[0]
tp_idxs = np.where((train_marginals > b) * (train_labels > b))[0]
tn_idxs = np.where((train_marginals <= b) * (train_labels <= b))[0]


n_tp, n_fp, n_fn, n_tn = len(tp_idxs), len(fp_idxs), len(fn_idxs), len(tn_idxs)
p = float(n_tp) / (n_tp + n_fp) if n_tp > 0 else 0
r = float(n_tp) / (n_tp + n_fn) if n_tp > 0 else 0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0

print '#tp = {0}\n#fp = {1}\n#fn = {2}\n#tn = {3}'.format(n_tp, n_fp, n_fn, n_tn)
print 'precision = {0:.3f}\nrecall = {1:.3f}\nf1 = {2:.3f}'.format(p, r, f1)

print(sum([n_tp, n_fp, n_fn, n_tn]))

# TRAINING TRICKS ON TRICKS ON TRICKS

In [None]:
from snorkel.learning.utils import ListParameter, RangeParameter

In [None]:
from snorkel.learning.fastmulticontext import get_matrix_keys
train_embed_xs = get_matrix_keys([F_train_span, F_train_title_span])

In [None]:
from itertools import product
from utils import CDRFMCT, CDRRandomSearch

epoch_param = RangeParameter('epoch', 20, 200, step=20)
lr_param = RangeParameter('lr', 1e-5, 0.1, step=0.5, log_base=10)
lambda_param = RangeParameter('lr', 1e-5, 10, step=1, log_base=10)
dim_param = RangeParameter('dim', 25, 150, step=25)
minct_opts = [1, 2, 3, 5, 7, 10, 12, 15]
minct_param = ListParameter('min_ct', minct_opts)

disc_model = CDRFMCT()

searcher = CDRRandomSearch(disc_model, train_marginals, train_embed_xs, 20,
                           epoch_param, lr_param, dim_param, lambda_param, minct_param)

In [None]:
from snorkel.learning.fastmulticontext import get_matrix_keys
dev_embed_xs = get_matrix_keys([F_dev_span, F_dev_title_span])

In [None]:
from snorkel.models import Corpus
dev_corpus = session.query(Corpus).filter(Corpus.name == 'CDR Development').one()

D = searcher.fit(dev_embed_xs, F_dev_keys.toarray(), dev_doc_dict, dev, dev_corpus, b=0.5,
                 raw_xs=F_train_keys.toarray(), n_threads=4, n_print=10000)

In [None]:
D

In [None]:
disc_model.train(train_marginals, train_embed_xs, raw_xs=F_train_keys.toarray(),
                 epoch=160, dim=100, lr = 0.001, min_ct = [10,5], lambda_l2 = 0.01,
                 seed=1701, n_threads=4)

In [None]:
test_embed_xs = get_matrix_keys([F_test_span, F_test_title_span])

In [None]:
test_corpus = session.query(Corpus).filter(Corpus.name == 'CDR Test').one()
buckets = disc_model.score(test_embed_xs, F_test_keys.toarray(), test_doc_dict, test, test_corpus, b=0.5)