Estimate tmChem's and ChemExt's perfomance

#### Read test data and annotations; mask ID entities

In [1]:
# Import all the necessary packages

from typing import Tuple, Optional, List
from itertools import chain, repeat, starmap, groupby
import operator as op
import json
import glob
import time
import re

from sklearn.metrics import f1_score, precision_score, recall_score
from fn import F
import numpy as np
import tqdm
import joblib
import becas

becas.email = "ilia.korvigo@gmail.com"

from sciner import util, intervals
from sciner.corpora import corpus, chemdner
from sciner.preprocessing import encoding, preprocessing, sampling, parsing
from sciner.util import oldmap

In [2]:
mapping = corpus.parse_mapping(
    ["ABBREVIATION:1",
     "FAMILY:2",
     "FORMULA:3",
     "MULTIPLE:4",
     "TRIVIAL:5",
     "SYSTEMATIC:6",
     "IDENTIFIER:7"]
)

ID = mapping["IDENTIFIER"]  # we will have to make predicted IDs, because chempred doesn't target them at all

In [3]:
def process_abstracts(tokeniser, abstracts, mapping=None):
            
    def flatten(arr):
        def f(x):
            pos = x.nonzero()[-1]
            return np.random.choice(pos[pos > 0]) if pos.any() else 0
        return np.apply_along_axis(f, 1, arr)
    
    flat_abstracts = map(corpus.flatten_abstract, abstracts)
    ids, srcs, texts, annotations, borders = zip(*chain.from_iterable(flat_abstracts))
    # parse texts and sample tokens within sentences
    parsed_texts = list(map(tokeniser, texts))
    samples = list(starmap(sampling.sample_sentences, zip(borders, parsed_texts)))
    tokens = (F(map, F(map, intervals.unload) >> F(map, list)) >> list)(samples)
    # make annotations if necessary
    if mapping is not None:
        nlabels = len(set(mapping.values()) | {0})
        anno_encoder = F(encoding.encode_annotation, mapping)
        border_encoder = F(encoding.encode_annotation, mapping, start_only=True)
        enc_annotations = list(starmap(anno_encoder, zip(annotations, map(len, texts))))
        enc_borders = list(starmap(border_encoder, zip(annotations, map(len, texts))))
        sample_annotations = [[flatten(preprocessing.annotate_sample(nlabels, anno, s)) for s in samples_]
                              for anno, samples_ in zip(enc_annotations, samples)]
        entity_borders = [[flatten(preprocessing.annotate_sample(nlabels, b_anno, s)) for s in samples_]
                           for b_anno, samples_ in zip(enc_borders, samples)]
    else:
        sample_annotations = repeat(repeat(None))
        entity_borders = repeat(repeat(None))
    return zip(*util.flatzip([ids, srcs], [samples, tokens, sample_annotations, entity_borders]))

In [4]:
nsteps = 200
tokeniser = F(parsing.tokenise, [re.compile("\w+|[^\s\w]")])

texts = chemdner.parse_abstracts("chemdner_corpus/testing.abstracts.txt")
sborders = chemdner.parse_borders("chemdner_corpus/testing.borders.tsv")
refanno = chemdner.parse_annotations("chemdner_corpus/testing.annotations.txt")

abstracts_ref = list(chemdner.align_abstracts(texts, refanno, sborders))

ids, srcs, samples, ws, w_anno, b_anno = process_abstracts(tokeniser, abstracts_ref, mapping)

wanno_ref, anno_mask = util.join(w_anno, nsteps, trim=True)
banno_ref, _ = util.join(b_anno, nsteps, trim=True)

In [5]:
# mask ID entities
entity_filter = wanno_ref != ID
wanno_ref = np.clip(np.where(entity_filter, wanno_ref, 0), 0, 1)
banno_ref = np.clip(np.where(entity_filter, banno_ref, 0), 0, 1)

#### Benchmark ChemExtract annotations

In [6]:
def wrap_spans(spans: Tuple[str, int, int]) -> intervals.Interval:
    return [intervals.Interval(span.start, span.end, "ANY") for span in spans]

anno_chemext = [corpus.AbstractAnnotation(id_, wrap_spans(title), wrap_spans(body)) 
                for id_, title, body in joblib.load("bench/chemdataextractor/cems.joblib")]

abstracts_chemext = list(chemdner.align_abstracts(texts, anno_chemext, sborders))

_, _, _, _, w_anno_chemext, b_anno_chemext = process_abstracts(tokeniser, abstracts_chemext, {"ANY": 1})

wanno_chemext, _ = util.join(w_anno_chemext, nsteps, trim=True)
banno_chemext, _ = util.join(b_anno_chemext, nsteps, trim=True)

wanno_chemext_masked = np.where(entity_filter, wanno_chemext, 0)
banno_chemext_masked = np.where(entity_filter, banno_chemext, 0)

In [7]:
# Estimate F1
print(precision_score(wanno_ref[anno_mask], wanno_chemext_masked[anno_mask]), 
      precision_score(banno_ref[anno_mask], banno_chemext_masked[anno_mask]))
