# II. Weakly Supervised Named Entity Recognition (NER)

We'll use the public [BioCreative V Chemical Disease Relation](https://biocreative.bioinformatics.udel.edu/tasks/biocreative-v/track-3-cdr/) (BC5CDR) dataset, focusing on Chemical entities. 

See `../applications/BC5CDR/` for the complete labeling function set used in our paper. 

## Installation Instructions

- Trove requires access to the [Unified Medical Language System (UMLS)](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html) which is freely available after signing up for an account with the National Library of Medicine. Visit the link above and download the latest "UMLS Metathesaurus Files" release [2020AB](https://download.nlm.nih.gov/umls/kss/2020AB/umls-2020AB-metathesaurus.zip) and run our UMLS install script. 
- Unzip the preprocessed BioCreative V CDR chemical dataset `bc5cdr.zip`

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0,'../../trove')


## 1. Load Unlabeled Data & Define Entity Classes

### A. Load Preprocessed Documents
This notebook assumes documents have already been preprocessed for sentence boundary detection and dumped into JSON format. See `preprocessing/README.md` for details.


In [2]:
%%time
import transformers
from trove.dataloaders import load_json_dataset

tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

data_dir = "data/bc5cdr/"
dataset = {
    split : load_json_dataset(f'{data_dir}/{split}.cdr.chemical.json', tokenizer)
    for split in ['train', 'dev', 'test']
}

Tagged Entities: 5203


Tokenization Error: Token is not a head token Annotation[Chemical](Cl|1240-1242) 19692487
Tokenization Error: Token is not a head token Annotation[Chemical](Cl|1579-1581) 15075188
Errors: Span Alignment: 2/5347 (0.0%)


Tagged Entities: 5345
Tagged Entities: 5385
CPU times: user 29.8 s, sys: 688 ms, total: 30.5 s
Wall time: 32.4 s


