# Disease Norm

In this example, we'll be writing an application to extract *mentions of* diseases from Pubmed abstracts, using annotations from the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial, which has 5 parts, walks through the process of constructing a model to classify _candidate_ disease mentions as either true (i.e., that it is truly a mention of a disease) or false.

# Loading Candidates + Annotations

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

from snorkel.models import candidate_subclass

Disease = candidate_subclass('Disease', ['disease'])

In [2]:
from snorkel.models import CandidateSet

train = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Training Candidates').one()
print len(train)
dev = session.query(CandidateSet).filter(CandidateSet.name == 'CDR Development Candidates').one()
print len(dev)

28087
27896


In [3]:
from snorkel.annotations import LabelManager

label_manager = LabelManager()

L_gold_train = label_manager.load(session, train, "CDR Training Label Set")
print L_gold_train.shape
L_gold_dev = label_manager.load(session, dev, "CDR Development Label Set")
print L_gold_dev.shape

(28087, 1)
(27896, 1)


# Process / Load Features

In [4]:
from snorkel.annotations import FeatureManager

feature_manager = FeatureManager()

Process:

In [None]:
%%time
try:
    F_train = feature_manager.load(session, train, 'Train Features')
    F_dev   = feature_manager.load(session, dev, 'Train Features')
except:
    F_train = feature_manager.create(session, train, 'Train Features')
    F_dev   = feature_manager.update(session, dev, 'Train Features', False)

# Create canonical dictionary

### Load the MESH ID -> CID mapping

In [None]:
from cPickle import load
MESH_to_CID = load(open('MESH_to_CID.pkl', 'rb'))
diseases    = load(open('diseases.pkl', 'rb'))

### Create a canonical dictionary (CD)

In [None]:
from utils import CanonicalDictionary
cd = CanonicalDictionary(MESH_to_CID)

### Add MESH to CD

In [None]:
# Load MESH
from utils import load_mesh_raw
mesh_entries = load_mesh_raw('data/desc2016.xml')

In [None]:
# Add MESH to cd
for entry in mesh_entries:
    mid, ps, terms = entry
    paths = [[p[0]] + p[1:].split('.') for p in ps]
    for term in terms:
        cd.add_term(term, mid, tree_paths=paths)

len(cd.term_to_sids)

### Add MEDIC to CD

Custom CTD diseases dictionary made from MESH category C + OMIM

In [None]:
from utils import load_MEDIC, load_mesh_raw
medic_entries, MEDIC_to_CID = load_MEDIC()

In [None]:
# Add MESH to cd
for entry in medic_entries:
    if entry.id.startswith("MESH"):
        mid = entry.id.split(":")[1]
    elif len(entry.parent_ids) > 0 and entry.parent_ids[0].startswith("MESH"):
        mid = entry.parent_ids[0].split(":")[1]
    else:
        raise KeyError(entry)
    
    paths = []
    for p in entry.tree_nums:
        x = p.split("/")[0]
        paths.append([x[0]] + x[1:].split('.'))
    
    terms = [entry.name] + entry.synonyms
    for term in terms:
        cd.add_term(term, mid, tree_paths=paths)
        
len(cd.term_to_sids)

### Add UMLS to CD

This may or may not be all of the UMLS... file from Jason

In [None]:
with open('cui2mesh.tsv', 'rb') as f:
    for line in f:
        term, cui, mid = line.rstrip('\n').split('\t')
        cd.add_term(term, mid)

len(cd.term_to_sids)

In [None]:
from cPickle import dump
dump(cd, open('cd.pkl', 'wb'))

# Writing some multinomial LFs

# NOTE: Beware of LF rollback bug!!!

## TYPE I LF: Subsets of MESH dictionary

In [None]:
SEEN_GLOBAL = defaultdict(set)

## MESH exact match