print(recall_score(wanno_ref[anno_mask], wanno_chemext_masked[anno_mask]), 
      recall_score(banno_ref[anno_mask], banno_chemext_masked[anno_mask]))
print(f1_score(wanno_ref[anno_mask], wanno_chemext_masked[anno_mask]), 
      f1_score(banno_ref[anno_mask], banno_chemext_masked[anno_mask]))

0.914208110893 0.901354664952
0.920796936022 0.90473692705
0.917490694439 0.903042629025


#### Benchmark Becas annotations

In [8]:
GROUPS = {"CHED": True}  # We only want chemical entities
FORMAT = "a1"

def parse_a1_anno(a1):
    try:
        parsed = [l.split("\t") for l in a1.splitlines() if not l.startswith("#")]
        return [(*map(int, span.split()[1:]), entity) for _, span, entity in parsed]
    except (TypeError, ValueError):
        return None


def run_becas_on_abstract(abstract: corpus.Abstract) \
        -> Tuple[int, List[Tuple[int, int, str]], List[Tuple[int, int, str]]]:
    title = abstract.title
    body = abstract.body
    title_anno = parse_a1_anno(becas.export_text(title, FORMAT, GROUPS))
    body_anno = parse_a1_anno(becas.export_text(body, FORMAT, GROUPS))
    assert all(title[start:stop] == entity for start, stop, entity in title_anno)
    assert all(body[start:stop] == entity for start, stop, entity in body_anno)
    time.sleep(1)  # Letting the server rest to avoid blockage
    return abstract.id, title_anno, body_anno

In [9]:
# becas_pred = list(map(run_becas_on_abstract, tqdm.tqdm(texts)))
# ! mkdir -p bench/becas
# joblib.dump(becas_pred, "bench/becas/annotation.joblib", 1)

becas_pred = joblib.load("bench/becas/annotation.joblib")
tointervals = F(map, lambda x: intervals.Interval(*x[:2], "ANY")) >> list
anno_becas = [corpus.AbstractAnnotation(pmid, tointervals(title), tointervals(body))
              for pmid, title, body in becas_pred]

abstracts_becas = list(chemdner.align_abstracts(texts, anno_becas, sborders))

_, _, _, _, w_anno_becas, b_anno_becas = process_abstracts(tokeniser, abstracts_becas, {"ANY": 1})

wanno_becas, _ = util.join(w_anno_becas, nsteps, trim=True)
banno_becas, _ = util.join(b_anno_becas, nsteps, trim=True)

wanno_becas_masked = np.where(entity_filter, wanno_becas, 0)
banno_becas_masked = np.where(entity_filter, banno_becas, 0)

In [10]:
# Estimate F1
print(precision_score(wanno_ref[anno_mask], wanno_becas_masked[anno_mask]), 
      precision_score(banno_ref[anno_mask], banno_becas_masked[anno_mask]))
print(recall_score(wanno_ref[anno_mask], wanno_becas_masked[anno_mask]), 
      recall_score(banno_ref[anno_mask], banno_becas_masked[anno_mask]))
print(f1_score(wanno_ref[anno_mask], wanno_becas_masked[anno_mask]), 
      f1_score(banno_ref[anno_mask], banno_becas_masked[anno_mask]))

0.586151907199 0.478987628671
0.429956794898 0.61394448031
0.496049280599 0.538133720005


#### Transform tmChem annotations

First, we need to get the annotations. Out of all options we've tried, raw-text CURL requests have proved to be the most fruitful. Since the step takes quite a long while, we've dumped the results. 

In [11]:
# from itertools import chain
# import subprocess
# import time
# import glob

# from fn import F
# from multiprocess import Pool

# from sciner.corpora.chemdner import parse_abstracts


# REQ = "curl -d {text} https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/RESTful/tmTool.cgi/tmChem/Submit/"
# RES = "curl https://www.ncbi.nlm.nih.gov/CBBresearch/Lu/Demo/RESTful/tmTool.cgi/{sessid}/Receive/"


# def format_abstract(abstract):
#     id_ = abstract.id
#     title = abstract.title
#     body = abstract.body
#     return ((id_, '\'{{"sourcedb":"PubMed","sourceid":"{}T","text":"{}"}}\''.format(id_, title)),
#             (id_, '\'{{"sourcedb":"PubMed","sourceid":"{}A","text":"{}"}}\''.format(id_, body)))


# def process_text(text):
#     id_, text = text
#     try:
#         # send request
#         sessid = subprocess.check_output(REQ.format(text=text), shell=True).decode()
#         # get output
#         ready = False
#         attempts = 0
#         while not ready:
#             if attempts >= 5:
#                 break 
#             attempts += 1
#             time.sleep(10)
#             output = subprocess.check_output(RES.format(sessid=sessid), shell=True)
#             ready = b'[Warning]' not in output
#         return id_, text, output
#     except subprocess.CalledProcessError:
#         return id_, text, None

    
# def iscomplete(out): 
#     """
#     The request is completed
#     """
#     return out[-1] is not None and b'Warning' not in out[-1]