##### B. Define Entity Categories
In popular biomedical annotators such as [NCBO BioPortal](https://bioportal.bioontology.org/annotator), we configure the annotator by selecting a set of semantic categories which define our entity class and a corresponding set of ontologies mapped to those types.  

Trove uses a similar style of interface in API form. For `CHEMICAL` tagging, we define an entity class consisting of [UMLS Semantic Network](https://semanticnetwork.nlm.nih.gov/) types mapped to $\{0,1\}$. The semantic network defines 127 concept categories called _Semantic Types_ (e.g., Disease or Syndrome , Medical Device) which are mappable to 15 coarser-grained _Semantic Groups_ (e.g., Anatomy, Chemicals & Drugs, Disorders). 

We use the _Chemicals & Drugs_ (CHEM) semantic group as the basis of our positive class label $1$, removing some categories (e.g., Gene or Genome) that do not match the definition of chemical as outlined in the BC5CDR annotation guidelines. Non-chemical STYs define our negative class label $0$.

In [3]:
import pandas as pd

# load the chemical entity definition
entity_def = pd.read_csv('data/chemical_semantic_types.tsv', sep='\t')
class_map = {row.TUI:row.LABEL for row in entity_def.itertuples() if row.LABEL != -1}


## 2. Load Ontology Labeling Sources
### A. Unified Medical Language System (UMLS) Metathesaurus
The UMLS Metathesaurus is a convenient source for deriving labels, since it provides over 200 source vocabularies (terminologies) with consistent entity categorization provided by the UMLS Semantic Network.

The first time this is run, Trove requires access to the installation zip


In [4]:
%%time
from trove.labelers.umls import UMLS

# initialize UMLS
backend = 'pandas'
if not UMLS.is_initalized(backend=backend):
    print(f'Please initalize the UMLS before running this notebook. See `umls_install.sh`')
    

CPU times: user 2.85 ms, sys: 3.17 ms, total: 6.03 ms
Wall time: 11.8 ms


We apply some minimal preprocessing to each source vocabularies term set, as outlined in the Trove paper. The most important settings are:
- `SmartLowercase()`, a string matching heuristic for preserving likely abbreviations and acronyms
- `min_char_len`, `filter_rgx`, filters for terms that are single characters or numbers  

Other choices are largely for speed purposes, such as restricting the max token length used for string matching. 


In [5]:
%%time
from trove.labelers.umls import UMLS
from trove.transforms import SmartLowercase

# english stopwords
stopwords = set(open('data/stopwords.txt','r').read().splitlines())
stopwords = stopwords.union(set([t[0].upper() + t[1:] for t in stopwords]))

# options for filtering terms
config = {
    "type_mapping"  : "TUI",  # TUI = semantic types, CUI = concept ids
    'min_char_len'  : 2,
    'max_tok_len'   : 8,
    'min_dict_size' : 500,
    'stopwords'     : stopwords,
    'transforms'    : [SmartLowercase()],
    'languages'     : {"ENG"},
    'filter_sabs'   : {"SNOMEDCT_VET"},
    'filter_rgx'    : r'''^[-+]*[0-9]+([.][0-9]+)*$'''  # filter numbers
}

umls = UMLS(backend=backend, **config)


CPU times: user 1min 38s, sys: 8.19 s, total: 1min 47s
Wall time: 1min 43s


In [6]:
%%time
import numpy as np

def map_entity_classes(dictionary, class_map):
    """
    Given a dictionary, create the term entity class probabilities
    """
    k = len([y for y in set(class_map.values()) if y != -1])
    ontology = {}
    for term in dictionary:
        proba = np.zeros(shape=k).astype(np.float32)
        for cls in dictionary[term]:
            # ignore abstains
            idx = class_map[cls] if cls in class_map else -1
            if idx != -1:
                proba[idx - 1] += 1
        # don't include terms that don't map to any classes
        if np.sum(proba) > 0:
            ontology[term] = proba / np.sum(proba)
    return ontology

# These are the top 10 ontologies as ranked by term overlap with the BC5CDR training set
terminologies = ['CHV', 'SNOMEDCT_US', 'NCI', 'MSH']

ontologies = {
    sab : map_entity_classes(umls.terminologies[sab], class_map)
    for sab in terminologies
}


CPU times: user 57.1 s, sys: 1.48 s, total: 58.6 s
Wall time: 57.9 s


In [7]:
%%time

# create dictionaries for our Schwartz-Hearst abbreviation detection labelers
positive, negative = set(), set()

for sab in umls.terminologies:
    for term in umls.terminologies[sab]:
        for tui in umls.terminologies[sab][term]:
            if tui in class_map and class_map[tui] == 1:
                positive.add(term)
            elif tui in class_map and class_map[tui] == 0:
                negative.add(term)


CPU times: user 9.06 s, sys: 445 ms, total: 9.5 s
Wall time: 9.51 s


### B. Additional Ontologies: ChEBI Database
We also want to utilize non-UMLS ontologies. External databases such as ChEBI or CTD typically don't include rich mappings to Semantic Network types, so we treat this as an ontology/dictionary mapping to a single class label.

In [8]:
from ontologies import ChebiDatabase

config = {
    'min_char_len'  : 2,
    'max_tok_len'   : 8,
    'min_dict_size' : 1,
    'stopwords'     : stopwords,
    'transforms'    : [SmartLowercase()],
    'languages'     : None,
    'filter_sources': None,
    'filter_rgx'    : r'''^[-+]*[0-9]+([.][0-9]+)*$'''  # filter numbers
}
chebi = ChebiDatabase(cache_path=None, **config)

In [10]:
import os
from tqdm import tqdm
from urllib.request import urlretrieve


class ProgressBar(tqdm):
    """
    Based on https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
    """
    def update_to(self, b=1, bsize=1, tsize=None):
        """
        b  : int, optional
            Number of blocks transferred so far [default: 1].
        bsize  : int, optional
            Size of each block (in tqdm units) [default: 1].
        tsize  : int, optional
            Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)  # will also set self.n = b * bsize

def downloader(url, save_dir):
    fname = url.split('/')[-1]
    with ProgressBar(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=fname) as t:
        urlretrieve(url, filename = os.path.join(save_dir, filename), reporthook = t.update_to)

    
downloader('ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz', ".")
    
    

names.tsv.gz: 0.00B [00:00, ?B/s]


NameError: name 'filename' is not defined

In [20]:
import requests
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
from abc import ABCMeta, abstractmethod



# def download(url, outdir, block_size=1024):
#     """
#     See https://stackoverflow.com/a/37573701
#     """
#     print(url)
#     fname = url.split('/')[-1]
#     response = requests.get(url, stream=True)
#     total_bytes= int(response.headers.get('content-length', 0))
#     progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
#     with open(fname, 'wb') as file:
#         for data in response.iter_content(block_size):
#             progress_bar.update(len(data))
#             file.write(data)
#     progress_bar.close()
#     if total_bytes != 0 and progress_bar.n != total_bytes:
#         print("ERROR downloading file")

from tqdm import tqdm
import urllib.request


class ProgressBar(tqdm):
    """
    Based on https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
    """
    def update_to(self, b=1, bsize=1, tsize=None):
        """
        b  : int, optional
            Number of blocks transferred so far [default: 1].
        bsize  : int, optional
            Size of each block (in tqdm units) [default: 1].
        tsize  : int, optional
            Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)  # will also set self.n = b * bsize

def download(url, save_dir):
    fname = url.split('/')[-1]
    opener = urllib.request.build_opener()
    opener.addheaders = [("User-agent", "Mozilla/5.0")]
    urllib.request.install_opener(opener)
        
    with ProgressBar(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=fname) as t:
        urllib.request.urlretrieve(url, filename=os.path.join(save_dir, fname), reporthook=t.update_to)

def apply_transforms(term, transforms):
    for tf in transforms:
        term = tf(term.strip())
        if not term:
            return None
    return term
        
        
class KnowledgeBase(metaclass=ABCMeta):
    """
    We use Knowledge Base to loosely refer to a structured resource
    that contains terminology information. We are interested in the 
    following properties:
    
    - term typing
    - synonomy
    
    When source information is available, we store the above info mapped to source.
    
    """
    _cache_path = "cache/"
    
    def __init__(self, cache_path, files, force_download=False):
        
        self.cache_path = cache_path
        self.files = files
        
        if not self._check_cache() or force_download:
            self._download()
        else:
            print(f"Using files at {self.cache_path}")
        
    def _download(self):
        
        for fname,url in self.files.items():
            download(url, self.cache_path)
  
    def _check_cache(self):
        """
        Confirm all file dependencies exist in the cache.
        """
        if not os.path.exists(self.cache_path):
            os.makedirs(self.cache_path)
            return False
        
        for fname in self.files:
            if not os.path.exists(f"{self.cache_path}{fname}"):
                return False
        return True
        
    @abstractmethod
    def name(self):
        ...
        
    @abstractmethod
    def manifest(self):
        ...    
        
    @abstractmethod
    def _load(self, **kwargs):
        ...
    
    @abstractmethod
    def get_source_terms(self):
        ...
    
    @abstractmethod
    def get_source_synsets(self):
        ...
    
##'names.tsv.gz':'ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz'
              

class CtdDatabase(KnowledgeBase):
    """
    TODO: CTD contains additional entity type information we can encode as an Ontology LF
    """
    def __init__(self, cache_path=None, **kwargs):
        
        cache_root = cache_path if cache_path else KnowledgeBase._cache_path
        force_download = kwargs['force_download'] if 'force_download' in kwargs else False
        
        super().__init__(
            cache_path = f"{cache_root}{self.name}/",
            files = self.manifest,
            force_download = force_download
        )
        
        self.terms = {}
        self.data = self._load()
        
        for name,key in {'disease':'DiseaseName', 'chemical':'ChemicalName'}.items():
            self.terms[name] = self._collapse_terms(self.data[name], key)
            self.terms[name] = self._transform_terminologies(self.terms[name], **kwargs)
            
        # TODO
        self.synset = {}
        
    @property
    def name(self):
        return 'ctd'
        
    @property
    def manifest(self):
        return {
            'CTD_diseases.csv.gz' : 'http://ctdbase.org/reports/CTD_diseases.csv.gz',
            'CTD_chemicals.csv.gz' : 'http://ctdbase.org/reports/CTD_chemicals.csv.gz'
        }      
    
    def _collapse_terms(self, df, key):
        """
        CTD includes ID: terms -> synonyms. We just collapse 
        all terms into a single entity dictionary.
        """
        terms = set()
        for row in df.itertuples():
            if not pd.isnull(getattr(row, key)):
                terms.add(getattr(row, key))
            if not pd.isnull(row.Synonyms):
                for term in row.Synonyms.split("|"):
                    terms.add(term)
        return terms

    
    def _load_disease_data(self):
        
        columns = [
            'DiseaseName',
            'DiseaseID',
            'AltDiseaseIDs',
            'Definition',
            'ParentIDs',
            'TreeNumbers',
            'ParentTreeNumbers',
            'Synonyms',
            'SlimMappings'
        ]
        
        fpath = f"{self.cache_path}/CTD_diseases.csv.gz"
        return pd.read_csv(
            fpath, 
            comment='#', 
            sep=',', 
            names=columns,
            dtype=str
        )
    
    def _load_chemical_data(self):
        
        columns = [
            'ChemicalName',
            'ChemicalID',
            'CasRN',
            'Definition',
            'ParentIDs',
            'TreeNumbers',
            'ParentTreeNumbers',
            'Synonyms'
        ]
        
        fpath = f"{self.cache_path}/CTD_chemicals.csv.gz"
        return pd.read_csv(
            fpath, 
            comment='#', 
            sep=',', 
            names=columns,
            dtype=str
        )
    
    def _transform_terminologies(self,
                            terms,
                            min_char_len=2,
                            max_tok_len=100,
                            transforms=None,
                            filter_rgx=r'''^[0-9]$''',
                            stopwords=None,
                            **kwargs):
        
        transforms = [] if not transforms else transforms
        filter_rgx = re.compile(filter_rgx) if filter_rgx else None
        stopwords = {} if not stopwords else stopwords

        def include(t):
            return t and len(t) >= min_char_len and \
                   t.count(' ') <= max_tok_len - 1 and \
                   t not in stopwords and \
                   (filter_rgx and not filter_rgx.search(t))
    
        tmp = set()
        for term in terms:
            term = apply_transforms(term, transforms)
            if include(term):
                tmp.add(term)
        return tmp
    
    def _load(self):
        
        return {
            "disease" : self._load_disease_data(),
            "chemical" : self._load_chemical_data()
        }
      
    def get_source_terms(self, source):
        assert source in self.data
        return self.terms[source]
    
    def get_source_synsets(self, source):
        pass
        
ctd = CtdDatabase(**config)




Using files at cache/ctd/


In [23]:

#ctd_terms = ctd.get_source_terms('disease')


In [24]:

#class BioPortalDatabase(KnowledgeBase):
#    pass
    
        
# class CtdDatabase(KnowledgeBase):
#     pass

# class SpecialistLexicon(KnowledgeBase):
#     pass



# def load_ctd_dictionary(filename, stopwords=None):
#     '''Comparative Toxicogenomics Database'''
#     stopwords = stopwords if stopwords else {}
    
#     d = {}
#     header = ['DiseaseName', 'DiseaseID', 'AltDiseaseIDs', 'Definition', 
#               'ParentIDs', 'TreeNumbers', 'ParentTreeNumbers', 'Synonyms', 
#               'SlimMappings']
        
#     synonyms = {}
#     dnames = {}
#     with open(filename,"r") as fp:
#         for i,line in enumerate(fp):
#             line = line.strip()
#             if line[0] == "#":
#                 continue
#             row = line.split("\t")
#             if len(row) != 9:
#                 continue
#             row = dict(zip(header,row))
            
#             synset = row["Synonyms"].strip().split("|")
#             if synset:
#                 synonyms.update(dict.fromkeys(synset))
#             term = row["DiseaseName"].strip()
#             if term:
#                 dnames[term] = 1
    
#     terms = {lowercase(t) for t in set(list(synonyms.keys()) + list(dnames.keys())) if t}
#     # filter out stopwords 
#     return {t for t in terms if t not in stopwords and not re.search(r'''^[0-9]$''',t)}

# class AdamDictionary

#     def get_url(self) -> str:
#         return (
#             "http://arrowsmith.psych.uic.edu/arrowsmith_uic/download/adam.tar"
#         )


# class CTD:
#     """
#     Comparative Toxicogenomics Database
#     """
    
#     _cfg = {
#         'url': 'ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz'
#     }
#     _cache_path = "cache/chebi/"
    
#     def __init__(self, cache_path, **kwargs):
#         self.cache_path = cache_path
#         self.df = self._load_terminologies(**kwargs)
        
#     def terms(self, filter_sources=None):
        
#         filter_sources = filter_sources if filter_sources else {}
#         terms = set()
#         for source in self.terminologies:
#             if source in filter_sources:
#                 continue
#             terms = terms.union(self.terminologies[source])
#         return terms
        
#     def _load_terminologies(self,
#                             min_char_len=2,
#                             max_tok_len=100,
#                             min_dict_size=1,
#                             languages=None,
#                             transforms=None,
#                             filter_sources=None,
#                             filter_rgx=None,
#                             stopwords=None):
           
#         # defaults
#         languages = languages if languages else {}
#         transforms = [] if not transforms else transforms
#         filter_sources = filter_sources if filter_sources else {}
#         filter_rgx = re.compile(filter_rgx) if filter_rgx else None
#         stopwords = {} if not stopwords else stopwords

#         def include(t):
#             return t and len(t) >= min_char_len and \
#                    t.count(' ') <= max_tok_len - 1 and \
#                    t not in stopwords and \
#                    (filter_rgx and not filter_rgx.search(t))
    
#         df = pd.read_csv('/users/fries/downloads/names.tsv', 
#                          sep='\t', 
#                          na_filter=False, 
#                          dtype={'NAME':'object', 'COMPOUND_ID':'object'})
        
#         self.terminologies = {}
#         if languages:
#             df = df[df.LANGUAGE.isin(languages)]
            
#         for source, data in df.groupby(['SOURCE']):
#             if source in filter_sources or len(data) < min_dict_size:
#                 continue
#             self.terminologies[source] = set()
            
#             for term in data.NAME:
#                 term = apply_transforms(term, transforms)
#                 if include(term):
#                     self.terminologies[source].add(term)
#         self.data = df


### C. ADAM Biomedical Abbreviations

In [25]:
# TBD

## 3. Create Sequence Labeling Functions
### A. Guideline Labeling Functions

Annotation guidelines -- the instructions provided to domain experts when labeling training data -- can have a big impact on the generalizability of named enity classifiers. These instructions include seeminly simple choices such as whether to include determiners in entity spans ("the XXX") or more complex tagging choices like not labeling negated mentions of drugs. These choices are baked into the dataset and expensive to change. 

With weak supervision, many of these annotation assumptions can encoded as labeling functions, making training set changes faster, more flexible, and lower cost. For our `Chemical` labeling functions, we use the instructions provided [here](https://biocreative.bioinformatics.udel.edu/media/store/files/2015/bc5_CDR_data_guidelines.pdf) (pages 5-6) to create small dictionaries encoding some of these guidelines. Note that these can be easily expanded on, and in some cases complex rules (e.g., not annotating polypeptides with more than 15 amino acids) can be coupled with richer structured resources to create more sophisticated rules. 

We also fine it useful to include labeling functions that exclude numbers and punctuation tokens, another common flag in online biomedical annotators. 


In [26]:
from trove.labelers.labeling import (
    OntologyLabelingFunction,
    DictionaryLabelingFunction, 
    RegexEachLabelingFunction
)

# load our guideline dictionaries
df = pd.read_csv('data/bc5cdr_guidelines.tsv', sep='\t',)
guidelines = {
    t:np.array([1.,0.]) if y==1 else np.array([0.,1.]) 
    for t,y in zip(df.TERM, df.LABEL)
}

# use guideline negative examples as an additional stopword list
guideline_stopwords = {t:2 for t in df[df.LABEL==0].TERM}
stopwords = {t:2 for t in stopwords}

guideline_lfs = [
    OntologyLabelingFunction('guidelines', guidelines),
    DictionaryLabelingFunction('stopwords', stopwords, 2),
    DictionaryLabelingFunction('punctuation', set('!"#$%&*+,./:;<=>?@[\\]^_`{|}~'), 2),
    RegexEachLabelingFunction('numbers', [r'''^[-]*[1-9]+[0-9]*([.][0-9]+)*$'''], 2)
]


### B. Semantic Type Labeling Functions

The bulk of our supervision comes from structured medical ontologies. 

In [27]:
%%time
from trove.labelers.abbreviations import SchwartzHearstLabelingFunction

ontology_lfs = [
    OntologyLabelingFunction(
        f'UMLS_{name}', 
        ontologies[name], 
        stopwords=guideline_stopwords 
    )
    for name in ontologies
]

ontology_lfs += [
    SchwartzHearstLabelingFunction('UMLS_schwartz_hearst_1', positive, 1, stopwords=guideline_stopwords),
    SchwartzHearstLabelingFunction('UMLS_schwartz_hearst_2', negative, 2)
]


CPU times: user 16.1 s, sys: 205 ms, total: 16.3 s
Wall time: 16.2 s


In [29]:
ext_ontology_lfs = [
    DictionaryLabelingFunction('CHEBI', chebi.terms(), 1, stopwords=guideline_stopwords),
    
    #DictionaryLabelingFunction('DOID', doid.terms(), 1, stopwords=guideline_stopwords),
    #DictionaryLabelingFunction('HP', hp.terms(), 1, stopwords=guideline_stopwords),
    #DictionaryLabelingFunction('AutoNER', autoner.terms(), 1, stopwords=guideline_stopwords)
    
    DictionaryLabelingFunction('CTD_chemical', ctd.get_source_terms('chemical'), 1, stopwords=guideline_stopwords),
    DictionaryLabelingFunction('CTD_disease', ctd.get_source_terms('disease'), 2, stopwords=guideline_stopwords)
    
    
    
]


### C. SynSet Labeling Functions

For biomedical concepts, abbreviations and acronymns (more generally "short forms") are a large source of ambiguity. 
These can be ambiguous to human readers as well, so authors of PubMed abstract typically define ambiguous terms when they are introduced in text. We can take adavantage of this redundancy to both handle ambiguous mentions and identify out-of-ontology short forms using classic text mining techniques such as the [Schwartz-Hearst algorithm](https://psb.stanford.edu/psb-online/proceedings/psb03/schwartz.pdf).

In [31]:
# #
# # TBD
# #

# synset_lfs = [
#     SynSetLabelingFunction('SPECIALIST_synsets'),
#     SynSetLabelingFunction('ADAM_synsets'),
# ]

### D. Task-specific Labeling Functions

Ontology-based labeling functions can do suprisingly well on their own, but we can get more performance gains by adding custom labeling functions. For this demo, we focus on simple rules that are easy to create via data exploration but any existing rule-based model can be transformed into a labeling function. 

In [32]:
import re
from trove.labelers.labeling import RegexLabelingFunction

task_specific_lfs = []

# We noticed parentheses were causing errors so this labeling function 
# identifies negative examples, e.g. (n=100), (10%)
parens_rgxs = [
    r'''[(](p|n)\s*([><=]+|(less|great)(er)*)|(ml|mg|kg|g|(year|day|month)[s]*)[)]|[(][0-9]+[%][)]'''
]
# case insensitive 
parens_rgxs = [re.compile(rgx, re.I) for rgx in parens_rgxs]
task_specific_lfs.append(RegexLabelingFunction('LF_parentheses', parens_rgxs, 2))



In [33]:
lfs = guideline_lfs + ontology_lfs + ext_ontology_lfs #+ task_specific_lfs 

## 4. Construct the Label Matrix $\Lambda$
### A. Apply Sequence Labeling Functions

In [34]:
%%time
import itertools
from trove.labelers.core import SequenceLabelingServer

X_sents = [
    dataset['train'].sentences,
    dataset['dev'].sentences,
    dataset['test'].sentences,
]

labeler = SequenceLabelingServer(num_workers=4)
L_sents = labeler.apply(lfs, X_sents)


# Wall time: 1min 48s

Parallel(n_jobs=4)
auto block size=3495
Partitioned into 4 blocks, [3494 3495] sizes
CPU times: user 20.8 s, sys: 4.06 s, total: 24.8 s
Wall time: 1min 28s


In [35]:
import itertools

splits = ['train', 'dev', 'test']
tag2idx = {'O':2, 'I-Chemical':1}

X_words = [
    np.array(list(itertools.chain.from_iterable([s.words for s in X_sents[i]]))) 
    for i,name in enumerate(splits)
]

X_seq_lens = [
    np.array([len(s.words) for s in X_sents[i]])
    for i,name in enumerate(splits)
]

X_doc_seq_lens = [  
    np.array([len(doc.sentences) for doc in dataset[name].documents]) 
    for i,name in enumerate(splits)
]

Y_words = [
    [dataset['train'].tagged(i)[-1] for i in range(len(dataset['train']))],
    [dataset['dev'].tagged(i)[-1] for i in range(len(dataset['dev']))],
    [dataset['test'].tagged(i)[-1] for i in range(len(dataset['test']))],
]

Y_words[0] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[0]))])
Y_words[1] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[1]))])
Y_words[2] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[2]))])