In [None]:
POS_DEPTH = 3
NEG_DEPTH = 3
def LFG_CD_match(c, p, key_mod=None, seen_global=None, max_paths_per_sid=1):
    """
    Given a candidate c, some transformed candidate disease phrase p,
    and an optional key name modifier key_mod to be appended, return a generator
    of key, value pairs
    """
    if p in cd.term_to_sids:
        for sid in cd.term_to_sids[p]:
            cid   = cd.sid_to_cid[sid] if sid in cd.sid_to_cid else -1
            paths = cd.tree_paths[sid]
            for path in paths[:max_paths_per_sid]:
                    
                # NOTE: path may be shorter than max depth if higher up in the tree (e.g. 'cancer', 'ischemia')!
                key = "-".join(path[:POS_DEPTH]) if cid > 0 else "-".join(path[:NEG_DEPTH])
                        
                # CHECK AGAINST / ADD TO GLOBAL SEEN SET!
                # To make sure that relaxations of an LF don't overlap with each other on any given candidate
                if seen_global is None or c.id not in seen_global[key]:
                    if seen_global is not None:
                        seen_global[key].add(c.id)
                    if key_mod:
                        key += "-" + key_mod
                    yield key, cid

In [None]:
def LFG_MESH_exact(c):
    p = c.disease.get_span().lower()
    return LFG_CD_match(c, p)

In [None]:
%%time
try:
    L_train_exact_match = label_manager.load(session, train, 'LF Training Labels -- Exact Match')
except:
    L_train_exact_match = label_manager.create(
        session, train, 'LF Training Labels -- Exact Match', LFG_MESH_exact)

### Drop leading modifiers

In [None]:
MOD_RGX = r'JJ.*|VB.*|RB.*'

def LFG_drop_leading_modifiers(c):
    words    = c.disease.get_attrib_tokens()
    pos_tags = c.disease.get_attrib_tokens('pos_tags')
    while re.match(MOD_RGX, pos_tags[0]):
        words    = words[1:]
        pos_tags = pos_tags[1:]
        p   = " ".join(words)
        
        # Hackey, but works for now...
        g = LFG_CD_match(c, p, key_mod="DJ")
        if len(list(g)) > 0:
            for key, cid in LFG_CD_match(c, p, key_mod="DJ"):
                yield key, cid

In [None]:
%%time
try:
    L_train_drop_leading = label_manager.load(session, train, 'LF Training Labels -- Drop Leading')
except:
    L_train_drop_leading = label_manager.create(
        session, train, 'LF Training Labels -- Drop Leading', LFG_drop_leading_modifiers)

### Remove common words

In [None]:
REMOVE_COMMON = r'.*induced|patient.*|drug|inhibitor|\d+|human|mouse|mice|rats?|with|syndrome|famil.*|s$|low(er)?|upper|left|right|top|bottom|subjects?'
def remove_common(c):
    p = re.sub(r'\s\s+', ' ', re.sub(REMOVE_COMMON, '', c.disease.get_span().lower())).strip()
    return p

In [None]:
def LFG_MESH_exact_remove_common(c):
    p = remove_common(c)
    return LFG_CD_match(c, p, key_mod="RC")

In [None]:
%%time
try:
    L_train_remove_common = label_manager.load(session, train, 'LF Training Labels -- Remove Common')
except:
    L_train_remove_common = label_manager.create(
        session, train, 'LF Training Labels -- Remove Common', LFG_MESH_exact_remove_common)

## MESH TF-IDF cosine match: POSITIVE terms

Only cosine match with _positive_ (i.e. C, F03 MESH disease terms) here!

**Note: we want to avoid positive reinforcement amongst these, so track seen / only emit one:**

In [None]:
SEEN_COSINE_POS = defaultdict(set)

In [None]:
%%time
from entity_norm import CanonDictVectorizer 

# Create a vectorizer based around this 
cd_vectorizer = CanonDictVectorizer(cd.term_to_sids, other_phrases=[])

# Vectorize the dictionary
D_pos   = cd_vectorizer.vectorize_phrases(cd.pos_terms)
D_pos_t = D_pos.T
D_pos_t

