# II. Weakly Supervised Named Entity Recognition (NER)

We'll use the public [BioCreative V Chemical Disease Relation](https://biocreative.bioinformatics.udel.edu/tasks/biocreative-v/track-3-cdr/) (BC5CDR) dataset, focusing on Chemical entities. 

See `../applications/BC5CDR/` for the complete labeling function set used in our paper. 

## Installation Instructions

- Trove requires access to the [Unified Medical Language System](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html) (UMLS) which is available after signing up for an account. If you already have a UMLS database instance, you can extract RRF files by running `dump_umls_rrfs.sh`
- Unzip the preprocessed BioCreative V CDR chemical dataset `bc5cdr.zip`

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0,'../trove')


## 1. Load Unlabeled Data & Define Entity Classes

### A. Load Preprocessed Documents
This notebook assumes documents have already been preprocessed for sentence boundary detection and dumped into JSON format. See `preprocessing/README.md` for details.


In [None]:
%%time
import transformers
from trove.dataloaders import load_json_dataset

tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

data_dir = "bc5cdr/"
dataset = {
    split : load_json_dataset(f'{data_dir}/{split}.cdr.chemical.json', tokenizer)
    for split in ['train', 'dev', 'test']
}

##### B. Define Entity Categories
In popular biomedical annotators such as [NCBO BioPortal](https://bioportal.bioontology.org/annotator), we configure the annotator by defining our entity
Trove uses this same style of interface in an API form. For `CHEMICAL` tagging, we define an entity class consisting of [UMLS Semantic Network](https://semanticnetwork.nlm.nih.gov/) types mapped to $\{0,1\}$. The semantic network defines 127 concept categories called _Semantic Types_ (e.g., Disease or Syndrome , Medical Device) which are mappable to 15 coarser-grained _Semantic Groups_ (e.g., Anatomy, Chemicals & Drugs, Disorders). 

We use the _Chemicals & Drugs_ (CHEM) semantic group as the basis of our positive class label $1$, removing some categories (e.g., Gene or Genome) that do not match the definition of chemical as outlined in the BC5CDR annotation guidelines. Non-chemical STYs define our negative class label $0$.

In [None]:
import pandas as pd

# load the chemical entity definition
entity_def = pd.read_csv('chemical_semantic_types.tsv', sep='\t')
class_map = {row.TUI:row.LABEL for row in entity_def.itertuples()}


## 2. Load Labeling Sources
### A. Unified Medical Language System (UMLS) Metathesaurus
The UMLS Metathesaurus is a convenient source for deriving labels, since it provides over 200 source vocabularies (terminologies) with consistent entity categorization provided by the UMLS Semantic Network.

The first time this is run, Trove requires access to the source RRF files `{MRSTY.RRF, MRSAB.RRF, MRCONSO.RRF}` originally used to create the database instance. 


In [None]:
from trove.labelers.umls import UMLS

# uncommenting this line will reset the Trove UMLS cache
#UMLS.reset()

# initialize UMLS
backend = 'pandas'
if not UMLS.is_initalized(backend=backend):
    UMLS.init_from_rrfs(indir="/users/fries/desktop/RRFs", backend=backend)
    

We apply some minimal preprocessing to each source vocabularies term set, as outlined in the Trove paper. The most important settings are:
- `SmartLowercase()`, a string matching heuristic for preserving likely abbreviations and acronyms
- `min_char_len`, `filter_rgx`, filters for terms that are single characters or numbers  

Other choices are largely for speed purposes, such as restricting the max token length used for string matching. 


In [None]:
%%time
from trove.transforms import SmartLowercase

# options for filtering terms
config = {
    "type_mapping"  : "TUI",  # TUI = semantic types, CUI = concept ids
    'min_char_len'  : 2,
    'max_tok_len'   : 6,
    'min_dict_size' : 500,
    'stopwords'     : None,
    'transforms'    : [SmartLowercase()],
    'languages'     : {"ENG"},
    'filter_sabs'   : {"SNOMEDCT_VET"},
    'filter_rgx'    : r'''^[-+]*[0-9]+([.][0-9]+)*$'''  # filter numbers
}
   
umls = UMLS(backend=backend, **config)


### B. ADAM Biomedical Abbreviations

In [None]:
# TBD

In [None]:
%%time
import numpy as np

def map_entity_classes(dictionary, class_map):
    """
    Given a dictionary, create the term
    """
    k = len([y for y in set(class_map.values()) if y != -1])
    ontology = {}
    for term in dictionary:
        proba = np.zeros(shape=k).astype(np.float32)
        for cls in dictionary[term]:
            idx = class_map[cls]
            proba[idx] += 1
        ontology[term] = proba / np.sum(proba)
    return ontology

# These are the top 10 ontologies as ranked by term overlap with the BC5CDR training set
terminologies = ['CHV', 'SNOMEDCT_US', 'NCI', 'MSH']

ontologies = {
    sab : map_entity_classes(umls.terminologies[sab], class_map)
    for sab in terminologies
}


## 3. Create Sequence Labeling Functions
### A. Guideline Labeling Functions

Annotation guidelines, the instructions provided to domain experts when labeling training data, can have a big impact on the generalizability of named enity classifiers. These instructions include seeminly simple choices such as whether to include determiners in entity spans ("the XXX") or more complex tagging choices like not labeling negated mentions of drugs. These choices are baked into the dataset and expensive to change. 

With weak supervision, many of these annotation assumptions can encoded as labeling functions, making training set changes faster, more flexible, and lower cost. For our `Chemical` labeling functions, we use the instructions provided [here](https://biocreative.bioinformatics.udel.edu/media/store/files/2015/bc5_CDR_data_guidelines.pdf) (pages 5-6) to create small dictionaries encoding some of these guidelines. Note that these can be easily expanded on, and in some cases complex rules (e.g., not annotating polypeptides with more than 15 amino acids) can be coupled with richer structured resources to create more sophisticated rules. 

We also fine it useful to include labeling functions that exclude numbers and punctuation tokens, another common flag in online biomedical annotators. 


In [None]:
from trove.labelers.labeling import (
    OntologyLabelingFunction,
    DictionaryLabelingFunction, 
    RegexEachLabelingFunction
)

MAX_NGRAMS = 8

# load our guideline dictionaries
df = pd.read_csv('data/bc5cdr_guidelines.tsv', sep='\t')
guidelines = {
    t:np.array([0.,1.]) if y==1 else np.array([1.,0.]) 
    for t,y in zip(df.TERM, df.TERM)
}

# use guideline negation examples as an additional stopword list
stopwords = {t:0 for t in df[df.LABEL==0].TERM}

guideline_lfs = [
    OntologyLabelingFunction('LF_guidelines', guidelines, max_ngrams=MAX_NGRAMS),
    DictionaryLabelingFunction('LF_punct', set('!"#$%&*+,./:;<=>?@[\\]^_`{|}~'), 2),
    RegexEachLabelingFunction('LF_numbers_rgx', [r'''^[-]*[1-9]+[0-9]*([.][0-9]+)*$'''], 2)
]


### B. Semantic Type Labeling Functions

The bulk of our supervision comes from structured knowedge sources such as medical ontologies. 

In [None]:
ontology_lfs = [
    OntologyLabelingFunction(
        f'LF_{name}', 
        ontologies[name], 
        max_ngrams=MAX_NGRAMS, 
        stopwords=stopwords
    )
    for name in ontologies
]

### C. SynSet Labeling Functions

For biomedical concepts, abbreviations and acronymns (more generally "short forms") are a large source of ambiguity. 
These can be ambiguous to human readers as well, so authors of PubMed abstract typically define ambiguous terms when they are introduced in text. We can take adavantage of this redundancy to both handle ambiguous mentions and identify out-of-ontology short forms using classic text mining techniques such as the [Schwartz-Hearst algorithm](https://psb.stanford.edu/psb-online/proceedings/psb03/schwartz.pdf).

In [None]:
# TBD

# synset_lfs = [
#     SchwartzHearstLabelingFunction('LF_umls_abbrvs_1', class_dictionaries[1], 1, stopwords=stopwords),
#     SchwartzHearstLabelingFunction('LF_umls_abbrvs_2', class_dictionaries[2], 2),
#     SynSetLabelingFunction('LF_specialist_synset_1', specialist_1, 1, stopwords=stopwords),
#     SynSetLabelingFunction('LF_specialist_synset_2', specialist_2, 2)
# ]


### D. Task-specific Labeling Functions

Ontology-based labeling functions can do suprisingly well on their own, but we can get more performance gains by adding custom labeling functions. For this demo, we focus on simple rules that are easy to create via data exploration but any existing rule-based model can be transformed into a labeling function. 

In [None]:
task_specific_lfs = []

# We noticed parentheses were causing errors, so this identifies negative examples, e.g. (n=100), (10%)
parens_rgxs = [
    r'''[(](P|p|n)\s*([><=]+|(less|great)(er)*)|(ml|mg|kg|g|(year|day|month)[s]*)[)]|[(][0-9]+[%][)]'''
]
parens_rgxs = [re.compile(rgx, re.I) for rgx in parens_rgxs]
task_specific_lfs.append(RegexLabelingFunction('LF_parentheses_rgx', parens_rgxs, 2))



In [None]:
lfs = guideline_lfs + ontology_lfs + synset_lfs + task_specific_lfs

## 4. Construct the Label Matrix $\Lambda$
### A. Apply Sequence Labeling Functions

In [None]:
%%time
import itertools
from trove.labelers.core import SequenceLabelingServer

X_sents = [
    dataset['train'].sentences,
    dataset['dev'].sentences,
    dataset['test'].sentences,
]

labeler = SequenceLabelingServer(num_workers=4)
L_sents = labeler.apply(lfs, X_sents)


In [None]:
import itertools

splits = ['train', 'dev', 'test']
tag2idx = {'O':2, 'I-Chemical':1}

X_words = [
    np.array(list(itertools.chain.from_iterable([s.words for s in X_sents[i]]))) 
    for i,name in enumerate(splits)
]

X_seq_lens = [
    np.array([len(s.words) for s in X_sents[i]])
    for i,name in enumerate(splits)
]

X_doc_seq_lens = [  
    np.array([len(doc.sentences) for doc in dataset[name].documents]) 
    for i,name in enumerate(splits)
]

Y_words = [
    [dataset['train'].tagged(i)[-1] for i in range(len(dataset['train']))],
    [dataset['dev'].tagged(i)[-1] for i in range(len(dataset['dev']))],
    [dataset['test'].tagged(i)[-1] for i in range(len(dataset['test']))],
]

Y_words[0] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[0]))])
Y_words[1] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[1]))])
Y_words[2] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[2]))])