### B. Build the Label Matrix

In [36]:
%%time
from scipy.sparse import dok_matrix, vstack, csr_matrix

def create_word_lf_mat(Xs, Ls, num_lfs):
    """
    Create word-level LF matrix from LFs indexed by sentence/word
    0 words X lfs
    1 words X lfs
    2 words X lfs
    ...
    
    """
    Yws = []
    for sent_i in range(len(Xs)):
        ys = dok_matrix((len(Xs[sent_i].words), num_lfs))
        for lf_i in range(num_lfs):
            for word_i,y in Ls[sent_i][lf_i].items():
                ys[word_i, lf_i] = y
        Yws.append(ys)
    return csr_matrix(vstack(Yws))

L_words = [
    create_word_lf_mat(X_sents[0], L_sents[0], len(lfs)),
    create_word_lf_mat(X_sents[1], L_sents[1], len(lfs)),
    create_word_lf_mat(X_sents[2], L_sents[2], len(lfs)),
]


CPU times: user 15.6 s, sys: 182 ms, total: 15.8 s
Wall time: 16.1 s


### C. Inspect Labeling Function Performance

In [37]:
from trove.metrics.analysis import lf_summary

lf_summary(L_words[0], Y=Y_words[0], lf_names=[lf.name for lf in lfs])


