## load the model which has been trained in the other notebook and use for inference

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import sys
sys.path.append("./loaders")

from datetime import datetime
from tqdm import tqdm
from IPython.display import display, HTML
import json
import pickle

from ModelTuner import ModelTuner
import spacy
nlp = spacy.load("en_core_web_lg")

model_name = "google/electra-base-discriminator"

id2label = {
    0: "O",
    1: "B-chemical",
    2: "I-chemical",
    3: "B-role",
    4: "I-role"
}
label2id = {
    "O": 0,
    "B-chemical": 1,
    "I-chemical": 2,
    "B-role": 3,
    "I-role": 4,    
}

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
tuner = ModelTuner(model_name, list(label2id.keys()), id2label = id2label, label2id = label2id)
tuner.load_model(f"/local/sps-local/cear-inferer/chemical_extract_{model_name.replace('/', '-')}")
tuner.model.summary()

All model checkpoint layers were used when initializing TFElectraForTokenClassification.

All the layers of TFElectraForTokenClassification were initialized from the model checkpoint at /local/sps-local/cear-inferer/chemical_extract_google-electra-base-discriminator.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFElectraForTokenClassification for predictions without further training.


Model: "tf_electra_for_token_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 electra (TFElectraMainLaye  multiple                  108891648 
 r)                                                              
                                                                 
 dropout_37 (Dropout)        multiple                  0 (unused)
                                                                 
 classifier (Dense)          multiple                  3845      
                                                                 
Total params: 108895493 (415.40 MB)
Trainable params: 108895493 (415.40 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [3]:
text = """Biomolecules in microbes related to CO2 -sensitive pathways or acting as a CO2 trans-
ducer have been proposed as appealing targets for medicines, since they control cell devel-
opment and the subsequent synthesis of chemicals, enhancing the pathogen persistence
in the host [26,27]. In this context, a crucial role is played by a superfamily of molecules
known as carbonic anhydrases (CAs, EC 4.2.1.1). CAs can be thought as molecules that,
rather than instantly detecting a change in CO2 , serve as CO2 transducers, adjusting its
levels [23,28]. With their activity, the CAs encoded by the bacterial genome of pathogenic
and non-pathogenic bacteria provide the indispensable CO2 and HCO3 − /protons to micro-
bial biosynthetic pathways, catalyzing the reversible reaction of CO2 hydration to HCO3 −
and H+(CO2+H2OHCO3−+H+)"""

HTML(tuner.infer_html(text))

In [4]:
CACHED_ARTICLES_DIR = "/local/sps-local/docs"

# read json document and return content as a json object
def get_json_from_file(json_file):
    with open(json_file, "r") as f:
        return json.loads(f.read())
    
def recursively_collect_files():
    filepaths = []
    for root, dirs, files in os.walk(CACHED_ARTICLES_DIR):
        for filename in files:            
            if filename.endswith(".json") and not filename.endswith("-cear.json"):                   
                filepaths.append(os.path.join(root, filename))
    return filepaths

def collect_relevant_sentences(filepath):    
    """
    collect sentences which have at least one chemical and one role
    """
    sentences = []    
    json_data = get_json_from_file(filepath)
    filehash = json_data["fileHash"]
    contenthash = json_data["contentHash"]
    texthash = json_data["textHash"]
    origpath = json_data["filepath"]

    pages = [page for page in json_data["pages"]]
    
    for page in pages:
        doc = nlp(page["text"])
        for sentence in doc.sents:                        
            page_number = int(page["pageNumber"])
            specials = tuner.infer(sentence.text)
            contains_chem = False
            contains_role = False
            for s in specials:
                if s[0] == 1:
                    contains_chem = True
                if s[0] % 2 == 1 and s[0] > 1:
                    contains_role = True
            if contains_chem and contains_role:                
                sentences.append((filepath, page_number, doc[sentence.start].idx, specials, sentence.text))
    return sentences
     

def load_relevant_sentences():
    if os.path.isfile("/local/sps-local/ner-role-extraction/relevant_sentences.pkl"):
        with open("/local/sps-local/ner-role-extraction/relevant_sentences.pkl", "rb") as f:
            return pickle.load(f)
    else:
        return []

def pickle_relevant_sentences():
    with open("/local/sps-local/ner-role-extraction/relevant_sentences.pkl", "wb") as f:
        pickle.dump(sentences, f)

# offset and limit for training

roughly 1000 files takes about 10 hours on a **NVIDIA RTX A5000** with **24 GB** of memory

In [None]:
filepaths = recursively_collect_files()

offset = 0
limit = 1000
sentences = load_relevant_sentences()

with open("/local/sps-local/ner-role-extraction/ner-role-inferer.log", "w") as log:    
    log.write(f"{datetime.now()}: starting at offset {offset} and stopping at {offset+limit}\n")
    log.flush()
    for filepath in tqdm(filepaths[offset:offset+limit]):    
        offset += 1
        if offset%10 == 0:
            log.write(f"{datetime.now()}: attempting saving {offset}\n")
            log.flush()
            pickle_relevant_sentences()
            log.write(f"{datetime.now()}: done {offset}\n")
            log.flush()
        sentences.extend(collect_relevant_sentences(filepath))

pickle_relevant_sentences()

  0%|                                                                                                                                                                                                                                 | 0/693 [00:00<?, ?it/s]