In [None]:
# !pip uninstall -y torch torchvision torchaudio
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install --upgrade transformers accelerate torch

In [None]:
# from huggingface_hub import login
# login()

In [None]:
from warnings import filterwarnings
filterwarnings("ignore")

In [None]:
# prompt: mount g drive

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd 'drive/MyDrive/Colab Notebooks/assignament/cadec/'

/content/drive/MyDrive/Colab Notebooks/assignament/cadec


In [None]:
!ls

meddra	original  sct  text


In [None]:
# !unzip CADEC.v2.zip

In [None]:
!pip uninstall -y bitsandbytes
!pip install -U --no-cache-dir bitsandbytes
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install langgraph

In [None]:
import os
import pandas as pd

# Define folder paths
text_folder = "text"
annotation_folder = "original"

# Function to parse .ann files and organize entities into separate lists
def parse_ann(file_path):
    annotations = {"ADR": [], "Symptom": [], "Drug": []}  # Initialize categories
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")

            # Ignore invalid lines
            if len(parts) != 3:
                continue

            tag_info, entity = parts[1], parts[2]
            tag_details = tag_info.split()

            if len(tag_details) < 3:
                continue

            tag_type = tag_details[0]

            # Store entities in their respective category
            if tag_type in annotations:
                annotations[tag_type].append(entity)

    return annotations

# Collect data
data = []

for text_file in os.listdir(text_folder):
    if text_file.endswith(".txt"):
        text_path = os.path.join(text_folder, text_file)
        ann_file = text_file.replace(".txt", ".ann")
        annotation_path = os.path.join(annotation_folder, ann_file)

        # Read text data
        with open(text_path, "r", encoding="utf-8") as f:
            text_data = f.read().strip()

        # Read annotation data
        annotations = parse_ann(annotation_path) if os.path.exists(annotation_path) else {"ADR": [], "Symptom": [], "Drug": []}

        # Append data as a row
        data.append([
            text_data,
            annotations["ADR"],
            annotations["Symptom"],
            annotations["Drug"]
        ])

# Convert to DataFrame
df = pd.DataFrame(data, columns=["text", "ADE", "Symptom", "Drug"])

In [None]:
df

In [None]:
import json
import logging
import nltk
import re
import torch
import requests
import pandas as pd

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langgraph.graph import StateGraph
from typing import TypedDict, Dict, List

# Set up logging and NLTK
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
nltk.download('punkt', quiet=True)

# Define 4-bit quantization config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,         # Enable 4-bit quantization
    bnb_4bit_compute_dtype=torch.float16,  # Use FP16 computation
    bnb_4bit_use_double_quant=True,  # Extra compression for lower VRAM usage
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",  # Automatically select GPU if available
    quantization_config=quant_config  # Apply quantization
)

# Define the Agent State Schema using dictionaries
class AgentState(TypedDict):
    text: str
    preprocessed_text: str
    extracted_entities: Dict[str, List[str]]
    standardized_entities: Dict[str, List[str]]
    gold_data: Dict    # Contains the gold standard values from the DF row
    verification_result: Dict   # Will store the verification result details

def preprocess_model_output(output: str, split_marker: str) -> Dict:
    """
    Extract a JSON object from the model's output by splitting on a given marker.
    """
    try:
        parts = output.split(split_marker)
        if len(parts) > 1:
            json_part = parts[1]
            match = re.search(r'({.*?})', json_part)
            if match:
                json_str = match.group(1)
                return json.loads(json_str)
    except (json.JSONDecodeError, AttributeError) as e:
        logging.error("Error parsing JSON: %s", e)
    return {}

def genrate_answer(prompt: str) -> str:
    """
    Generate a response from the language model.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=500)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def expand_abbreviations(state: AgentState) -> AgentState:
    """
    Expand medical abbreviations in the input text.
    """
    input_text = state["text"]
    prompt = f"""
You are a medical expert. Identify and expand any medical abbreviations in the text into their full forms.
Replace each abbreviation with its expanded form while preserving the sentence structure.
Return only a JSON object without any extra text.

Example:
Input: "He has COPD and a history of MI."
Output:
{{ "expanded_text": "He has Chronic Obstructive Pulmonary Disease and a history of Myocardial Infarction." }}