Unnamed: 0,j,Polarity,Coverage%,Overlaps%,Conflicts%,Coverage,Correct,Incorrect,Emp. Acc.
guidelines,0,"[1.0, 2.0]",0.006085,0.004694,0.001478,704,678,26,0.963068
stopwords,1,2,0.282796,0.021618,0.00083,32717,32649,68,0.997922
punctuation,2,2,0.099489,0.004279,0.000251,11510,11425,85,0.992615
numbers,3,2,0.035387,0.002809,0.001737,4094,3790,304,0.925745
UMLS_CHV,4,"[1.0, 2.0]",0.352145,0.339888,0.0174,40740,39696,1044,0.974374
UMLS_SNOMEDCT_US,5,"[1.0, 2.0]",0.334633,0.329619,0.017771,38714,37829,885,0.97714
UMLS_NCI,6,"[1.0, 2.0]",0.397032,0.351661,0.020295,45933,45115,818,0.982191
UMLS_MSH,7,"[1.0, 2.0]",0.181172,0.179504,0.011315,20960,20427,533,0.974571
UMLS_schwartz_hearst_1,8,1,0.006033,0.005394,0.003371,698,649,49,0.929799
UMLS_schwartz_hearst_2,9,2,0.01313,0.012058,0.004858,1519,1207,312,0.794602