### B. Build the Label Matrix

In [None]:
%%time
from scipy.sparse import dok_matrix, vstack, csr_matrix

def create_word_lf_mat(Xs, Ls, num_lfs):
    """
    Create word-level LF matrix from LFs indexed by sentence/word
    0 words X lfs
    1 words X lfs
    2 words X lfs
    ...
    
    """
    Yws = []
    for sent_i in range(len(Xs)):
        ys = dok_matrix((len(Xs[sent_i].words), num_lfs))
        for lf_i in range(num_lfs):
            for word_i,y in Ls[sent_i][lf_i].items():
                ys[word_i, lf_i] = y
        Yws.append(ys)
    return csr_matrix(vstack(Yws))

L_words = [
    create_word_lf_mat(X_sents[0], L_sents[0], len(lfs)),
    create_word_lf_mat(X_sents[1], L_sents[1], len(lfs)),
    create_word_lf_mat(X_sents[2], L_sents[2], len(lfs)),
]


### C. Inspect Labeling Function Performance

In [None]:
from trove.metrics.analysis import lf_summary

lf_summary(L_words[0], Y=Y_words[0], lf_names=['CHV', 'SNOMEDCT_US', 'NCI', 'MSH'])


## 5. Train the Label Model

In [None]:
# Trove uses a different internal mapping for labeling function abstains
def convert_label_matrix(L):
    # abstain is -1
    # negative is 0
    L = L.toarray().copy()
    L[L == 0] = -1
    L[L == 2] = 0
    return L