Process the following text and return only the JSON response:
"{input_text}"
"""
    response = genrate_answer(prompt)
    processed = preprocess_model_output(response, "Process the following text and return only the JSON response:")
    if processed.get("expanded_text"):
        state["preprocessed_text"] = processed["expanded_text"]
        logging.info("Abbreviation expansion completed.")
    else:
        state["preprocessed_text"] = state["text"]
        logging.error("Problem in abbreviation expansion; using original text.")
    return state

def classify_detect_entities(state: AgentState) -> AgentState:
    """
    Extract and classify entities (Drug, ADE, Symptom/Disease) from the input text.
    """
    input_text = state["text"]
    prompt = f"""
You are an expert medical NLP model. Identify and extract entities from the following text and classify each entity into one of these categories:
- Drug: Any medication or pharmaceutical compound.
- ADE: Any adverse drug event or side effect.
- Symptom/Disease: Any medical condition, illness, or symptom.

Text: "{input_text}"

Return your response as a JSON object in the format:
{{
  "Drug": ["drug1", "drug2"],
  "ADE": ["ade1", "ade2"],
  "Symptom/Disease": ["symptom1", "symptom2"]
}}

If no entities are found for a category, return an empty list. Return only the JSON response.
"""
    response = genrate_answer(prompt)
    processed = preprocess_model_output(response, "Return only the JSON response:")
    if processed:
        state["extracted_entities"] = processed
        logging.info("Entity detection completed.")
    else:
        logging.error("Problem in entity detection; no entities extracted.")
        state["extracted_entities"] = {}
    return state

def verify_entities(state: AgentState) -> AgentState:
    """
    Verification step: Compare extracted entities with the gold data from the DataFrame.
    For each category, determine:
      - 'correct': extracted entities that match the gold data.
      - 'incorrect': entities extracted that are not in the gold data.
      - 'missing': entities that are in the gold data but were not extracted.
    The gold data from the DF is assumed to have keys:
      "Drug", "ADE", and "Symptom" (for symptoms, matching extraction's "Symptom/Disease").
    """
    def verify_category(extracted: List[str], gold: List[str]) -> Dict[str, List[str]]:
        # Normalize both lists to lowercase stripped strings for case-insensitive comparison.
        extracted_set = set([str(e).strip().lower() for e in extracted])
        gold_set = set([str(g).strip().lower() for g in gold])
        correct = list(extracted_set.intersection(gold_set))
        incorrect = list(extracted_set - gold_set)
        missing = list(gold_set - extracted_set)
        return {"correct": correct, "incorrect": incorrect, "missing": missing}

    verification_result = {}
    # Define a mapping between extracted entity keys and gold_data keys.
    gold_keys_mapping = {
        "Drug": "Drug",
        "ADE": "ADE",
        "Symptom/Disease": "Symptom"
    }

    for extracted_category, gold_category in gold_keys_mapping.items():
        extracted_list = state["extracted_entities"].get(extracted_category, [])
        gold_list = state["gold_data"].get(gold_category, [])
        verification_result[extracted_category] = verify_category(extracted_list, gold_list)

    state["verification_result"] = verification_result
    logging.info("Verification using gold data completed.")
    return state

def create_agent_graph():
    """
    Create the agent graph with sequential nodes.
    """
    workflow = StateGraph(AgentState)
    workflow.add_node("preprocessing", expand_abbreviations)
    workflow.add_node("standardization", classify_detect_entities)
    workflow.add_node("verify_step", verify_entities)

    workflow.add_edge("preprocessing", "standardization")
    workflow.add_edge("standardization", "verify_step")

    workflow.set_entry_point("preprocessing")
    return workflow

# Create the agent graph
agent_graph = create_agent_graph()

def process_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """
    Iterate over DataFrame rows, convert each row into an AgentState dictionary,
    process it through the agent graph pipeline, and return a DataFrame with the results.
    The 'gold_data' field is populated with the original row data.
    """
    results = []
    for idx, row in df.iterrows():
        state: AgentState = {
            "text": row["text"],
            "preprocessed_text": "",
            "extracted_entities": {},
            "standardized_entities": {},
            "gold_data": row.to_dict(),  # Gold data from the DataFrame row (must contain "Drug", "ADE", "Symptom")
            "verification_result": {}
        }
        processed_state = agent_graph.run(state)
        results.append(processed_state)
    return pd.DataFrame(results)

if __name__ == "__main__":
    # Example: Read DataFrame from a CSV file (ensure the CSV has at least a "text" column,
    processed_df = process_dataframe(df)
    print(processed_df)