## 5. Train the Label Model

In [38]:
# Trove uses a different internal mapping for labeling function abstains
def convert_label_matrix(L):
    # abstain is -1
    # negative is 0
    L = L.toarray().copy()
    L[L == 0] = -1
    L[L == 2] = 0
    return L

L_words_hat = [
    convert_label_matrix(L_words[0]),
    convert_label_matrix(L_words[1]),
    convert_label_matrix(L_words[2])
]

Y_words_hat = [
    np.array([0 if y == 2 else 1 for y in Y_words[0]]),
    np.array([0 if y == 2 else 1 for y in Y_words[1]]),
    np.array([0 if y == 2 else 1 for y in Y_words[2]])
]


In [39]:
import functools
from trove.models.model_search import grid_search
from snorkel.labeling.model.label_model import LabelModel

np.random.seed(1234)

n = L_words_hat[0].shape[0]

param_grid = {
    'lr': [0.01, 0.005, 0.001, 0.0001],
    'l2': [0.001, 0.0001],
    'n_epochs': [50, 100, 200, 600, 700, 1000],
    'prec_init': [0.6, 0.7, 0.8, 0.9],
    'optimizer': ["adamax"], 
    'lr_scheduler': ['constant'],

#     'seed': list(np.random.randint(0,10000, 400)),
#     'mu_eps': [1 / 10 ** np.ceil(np.log10(n*100)), 
#             1 / 10 ** np.ceil(np.log10(n*10)),
#             1 / 10 ** np.ceil(np.log10(n))]
}

