In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel

In [2]:
# Define NER BIO labels and mapping
label_to_id = {
    "O": 0,
    "B-Chemical": 1,
    "I-Chemical": 2,
    "B-Disease": 3,
    "I-Disease": 4,
}

# Define relation types and mapping
relation_label_to_id = {
    "CID": 0,
    "no_relation": 1,
}

relation_label_to_id = {
    "chem_disease:marker/mechanism": 0,
    "chem_disease:therapeutic": 1,
    "chem_gene:affects^activity": 2,
    "chem_gene:affects^binding": 3,
    "chem_gene:affects^expression": 4,
    "chem_gene:affects^localization": 5,
    "chem_gene:affects^metabolic_processing": 6,
    "chem_gene:affects^transport": 7,
    "chem_gene:decreases^activity": 8,
    "chem_gene:decreases^expression": 9,
    "chem_gene:decreases^metabolic_processing": 10,
    "chem_gene:decreases^transport": 11,
    "chem_gene:increases^activity": 12,
    "chem_gene:increases^expression": 13,
    "chem_gene:increases^metabolic_processing": 14,
    "chem_gene:increases^transport": 15,
    "gene_disease:marker/mechanism": 16,
    "gene_disease:therapeutic": 17,
}
num_rel_labels = len(relation_label_to_id)

## Load Your NER Model