In [None]:
POS_DEPTH = 3
NEG_DEPTH = 3
THRESH    = 0.75
def LFG_CD_cosine_match(c, p, vectorizer, D_t, terms, thresh=THRESH, seen_global=None, max_paths_per_sid=1, key_mod=None):
    cx = vectorizer.vectorize_phrases([p])
    m  = cx * D_t
    m  = m.tocoo()

    best_match = defaultdict(lambda : (0, None))
    for i, s in enumerate(m.data):
        if s > thresh:
            j = m.col[i]
            t = terms[j]
            for sid in cd.term_to_sids[t]:
                cid = cd.sid_to_cid[sid] if sid in cd.sid_to_cid else -1
                for path in cd.tree_paths[sid][:max_paths_per_sid]:
                    key  = "-".join(path[:POS_DEPTH]) if cid > 0 else "-".join(path[:NEG_DEPTH])
                    if s > best_match[key][0]:
                        best_match[key] = (s, cid)

    for key, x in best_match.iteritems():
        s, cid = x
        
        # CHECK AGAINST / ADD TO GLOBAL SEEN SET!
        # To make sure that relaxations of an LF don't overlap with each other on any given candidate
        if seen_global is None or c.id not in seen_global[key]:
            if seen_global is not None:
                seen_global[key].add(c.id)
            key += "-c"
            if key_mod:
                key += key_mod
            yield key, cid

In [None]:
def LFG_CD_cosine_match_pos(c):
    p = c.disease.get_span().lower()
    return LFG_CD_cosine_match(c,p, cd_vectorizer, D_pos_t, cd.pos_terms, seen_global=SEEN_COSINE_POS)

In [None]:
%%time
try:
    L_train_cosine_pos = label_manager.load(session, train, 'LF Training Labels -- TF-IDF Pos Terms')
except:
    L_train_cosine_pos = label_manager.create(
        session, train, 'LF Training Labels -- TF-IDF Pos Terms', LFG_CD_cosine_match_pos)

## Transform (T) -> pos. cosine match

In [None]:
def LFG_CD_cosine_match_RC(c):
    p = remove_common(c)
    if p != c.disease.get_span().lower():
        for key, cid in LFG_CD_cosine_match(c, p, cd_vectorizer, D_pos_t, cd.pos_terms, key_mod="-RC", seen_global=SEEN_COSINE_POS):
            yield key, cid

In [None]:
%%time
try:
    L_train_cosine_pos_T1 = label_manager.load(session, train, 'LF Training Labels -- TF-IDF Pos Terms T1')
except:
    L_train_cosine_pos_T1 = label_manager.create(
        session, train, 'LF Training Labels -- TF-IDF Pos Terms T1', LFG_CD_cosine_match_RC)

## Neg cosine match

In [None]:
# Vectorize the dictionary
D_neg   = cd_vectorizer.vectorize_phrases(cd.neg_terms)
D_neg_t = D_neg.T
D_neg_t

In [None]:
def LFG_CD_cosine_match_neg(c):
    p = c.disease.get_span().lower()
    return LFG_CD_cosine_match(c, p, cd_vectorizer, D_neg_t, cd.neg_terms, thresh=0.85)

In [None]:
%%time
try:
    L_train_cosine_neg = label_manager.load(session, train, 'LF Training Labels -- TF-IDF Neg Terms')
except:
    L_train_cosine_neg = label_manager.create(
        session, train, 'LF Training Labels -- TF-IDF Neg Terms', LFG_CD_cosine_match_neg)

# Putting in some negative LFs

In [None]:
import re
from lf_terms import *
from snorkel.lf_helpers import get_left_tokens, get_right_tokens
from utils import *
from Disease_Tagging_Tutorial_LFs import *
chemicals = load_chemdner_dictionary()

def LF_organs(c):
    phrase = " ".join(c[0].get_attrib_tokens()).lower()
    return -1 if phrase in organs else 0      

def LF_chemical_name(c):
    phrase = " ".join(c[0].get_attrib_tokens())
    return -1 if phrase in chemicals and not phrase.isupper() and phrase.lower() not in cd.term_to_sids else 0

