In [1]:
import os
import sys
from tqdm import tqdm
import pickle
import spacy
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

sys.path.append("./loaders")
from ModelTuner import ModelTuner
from ChebiLoader import ChebiLoader
from CraftLoader import CraftLoader
from BC5CDRLoader import BC5CDRLoader
from NLMChemLoader import NLMChemLoader

nlp = spacy.load("en_core_web_lg")

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
from IPython.display import display, HTML

label_list = ["O", "B-Chemical", "I-Chemical", "B-role", "I-role"]

id2label = {
    0: "O",
    1: "B-chemical",
    2: "I-chemical",
    3: "B-role",
    4: "I-role"
}
    
label2id = {
    "O": 0,
    "B-chemical": 1,
    "I-chemical": 2,
    "B-role": 3,
    "I-role": 4
}


tuner = ModelTuner("google/electra-base-discriminator", label_list, id2label, label2id)

tuner.load_model("./model/chemical_extract_google-electra-base-discriminator")

All model checkpoint layers were used when initializing TFElectraForTokenClassification.

All the layers of TFElectraForTokenClassification were initialized from the model checkpoint at ./model/chemical_extract_google-electra-base-discriminator.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFElectraForTokenClassification for predictions without further training.


In [3]:
chebi = ChebiLoader("/local/sps-local/chebi/chebi.owl")
craft = CraftLoader("./assets/test/CRAFT", chebi)
nlm = NLMChemLoader("./assets/test/NLM_Chem_corpus/", chebi)
cdr = BC5CDRLoader("./assets/test/BC5CDR/", chebi)

loading chebi from: /local/sps-local/chebi/chebi.owl

loading chemicals and their synonyms
loading roles and their synonyms

found 409625 chemicals and 14176 roles.
Memory usage of ChebiLoader: 5207.80859375 MB
loading CRAFT from: ./assets/test/CRAFT
collecting chemical entities...
collecting text...
cutting text into spans and labeling them.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:12<00:00,  1.51it/s]


loading NLMChem dataset from: ./assets/test/NLM_Chem_corpus/
loading Chebi to add roles which can be lexically found in the text snippets...
using 4 as a minimum character length for roles to mark them in the text
adding special tokenizer rules for chemical roles in Chebi


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14176/14176 [00:03<00:00, 3565.95it/s]


collecting entities
NLMChem has 30 files to parse


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [01:31<00:00,  3.04s/it]


loaded 6915 spans and the according labels.
loading BC5CDR dataset from: ./assets/test/BC5CDR/
loading Chebi to add roles which can be lexically found in the text snippets...
using 4 as a minimum character length for roles to mark them in the text
adding special tokenizer rules for chemical roles in Chebi


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14176/14176 [00:03<00:00, 3584.24it/s]


collecting entities
BC5CDR has 1 files to parse


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:00<00:00, 60.02s/it]

loaded 5214 spans and the according labels.





In [4]:
html_elems = {
    1: ("<b style='font-size:1.5em;'>", "</b>"),
    3: ("<b style='color:blue; font-size:1.5em;'><i>","</i></b>")    
}

def render_as_html(text, specials):
    start_dict = {}
    end_dict = {}
    for s in specials:
        start_dict[s[1]] = html_elems[s[0]][0]
        end_dict[s[2]] = html_elems[s[0]][1]
    html = ""    
    for i, c  in enumerate(text):                
        start_elem = start_dict.get(i, None)
        end_elem = end_dict.get(i, None)

        if start_elem:
            html += start_elem
        if end_elem:
            html += end_elem
        html += c
    display(HTML(html))