# abstracts = parse_abstracts("chemdner_corpus/testing.abstracts.txt")
# texts = (F(map, format_abstract) >> chain.from_iterable >> list)(abstracts)
# with Pool(5) as workers:
#     curlout = workers.map(process_text, texts)
# joblib.dump(curlout, "bench/tmchem/curlout.joblib", 1)

In [12]:
def extract_annotations(j):
    """
    Extract annotation from a PubTator-generated JSON
    """
    pmid = j["sourceid"]
    text = j["text"]
    spans = [(anno["span"]["begin"], anno["span"]["end"]) for anno in j["denotations"]]
    entities = [(start, stop, text[start:stop]) for start, stop in spans]
    return entities


def parse_curlout(out):
    """
    Parse the output of `process_text`
    """
    def istitle(id_):
        return id_.endswith("T")
    
    pmid, request, response = out
    try:
        reqj = json.loads(request.strip("'"))
        resj = json.loads(response.decode())
        if reqj["text"] != resj["text"]:
            raise ValueError
        annotations = extract_annotations(resj)
        # validate entity strings
        reftext = reqj["text"]
        assert all(reftext[start:stop] == entity 
                   for start, stop, entity in annotations)
        return pmid, istitle(reqj["sourceid"]), annotations
    except (json.JSONDecodeError, AttributeError, ValueError, KeyError, AssertionError):
        return None

    
def toannotation(group) -> corpus.AbstractAnnotation:
    tid, tanno = next((pmid, anno) for pmid, istitle, anno in group if istitle)
    bid, banno = next((pmid, anno) for pmid, istitle, anno in group if not istitle)
    assert tid == bid
    return corpus.AbstractAnnotation(
        tid,
        [intervals.Interval(start, stop, data="ANY") for start, stop, _ in tanno],
        [intervals.Interval(start, stop, data="ANY") for start, stop, _ in banno]
    )

In [13]:
curlout = joblib.load("bench/tmchem/curlout.joblib")
curlout_parsed = list(map(parse_curlout, curlout))
# Group titles and bodies back together and remove incomplete abstracts
getpmid = op.itemgetter(0)
tmchem_complete = (
    F(filter, bool) >>
    F(sorted, key=getpmid) >>
    (lambda x: groupby(x, getpmid)) >>
    (map, lambda x: list(x[1])) >>
    (filter, lambda x: len(x) == 2) >>
    (map, toannotation) >>
    list
)(curlout_parsed)

# joblib.dump(tmchem_complete, "bench/tmchem/complete.joblib")

In [14]:
sum(map(bool, curlout_parsed)), len(tmchem_complete)

(4710, 1855)

Since here we only have 1855 complete abstracts out of 3000, we'll have to recalculate the references

In [16]:
tmchem_ids = {anno.id for anno in tmchem_complete}

texts_subset = [text for text in texts if text.id in tmchem_ids]
abstracts_ref_subset = list(chemdner.align_abstracts(texts_subset, refanno, sborders))

*_, w_anno_subset, b_anno_subset = process_abstracts(tokeniser, abstracts_ref_subset, mapping)

wanno_ref_subset, anno_mask_subset = util.join(w_anno_subset, nsteps, trim=True)
banno_ref_subset, _ = util.join(b_anno_subset, nsteps, trim=True)
entity_filter_subset = wanno_ref_subset != ID

wanno_ref_subset = np.clip(np.where(entity_filter_subset, wanno_ref_subset, 0), 0, 1)
banno_ref_subset = np.clip(np.where(entity_filter_subset, banno_ref_subset, 0), 0, 1)

abstracts_tmchem = list(chemdner.align_abstracts(texts_subset, tmchem_complete, sborders))

_, _, _, _, w_anno_tmchem, b_anno_tmchem = process_abstracts(tokeniser, abstracts_tmchem, {"ANY": 1})

wanno_tmchem, _ = util.join(w_anno_tmchem, nsteps, trim=True)
banno_tmchem, _ = util.join(b_anno_tmchem, nsteps, trim=True)

wanno_tmchem_masked = np.where(entity_filter_subset, wanno_tmchem, 0)
banno_tmchem_masked = np.where(entity_filter_subset, banno_tmchem, 0)

In [17]:
# Estimate F1
print(precision_score(wanno_ref_subset[anno_mask_subset], wanno_tmchem_masked[anno_mask_subset]), 
      precision_score(banno_ref_subset[anno_mask_subset], banno_tmchem_masked[anno_mask_subset]))
print(recall_score(wanno_ref_subset[anno_mask_subset], wanno_tmchem_masked[anno_mask_subset]), 
      recall_score(banno_ref_subset[anno_mask_subset], banno_tmchem_masked[anno_mask_subset]))
print(f1_score(wanno_ref_subset[anno_mask_subset], wanno_tmchem_masked[anno_mask_subset]), 
      f1_score(banno_ref_subset[anno_mask_subset], banno_tmchem_masked[anno_mask_subset]))

0.803731546056 0.692954867117
0.591442282058 0.596642068919
0.681435976195 0.64120193532
