In [None]:
%pip install numpy, transformers, torch, numpy, huggingface_hub, datasets

In [None]:
import numpy as np
from transformers import pipeline, AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer, TrainingArguments
from collections import defaultdict
import torch
import spacy

In [None]:
from datasets import load_dataset
data = load_dataset("zhengyun21/PMC-Patients")

In [None]:
from generateSentiment import generateSentiment

# Add a preliminary sentiment label to the dataset
sentiments = generateSentiment(data["train"])
data["train"] = data["train"].add_column("sentiment_label", sentiments)

In [None]:
from findEntities import findEntities

entities = findEntities(data["train"])
data["train"] = data["train"].add_column("entities", entities)

In [None]:
train_split = 0.8
val_split = 0.1
test_split = 1 - train_split - val_split
assert train_split + val_split + test_split == 1

data = data["train"].train_test_split(
    test_size=0.2, seed=0
)
test_val_split = data["test"].train_test_split(
    test_size=0.5, seed=0
)
# Combine splits into a single dataset
split_dataset = {
    "train": data["train"],
    "validation": test_val_split["train"],
    "test": test_val_split["test"],
}

In [None]:
from huggingface_hub import HfFolder
token = HfFolder.get_token()

pretrained_model_name = "chaoyi-wu/PMC_LLAMA_7B"
llama_model_name = "meta-llama/Llama-3.2-1B"
model_name = llama_model_name if token else pretrained_model_name
# model_name = pretrained_model_name
pipe = pipeline("text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1)

In [None]:
# TODO: Do the actual finetuning
tokenizer = AutoTokenizer.from_pretrained(model_name)

ner_training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
)

sentiment_training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
)

ner_trainer = Trainer(
    model=ner_model,
    args=ner_training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["validation"],
    tokenizer=tokenizer,
)

sentiment_trainer = Trainer(
    model=sentiment_model,
    args=sentiment_training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["validation"],
    tokenizer=tokenizer,
)

In [None]:
# Extract medical entities from the clinical note
# Classify them into conditions, treatments, and outcomes
def named_entity_recognition(note):
    prompt = f"""
Clinical Note: {note}
You are a specialized medical language model designed to extract critical information from clinical notes. Given the clinical note, identify and extract the following entities:

Symptoms: Physical or psychological conditions reported by the patient.
Diagnoses: Medical conditions or diseases identified by the clinician.
Treatments/Medications: Procedures, therapies, or drugs mentioned.
Outcomes: Observations or indications of the patient's response to treatment or prognosis
Return the extracted entities categorized into the corresponding groups. Use accurate medical terminology, and only include entities explicitly or implicitly mentioned in the text.
Example Input:
"Patient reports severe fatigue and joint pain. Diagnosed with rheumatoid arthritis. Prescribed methotrexate. Follow-up shows improved joint mobility but persistent mild fatigue. Recent ESR levels have decreased but are still elevated."
Expected Output:
Entities:
Symptoms: severe fatigue, joint pain
Diagnoses: rheumatoid arthritis
Treatments/Medications: methotrexate
Outcomes: improved joint mobility, persistent mild fatigue
    """
    result = pipe(note)
    print(result)
    for entity in result:
        print(f"Entity: {entity['word']} | Label: {entity['entity']} | Score: {entity['score']}")
    return result[0]["generated_text"].split("Entities:")[-1].strip()

In [None]:
# Determine the patient's recovery risk sentiment
# Positive (low risk), neutral (medium risk), or negative (high risk)
def sentiment_analysis(note, entities):
    prompt = f"""
You are an expert medical language model tasked with analyzing clinical notes to determine patient recovery outcomes. Given a clinical note and extracted entities, assess the sentiment of the note with respect to the patient's recovery risk.
Clinical Note: {note}
Entities: {entities}
Assess the sentiment of the clinical note with respect to the patient's recovery risk.
Positive: Indicators of improvement or a high likelihood of recovery.
Neutral: Indicators of stability or uncertain outcomes.
Negative: Indicators of deterioration or a low likelihood of recovery.
Example Input:
"Patient presents with severe dyspnea and elevated BNP levels. Treatment initiated with diuretics shows mild improvement. However, recurring chest pain persists, and cardiac markers remain elevated."
Expected Output:
Sentiment: Neutral
    """
    result = pipe(note)
    print(result)
    return result[0]["generated_text"].split("Sentiment:")[-1].strip()

In [None]:
labeled_notes = []
for i, entry in enumerate(data["train"]):
    title = entry["title"]
    note = entry["patient"]
    entities = named_entity_recognition(note)
    sentiment = sentiment_analysis(note, entities)
    labeled_notes.append({"title": title, "note": note, "entities": entities, "sentiment": sentiment})
    break
print(labeled_notes[0]["entities"])
print(labeled_notes[0]["sentiment"])

In [None]:
entity_sentiment_map = defaultdict(lambda: {"positive": 0, "neutral": 0, "negative": 0})

for entry in labeled_notes:
    sentiment = entry["sentiment"]
    entities = entry["entities"]
    for entity in entities:
        entity_sentiment_map[entity][sentiment] += 1