In [5]:
class Stats:
    def __init__(self):
        self.tp = 0
        self.fp = 0
        self.fn = 0
        self.fp_dict = dict()
        self.fn_dict = dict()
    def inc_tp(self, c):
        self.tp += c
    def inc_fp(self, c):
        self.fp += c
    def inc_fn(self, c):
        self.fn += c
    def precision(self):
        return self.tp / (self.tp + self.fp)
    def recall(self):
        return self.tp / (self.tp + self.fn)
    def f_measure(self):
        p = self.precision()
        r = self.recall()
        return 2 * (p * r) / (p + r)
    def add_fp(self, fp_entity):
        self.fp_dict[fp_entity] = self.fp_dict.get(fp_entity, 0) + 1
    def add_fn(self, fn_entity):
        self.fn_dict[fn_entity] = self.fn_dict.get(fn_entity, 0) + 1

    def clear(self):
        self.tp = 0
        self.fp = 0
        self.fn = 0
        self.fp_dict.clear()
        self.fn_dict.clear()
    
    def __str__(self):
        return f"tp: {self.tp}, fp: {self.fp}, fn: {self.fn} | precision: {self.precision()}, recall: {self.recall()}, f-measure: {self.f_measure()}"

def count_stats(xset, l):
    c = 0
    for x in xset:
        if x[0] == l:
            c += 1
    return c
    

def eval_stats(spans, labels, chemstats, rolestats, print_diff=False):
    if len(labels) <= 1:
        return
    
    text = ''.join(spans)
    pred_specials = set(tuner.infer(text))

    text_idx = 0
    true_specials = set()
    for span, label in zip(spans, labels):
        if label > 0:
            true_specials.add((label, text_idx, text_idx+len(span)))
        text_idx += len(span)

    tp = pred_specials.intersection(true_specials)
    fp = pred_specials.difference(true_specials)
    fn = true_specials.difference(pred_specials)

    if len(fp)>1 or len(fn)>1:        
        if print_diff: 
            render_as_html(text, true_specials)
            print("")
            render_as_html(text, pred_specials)
            print("\n--------------------------------------------------------------------\n")

    for x in fp:        
        if x[0] < 3:
            chemstats.add_fp(text[x[1]:x[2]])
        else:
            rolestats.add_fp(text[x[1]:x[2]])
    for x in fn:
        if x[0] < 3:
            chemstats.add_fn(text[x[1]:x[2]])
        else:
            rolestats.add_fn(text[x[1]:x[2]])    

    chemstats.inc_tp(count_stats(tp, 1))
    chemstats.inc_fp(count_stats(fp, 1))
    chemstats.inc_fn(count_stats(fn, 1))

    rolestats.inc_tp(count_stats(tp, 3))
    rolestats.inc_fp(count_stats(fp, 3))
    rolestats.inc_fn(count_stats(fn, 3))
    

def print_stats(dataset_spans, dataset_labels):
    chemstats = Stats()
    rolestats = Stats()
    chemstats.clear()
    rolestats.clear()
    
    for spans, labels in tqdm(zip(dataset_spans, dataset_labels), total=len(dataset_labels)):
        eval_stats(spans, labels, chemstats, rolestats)

    print("label/span count:", len(dataset_spans))
    print("chems stats:", chemstats)
    print("roles stats:", rolestats)
    print("fp-chems:", list(sorted(chemstats.fp_dict.items(), key=lambda item: item[1], reverse=True))[:10])
    print("fn-chems:", list(sorted(chemstats.fn_dict.items(), key=lambda item: item[1], reverse=True))[:10])
    print()
    print("fp-roles:", list(sorted(rolestats.fp_dict.items(), key=lambda item: item[1], reverse=True))[:10])
    print("fn-roles:", list(sorted(rolestats.fn_dict.items(), key=lambda item: item[1], reverse=True))[:10])


In [6]:
print("=== BC5CDR-RESULTS ===")
print_stats(cdr.spans, cdr.labels)
print()
print()
print("=== NLM-RESULTS ===")
print_stats(nlm.spans, nlm.labels)
print()
print()
print("=== CRAFT-RESULTS ===")
print_stats(craft.spans, craft.labels)
print()
print()

=== BC5CDR-RESULTS ===


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5214/5214 [05:34<00:00, 15.58it/s]