In [3]:
class MultiTaskModelSciBERT(torch.nn.Module):
    def __init__(self, encoder_name, num_ner_labels, num_rel_labels):
        super(MultiTaskModelSciBERT, self).__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        self.ner_head = torch.nn.Linear(self.encoder.config.hidden_size, num_ner_labels)
        self.re_head = torch.nn.Linear(self.encoder.config.hidden_size, num_rel_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # For NER
        pooled_output = outputs.pooler_output  # For Relation Extraction
        ner_logits = self.ner_head(sequence_output)
        re_logits = self.re_head(pooled_output)
        return ner_logits, re_logits

In [4]:
# Load the tokenizer and model for NER
tokenizer_ner = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_cased")
ner_model = MultiTaskModelSciBERT(
    "allenai/scibert_scivocab_cased",
    num_ner_labels=len(label_to_id),  # Number of NER labels
    num_rel_labels=len(relation_label_to_id)  # Number of RE labels
)
ner_model.load_state_dict(torch.load("/data/user/pperla/ondemand/NLP Project/multi_task_model_sci.pth"))
ner_model.eval()

# Define label mappings
id_to_label = {v: k for k, v in label_to_id.items()}

  return self.fget.__get__(instance, owner)()


## Load  RE Model

In [5]:
re_tokenizer = AutoTokenizer.from_pretrained("/data/user/pperla/ondemand/NLP Project/Aziz/re_tokenizer")
re_model = AutoModelForSequenceClassification.from_pretrained("/data/user/pperla/ondemand/NLP Project/Aziz/re_model")
re_model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

## NER Prediction Function

In [6]:
def predict_entities(input_text):
    # Tokenize input text
    inputs = tokenizer_ner(
        input_text.strip().split(),
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=128,
        is_split_into_words=True,
    )

    # Predict with NER model
    with torch.no_grad():
        ner_logits, _ = ner_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

    # Decode predictions
    ner_preds = torch.argmax(ner_logits, dim=-1)
    tokens = tokenizer_ner.convert_ids_to_tokens(inputs["input_ids"][0])
    decoded_tags = [id_to_label[label_id] for label_id in ner_preds[0].tolist()]

    # Map tokens to tags
    result = list(zip(tokens, decoded_tags))
    result = [(token, tag) for token, tag in result if token not in ["[CLS]", "[SEP]", "[PAD]"]]

    # Extract entities
    extracted_entities = {"Chemical": [], "Disease": []}
    for token, tag in result:
        if tag.startswith("B-") or tag.startswith("I-"):
            entity_type = tag.split("-")[1]
            extracted_entities[entity_type].append(token)

    return extracted_entities

## RE Prediction Function

In [7]:
def predict_relations(text, entities, re_model, re_tokenizer):
    predictions = []
    # Generate all pairs of entities
    chemicals = entities["Chemical"]
    diseases = entities["Disease"]

    for chemical in chemicals:
        for disease in diseases:
            # Format text for RE model
            modified_text = text.replace(chemical, f"[CHEM] {chemical} [/CHEM]")
            modified_text = modified_text.replace(disease, f"[DISEASE] {disease} [/DISEASE]")

            # Tokenize and predict
            inputs = re_tokenizer(modified_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
            with torch.no_grad():
                outputs = re_model(**inputs)
                logits = outputs.logits
                prediction = logits.argmax(dim=-1).item()

            # Append prediction
            relation = "CID" if prediction == 1 else "no relation"
            predictions.append({"chemical": chemical, "disease": disease, "relation": relation})

    return predictions

## Full Integration Pipeline

In [8]:
def integrated_pipeline(input_text):
    # Step 1: NER - Extract entities
    entities = predict_entities(input_text)
    print("Extracted Entities:", entities)

    # Step 2: RE - Predict relations
    relations = predict_relations(input_text, entities, re_model, re_tokenizer)
    print("Predicted Relations:", relations)

    # Save results to a JSON file
    output_data = {
        "input_text": input_text,
        "extracted_entities": entities,
        "predicted_relations": relations,
    }
    output_file = "/data/user/pperla/ondemand/NLP Project/integrated_results.json"
    with open(output_file, "w") as f:
        json.dump(output_data, f, indent=4)

    print(f"Results saved to {output_file}")

In [9]:
# **Step 6: Run the Pipeline**
input_text = """
Acetaminophen is used to treat mild to moderate pain and fever. It is associated with liver damage if taken in high doses.
Patients with hepatitis should avoid its usage.
"""
#Patient Medical Records:
input_text_2 = """The patient, a 45-year-old male, presented with complaints of persistent chest pain radiating to the left arm, exacerbated by physical activity, and relieved by rest. 
Past medical history includes hypertension and type 2 diabetes. No history of smoking or alcohol consumption was noted."""

#Pathology Reports:
input_text_3 = """Biopsy of the left breast mass reveals infiltrating ductal carcinoma, Grade II, with positive hormone receptor status. 
Margins are free of malignancy."""

#Pharmacovigilance Records:
input_text_4 = """A 30-year-old female reported severe dizziness and rash following the administration of drug X. 
Symptoms appeared within 1 hour of ingestion and resolved after discontinuation of the drug."""

integrated_pipeline(input_text) #input_text, input_text_2, input_text_3, input_text_4

Extracted Entities: {'Chemical': ['acet', '##amino', '##phen'], 'Disease': ['pain', 'fever', 'liver', 'damage', 'hepatitis']}
Predicted Relations: [{'chemical': 'acet', 'disease': 'pain', 'relation': 'no relation'}, {'chemical': 'acet', 'disease': 'fever', 'relation': 'no relation'}, {'chemical': 'acet', 'disease': 'liver', 'relation': 'no relation'}, {'chemical': 'acet', 'disease': 'damage', 'relation': 'no relation'}, {'chemical': 'acet', 'disease': 'hepatitis', 'relation': 'no relation'}, {'chemical': '##amino', 'disease': 'pain', 'relation': 'no relation'}, {'chemical': '##amino', 'disease': 'fever', 'relation': 'no relation'}, {'chemical': '##amino', 'disease': 'liver', 'relation': 'no relation'}, {'chemical': '##amino', 'disease': 'damage', 'relation': 'no relation'}, {'chemical': '##amino', 'disease': 'hepatitis', 'relation': 'no relation'}, {'chemical': '##phen', 'disease': 'pain', 'relation': 'no relation'}, {'chemical': '##phen', 'disease': 'fever', 'relation': 'no relation'}