L_words_hat = [
    convert_label_matrix(L_words[0]),
    convert_label_matrix(L_words[1]),
    convert_label_matrix(L_words[2])
]

Y_words_hat = [
    np.array([0 if y == 2 else 1 for y in Y_words[0]]),
    np.array([0 if y == 2 else 1 for y in Y_words[1]]),
    np.array([0 if y == 2 else 1 for y in Y_words[2]])
]


In [None]:
np.random.seed(1234)

param_grid = {
    'lr': [0.01, 0.005, 0.001, 0.0001],
    'l2': [0.001, 0.0001],
    'n_epochs': [50, 100, 200, 600, 700, 1000],
    'prec_init': [0.6, 0.7, 0.8, 0.9],
    'optimizer': ["adamax"], 
    'lr_scheduler': ['constant']
}

model_class_init = {
    'cardinality': 2, 
    'verbose': True
}

n_model_search = 50
num_hyperparams = functools.reduce(lambda x,y:x*y, [len(x) for x in param_grid.values()])
print("Hyperparamater Search Space:", num_hyperparams)

label_model, best_config = grid_search(LabelModel, 
                                       model_class_init, 
                                       param_grid,
                                       train = (L_train, Y_train, train_seq_lens),
                                       dev = (L_dev, Y_dev, dev_seq_lens),
                                       n_model_search=n_model_search, 
                                       val_metric='f1', 
                                       seq_eval=True,
                                       seed=1234)

In [None]:
from trove.analysis.error_analysis import get_coverage, eval_label_model

for i in range(3):
    get_coverage(L_words_hat[i], Y_words_hat[i])
    print("IO Tag Format")
    eval_label_model(label_model, L_words_hat[i], Y_words_hat[i], X_seq_lens[i])
    print("BIO Tag Format")
    eval_label_model(label_model, L_words_hat[i], Y_words_gold_hat[i], X_seq_lens[i])
    print('------')


## 6. Export Proba Conll