def LF_bodypart(c):
    phrase = re.sub(r's$', '', " ".join(c[0].get_attrib_tokens()).lower())
    return -1 if phrase in bodypart else 0  

def LF_protein_chemical_abbrv(c):
    '''Gene/protein/chemical name'''
    lemma = " ".join(c[0].get_attrib_tokens('lemmas'))
    return -1 if re.search("\d+",lemma) else 0

def LF_base_pair_seq(c): 
    lemma = " ".join(c[0].get_attrib_tokens('lemmas'))
    return -1 if re.search("^[GACT]{2,}$",lemma) else 0

LFs_false = [LF_chemical_name,
             LF_organs,
             LF_bodypart,
             LF_protein_chemical_abbrv,
             LF_base_pair_seq,
             #LF_too_vague,
             #LF_neg_surfix,
             LF_non_common_disease,
             LF_non_disease_acronyms,
             #LF_pos_in,
             LF_gene_chromosome_link,
             LF_right_window_incomplete,
             #LF_negative_indicator
            ]

In [None]:
%%time
try:
    L_train_false_1 = label_manager.load(session, train, 'LF Training Labels -- False 1')
except:
    L_train_false_1 = label_manager.create(
        session, train, 'LF Training Labels -- False 1', LFs_false)

### More neg. LFs

In [None]:
NEG_COMMON_RGX = r'(finding|disease|syndrome|marker|defecit|.*event|mean|median|mg)s?'
def LF_common_neg_phrases(c):
    p = c.disease.get_span().lower()
    return -1 if re.match(NEG_COMMON_RGX, p) else 0

NEG_AFTER_WORDS = frozenset(['of', 'to'])
def LF_neg_after(c):
    rw = get_right_tokens(c, window=1, attrib='lemmas')
    return -1 if len(rw) > 0 and rw[0] in NEG_AFTER_WORDS and c.disease.get_span().lower() not in cd.term_to_sids else 0

def LF_after_num(c):
    lw = get_left_tokens(c, window=1, attrib='lemmas')
    return -1 if len(lw) > 0 and re.match(r'\d+', lw[0]) else 0

def LF_too_short(c):
    p = c.disease.get_span().lower()
    return -1 if len(p) < 3 else 0

BAD_ENDINGS_RGX = r'(type|trait|cell)s?$'
def LF_bad_endings(c):
    p = c.disease.get_span().lower()
    return -1 if re.search(BAD_ENDINGS_RGX, p) else 0

BAD_MESH_TERMS = frozenset(['disease', 'diseases', 'conversion'])
def LF_bad_MESH_entries(c):
    p = c.disease.get_span().lower()
    return -1 if p in BAD_MESH_TERMS else 0

LFs_false_2 = [
    LF_common_neg_phrases,
    LF_neg_after,
    LF_after_num,
    LF_too_short,
    LF_bad_endings,
    LF_bad_MESH_entries
]

In [None]:
%%time
try:
    L_train_false_2 = label_manager.load(session, train, 'LF Training Labels -- False 2')
except:
    L_train_false_2 = label_manager.create(
        session, train, 'LF Training Labels -- False 2', LFs_false_2)

# Combine all the LFs

Also form the binarized version of the LF matrix for doing DISEASE vs. OTHER tagging

In [None]:
from utils import binarize_LF_matrix, get_binarized_score
from snorkel.annotations import merge_annotations

L_train = merge_annotations([
        L_train_exact_match
        , L_train_drop_leading
        , L_train_remove_common
        , L_train_cosine_pos
        , L_train_cosine_pos_T1
        , L_train_cosine_neg
        , L_train_false_1
        , L_train_false_2
    ])

L_train_b = binarize_LF_matrix(L_train)
L_train_b

# Run the generative model

For DISEASE vs. OTHER tagging

In [None]:
from snorkel.learning import NaiveBayes

gen_model_train = NaiveBayes()
%time gen_model_train.train(L_train_b, n_iter=10000, rate=1e-1, verbose=True)