label/span count: 5214
chems stats: tp: 4868, fp: 346, fn: 526 | precision: 0.9336401994629843, recall: 0.9024842417500927, f-measure: 0.9177978883861236
roles stats: tp: 709, fp: 66, fn: 62 | precision: 0.9148387096774193, recall: 0.9195849546044098, f-measure: 0.9172056921086674
fp-chems: [('lipid', 16), ('GEM-P', 8), ('ROS', 8), ('calcium', 7), ('TBPS', 6), ('NS-718', 6), ('lovastatin', 6), ('35S', 5), ('dl-sotalol', 5), ('antidepressants', 4)]
fn-chems: [('CCK-8', 11), ('BS', 11), ('contrast', 9), ('AVP', 9), ('GEM', 8), ('APAP', 8), ('LNNA', 8), ('VGB', 8), ('K', 7), ('OCs', 7)]

fp-roles: [('neuroleptic', 4), ('androgen', 4), ('progestagens', 3), ('COX-2 inhibitors', 3), ('antagonist', 2), ('diuretic hormone', 2), ('drugs', 2), ('estrogen', 2), ('inhibitor', 2), ('inhibitors', 2)]
fn-roles: [('antidepressant', 6), ('antagonist', 5), ('antidepressants', 4), ('inhibitors', 3), ('protective agent', 2), ('hormone', 2), ('diuretic', 2), ('drug', 2), ('anticholinesterases', 2), ('free 

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6915/6915 [06:25<00:00, 17.93it/s]


label/span count: 6915
chems stats: tp: 6273, fp: 1088, fn: 1821 | precision: 0.8521939953810623, recall: 0.7750185322461083, f-measure: 0.8117761242316404
roles stats: tp: 806, fp: 67, fn: 52 | precision: 0.9232531500572738, recall: 0.9393939393939394, f-measure: 0.9312536106296939
fp-chems: [('DEX', 66), ('glucose', 40), ('GDP', 37), ('Em', 30), ('3H', 28), ('gemcitabine', 23), ('blood glucose', 23), ('', 22), ('fat', 18), ('2H', 16)]
fn-chems: [('PTX', 148), ('BAK', 126), ('CKC', 112), ('Vam3', 94), ('DEX-IND', 78), ('AEATP', 49), ('FCM', 43), ('KS', 40), ('DCP', 37), ('GDP-glucose', 35)]

fp-roles: [('biocides', 10), ('Syk inhibitors', 8), ('catalysts', 5), ('', 3), ('buffer', 3), ('neoadjuvant', 3), ('EGFR tyrosine kinase inhibitors', 3), ('anti-infectives', 2), ('preservative', 2), ('acid', 2)]
fn-roles: [('inhibitors', 8), ('antidiabetic', 7), ('tyrosine kinase inhibitors', 4), ('Reagents', 3), ('antibiotic', 2), ('antihyperglycemic', 2), ('anti-obesity agent', 2), ('anti-inflam

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5497/5497 [02:36<00:00, 35.09it/s]

label/span count: 5497
chems stats: tp: 932, fp: 436, fn: 1405 | precision: 0.6812865497076024, recall: 0.39880188275566963, f-measure: 0.5031039136302294
roles stats: tp: 205, fp: 53, fn: 64 | precision: 0.7945736434108527, recall: 0.7620817843866171, f-measure: 0.7779886148007591
fp-chems: [('PBS', 68), ('huntingtin', 25), ('fat', 14), ('tet', 12), ('polyglutamine', 10), ('paraffin', 8), ('Alcian blue', 8), ('pachytene', 7), ('dextran-FITC', 7), ('FIAU', 7)]
fn-chems: [('protein', 285), ('DNA', 113), ('Aβ', 112), ('proteins', 108), ('b', 79), ('RNA', 64), ('mRNA', 47), ('solution', 41), ('molecules', 30), ('peptide', 26)]

fp-roles: [('acid', 17), ('agonist', 4), ('inhibitors', 3), ('BMP antagonists', 2), ('BMP antagonist', 2), ('activator', 2), ('inhibitor', 2), ('acids', 2), ('secret', 2), ('eosinophiles', 2)]
fn-roles: [('dye', 10), ('chow', 10), ('acidic', 4), ('pigment', 4), ('pigmented', 4), ('PPARδ agonist', 4), ('epitopes', 3), ('toxin', 2), ('antagonists', 2), ('antagonist',