model_class_init = {
    'cardinality': 2, 
    'verbose': True
}

n_model_search = 25
num_hyperparams = functools.reduce(lambda x,y:x*y, [len(x) for x in param_grid.values()])
print("Hyperparamater Search Space:", num_hyperparams)


L_train      = L_words_hat[0]
Y_train      = Y_words_hat[0]
L_dev        = L_words_hat[1]
Y_dev        = Y_words_hat[1]

label_model, best_config = grid_search(LabelModel, 
                                       model_class_init, 
                                       param_grid,
                                       train = (L_train, Y_train, X_seq_lens[0]),
                                       dev = (L_dev, Y_dev, X_seq_lens[1]),
                                       n_model_search=n_model_search, 
                                       val_metric='f1', 
                                       seq_eval=True,
                                       seed=1234,
                                       tag_fmt_ckpnt='IO')

Hyperparamater Search Space: 192
Using SEQUENCE dev checkpointing
Using IO dev checkpointing
Grid search over 25 configs
[0] Label Model
[1] Label Model
[2] Label Model
[3] Label Model
[4] Label Model
[5] Label Model
[6] Label Model
[7] Label Model
[8] Label Model
{'lr': 0.0001, 'l2': 0.0001, 'n_epochs': 600, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant'}
[TRAIN] accuracy: 97.96 | precision: 84.63 | recall: 84.34 | f1: 84.48
[DEV]   accuracy: 98.19 | precision: 86.60 | recall: 86.49 | f1: 86.55
----------------------------------------------------------------------------------------
[9] Label Model
[10] Label Model
[11] Label Model
{'lr': 0.001, 'l2': 0.001, 'n_epochs': 100, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant'}
[TRAIN] accuracy: 98.01 | precision: 84.38 | recall: 85.34 | f1: 84.85
[DEV]   accuracy: 98.22 | precision: 86.59 | recall: 87.69 | f1: 87.13
------------------------------------------------------------------------------------

In [40]:
from trove.metrics import eval_label_model # get_coverage, 

for i in range(3):
    #get_coverage(L_words_hat[i], Y_words_hat[i])
    #print("IO Tag Format")
    #eval_label_model(label_model, L_words_hat[i], Y_words_hat[i], X_seq_lens[i])
    print("BIO Tag Format")
    eval_label_model(label_model, L_words_hat[i], Y_words_hat[i], X_seq_lens[i])
    print('-' * 80)


BIO Tag Format
[Label Model]   accuracy: 98.01 | precision: 84.38 | recall: 85.34 | f1: 84.85
[Majority Vote] accuracy: 97.57 | precision: 76.15 | recall: 83.12 | f1: 79.48
--------------------------------------------------------------------------------
BIO Tag Format
[Label Model]   accuracy: 98.22 | precision: 86.59 | recall: 87.69 | f1: 87.13
[Majority Vote] accuracy: 97.83 | precision: 78.47 | recall: 85.42 | f1: 81.80
--------------------------------------------------------------------------------
BIO Tag Format
[Label Model]   accuracy: 98.39 | precision: 86.10 | recall: 87.23 | f1: 86.66
[Majority Vote] accuracy: 97.85 | precision: 76.23 | recall: 84.52 | f1: 80.16
--------------------------------------------------------------------------------


## 6. Export Proba Conll

In [None]:
#
# TBD
#