In [None]:
yp_gt_train = gen_model_train.predict(L_train_b, b=0.5)
get_binarized_score(yp_gt_train, L_gold_train)

### Printing LF stats

In [None]:
# Print LF stats...
from snorkel.learning import odds_to_prob
lfs = L_train.lf_stats(labels=L_gold_train, est_accs=odds_to_prob(gen_model.w))
lfs.nlargest(50, "coverage")

# Error analysis

First, collect error buckets:

In [None]:
from random import shuffle
N_train = L_gold_train.shape[0]

fps    = []
fns    = []
fns_na = []
for i in range(N_train):
    if yp[i] > 0 and L_gold_train[i] < 0:
        fps.append(i)
    elif yp[i] <= 0 and L_gold_train[i] > 0:
        if yp[i] == 0:
            fns_na.append(i)
        else:
            fns.append(i)

shuffle(fps)
shuffle(fns)
shuffle(fns_na)

print len(fps)
print len(fns)
print len(fns_na)

Next, visualize in the `Viewer`:

In [None]:
from snorkel.viewer import SentenceNgramViewer
fp_cands = [L_train.get_candidate(i) for i in fps[:100]]
svp      = SentenceNgramViewer(fp_cands, session)
svp

In [None]:
c = svp.get_selected()
c

Get all the associated labels:

In [None]:
from snorkel.models import Label
session.query(Label).filter(Label.candidate == c).all()

In [None]:
from snorkel.learning.gen_learning import odds_to_prob
i = L_train.get_row_index(c)

for j in L_train.getrow(i).nonzero()[1]:
    print L_train.get_key(j), odds_to_prob(gen_model.w[j]), int(L_train[i,j])

# Training the Discriminative Model

In [None]:
from snorkel.learning import LogReg

train_marginals = gen_model_train.marginals(L_train_b)

disc_model = LogReg()
disc_model.train(F_train, train_marginals, n_iter=2000, rate=1e-3, mu=1e-6)

In [None]:
yp_d_train = disc_model.predict(F_train)
get_binarized_score(yp_d_train, L_gold_train)

In [None]:
yp_d_dev = disc_model.predict(F_dev, b=0.5)
get_binarized_score(yp_d_dev, L_gold_dev)

## What happens if we override all _unambiguous_ direct matches?

In [None]:
# Try overriding with any exact matches...
yp_d_dev_um = np.zeros(L_gold_dev.shape[0])
for i,c in enumerate(dev):
    pos = 0
    neg = 0
    for lf_name, label in LFG_MESH_exact(c):
        if label > 0:
            pos += 1
        else:
            neg += 1
    
    for lf_name, label in LFG_drop_leading_modifiers(c):
        if label > 0:
            pos += 1
        else:
            neg += 1
    
    for lf_name, label in LFG_MESH_exact_remove_common(c):
        if label > 0:
            pos += 1
        else:
            neg += 1
    
    if neg > 0 and pos == 0:
        yp_d_dev_um[i] = -1
    elif pos > 0 and neg == 0:
        yp_d_dev_um[i] = 1
    else:
        yp_d_dev_um[i] = yp_d_dev[i]

In [None]:
get_binarized_score(yp_d_dev_um, L_gold_dev)

# We can also train the multinomial generative model...

In [None]:
from snorkel.learning.learning_mn import assemble_mn_format, LogReg

# Get data in the correct format
Xs, mn_maps, mn_inv_maps, nz_idxs = assemble_mn_format(L_train)

# Run multinomial model
gen_model = LogReg()
gen_model.train(Xs, n_iter=100, rate=1e-2, w0=np.ones(L_train.shape[1]))

In [None]:
from utils import get_mn_score

train_marginals = gen_model.marginals(Xs)
N_pos_train     = sum([1 for i in range(L_gold_train.shape[0]) if L_gold_train[i,0] > 0])
predicted       = [mn_inv_maps[i][np.argmax(m)] for i,m in enumerate(train_marginals)]
get_mn_score(predicted, L_gold_train[nz_idxs], N_total_pos=N_pos_train)