In [None]:
!pip install -q evaluate bert_score rouge_score sacremoses


In [None]:
# =============================================================================
# DEPENDENCIES
# =============================================================================


import os
import re
import random
import logging
from typing import List

import torch
import numpy as np
import nltk
from tqdm import tqdm


from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
GenerationConfig,
)
from peft import (
LoraConfig,
get_peft_model,
TaskType,
PeftModel,
PeftConfig,
)


import evaluate
from bert_score import score as bert_score


# =============================================================================
# GLOBAL CONFIGURATION
# =============================================================================


SEED = 42
MODEL_NAME = "google/flan-t5-large"
DATASET_ID = "omi-health/medical-dialogue-to-soap-summary"
OUTPUT_DIR = "./flan_t5_large_soap_lora"
DRIVE_PATH = "/content/drive/MyDrive/KG_Medical_SOAP_Model"


MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 512
MAX_EVAL_SAMPLES = 100


# =============================================================================
# SETUP
# =============================================================================


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)




def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)




set_seed(SEED)

In [None]:
# =============================================================================
# DATA LOADING
# =============================================================================


raw_ds = load_dataset(DATASET_ID)


if "validation" not in raw_ds:
    split = raw_ds["train"].train_test_split(test_size=0.1, seed=SEED)
    raw_ds["train"] = split["train"]
    raw_ds["validation"] = split["test"]

In [None]:
# =============================================================================
# TOKENIZER & MODEL
# =============================================================================


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.bfloat16
)


lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q", "v"],
)


model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

In [None]:
# =============================================================================
# PREPROCESSING
# =============================================================================


def preprocess_function(batch):
    prompts = [
        "Generate a structured medical SOAP note from this doctor-patient dialogue:\n\n" + d
        for d in batch["dialogue"]
    ]


    model_inputs = tokenizer(
        prompts, truncation=True, max_length=MAX_INPUT_LENGTH
    )


    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["soap"], truncation=True, max_length=MAX_TARGET_LENGTH
        )


    labels["input_ids"] = [
        [(tok if tok != tokenizer.pad_token_id else -100) for tok in seq]
        for seq in labels["input_ids"]
    ]


    model_inputs["labels"] = labels["input_ids"]
    return model_inputs




tokenized_ds = raw_ds.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_ds["train"].column_names,
)

In [None]:
# =============================================================================
# MEDICAL ENTITY EXTRACTION (HALLUCINATION)
# =============================================================================

from transformers import pipeline

ner_pipeline = pipeline(
    "ner",
    model="d4data/biomedical-ner-all",
    tokenizer="d4data/biomedical-ner-all",
    aggregation_strategy="simple",
)

def extract_medical_terms(text: str) -> set:
    ents = ner_pipeline(text[:2000])
    return {e["word"].lower() for e in ents}


In [None]:
# =============================================================================
# METRICS
# =============================================================================


rouge_metric = evaluate.load("rouge")


def compute_metrics(eval_pred):
    preds, labels = eval_pred


    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)


    decoded_preds = ["\n".join(nltk.sent_tokenize(p)) for p in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(l)) for l in decoded_labels]


    scores = rouge_metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True,
    )


    return {k: round(v * 100, 4) for k, v in scores.items()}


def hallucination_score(dialogue: str, prediction: str) -> float:
    """
    Fraction of medical entities in prediction
    that do NOT appear in the input dialogue.
    """
    pred_ents = extract_entities(prediction)
    input_ents = extract_entities(dialogue)

    if not pred_ents:
        return 0.0

    hallucinated = pred_ents - input_ents
    return len(hallucinated) / len(pred_ents)

In [None]:
# =============================================================================
# TRAINING
# =============================================================================


training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-3,
    num_train_epochs=3,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_strategy="steps",
    logging_steps=100,
    eval_strategy="steps",
    save_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=True,
    predict_with_generate=True,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    generation_config=GenerationConfig.from_pretrained(
        MODEL_NAME,
        max_new_tokens=512,
        repetition_penalty=2.5,
        no_repeat_ngram_size=3,
    ),
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    compute_metrics=compute_metrics,
)

In [None]:
# =============================================================================
# GENERATION
# =============================================================================

def generate_soap(model, tokenizer, dialogue: str, device: str) -> str:
    prompt = (
        "Generate a structured medical SOAP note from this doctor-patient dialogue:\n\n"
        + dialogue
    )


    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_LENGTH,
    ).to(device)


    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=MAX_TARGET_LENGTH,
            num_beams=3,
            repetition_penalty=2.5,
            no_repeat_ngram_size=3,
        )


    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
# =============================================================================
# EVALUATION CLASS
# =============================================================================


class SOAPEvaluator:
    def __init__(self):
        self.rouge = evaluate.load("rouge")
        self.bleu = evaluate.load("bleu")


    @staticmethod
    def clean(text: str) -> str:
        text = text.lower().strip()
        return re.sub(r"\s+", " ", text)


    def lexical(self, refs: List[str], preds: List[str]):
        refs = [self.clean(r) for r in refs]
        preds = [self.clean(p) for p in preds]


        rouge = self.rouge.compute(predictions=preds, references=refs)
        bleu = self.bleu.compute(predictions=preds, references=[[r] for r in refs])


        return {
            "rouge1": round(rouge["rouge1"], 4),
            "rouge2": round(rouge["rouge2"], 4),
            "rougeL": round(rouge["rougeL"], 4),
            "bleu": round(bleu["bleu"], 4),
        }


    def bert(self, refs: List[str], preds: List[str]):
        P, R, F = bert_score(preds, refs, lang="en", verbose=False)
        return {
            "bertscore_precision": round(P.mean().item(), 4),
            "bertscore_recall": round(R.mean().item(), 4),
            "bertscore_f1": round(F.mean().item(), 4),
        }


    def hallucination(self, dialogues: List[str], preds: List[str]):
        scores = [
            hallucination_score(d, p)
            for d, p in zip(dialogues, preds)
        ]
        return {
            "hallucination_rate": round(float(np.mean(scores)), 4)
        }


In [None]:
FLUENCY_PROMPT = """
You are an experienced physician and medical editor.

Evaluate the FLUENCY of the following SOAP note.

Consider:
- grammatical correctness
- clarity and coherence
- professional medical tone
- readability for clinical documentation

Score on a scale from 0 to 10:
- 0 = incoherent, ungrammatical, unusable
- 10 = perfectly fluent, natural, indistinguishable from a human physician

SOAP NOTE:
<<<
{generated_soap}
>>>

Return ONLY a single integer score between 0 and 10.
"""



CONSISTENCY_PROMPT = """
You are an experienced physician.

Evaluate the CONSISTENCY of the generated SOAP note
with respect to the reference physician SOAP note.

Consider:
- whether key clinical facts match
- whether important information is missing
- whether incorrect or fabricated details are introduced
- whether assessment and plan align with the reference

Score on a scale from 0 to 10:
- 0 = completely inconsistent or incorrect
- 10 = fully consistent and clinically aligned

REFERENCE SOAP:
<<<
{reference_soap}
>>>

GENERATED SOAP:
<<<
{generated_soap}
>>>

Return ONLY a single integer score between 0 and 10.
"""


In [None]:
!pip install -q google-generativeai

import os
import re
os.environ["GEMINI_API_KEY"] = "AIzaSyDmKkVhVFIuBcaxPdBpiZgOBqgmmVPbE_oabcderghsn"

In [None]:
# =============================================================================
# GEMINI JUDGING
# =============================================================================


import google.generativeai as genai
import os

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

GEMINI_MODEL = "gemini-2.0-flash-lite"
judge_model = genai.GenerativeModel(GEMINI_MODEL)

def parse_score(text: str) -> int:
    match = re.search(r"\b([0-9]|10)\b", text)
    if match:
        return int(match.group(1))
    return 0  # fallback


def judge_fluency(generated_soap: str, model: str = GEMINI_MODEL) -> int:
    prompt = FLUENCY_PROMPT.format(generated_soap=generated_soap)

    response = judge_model.generate_content(
        prompt,
        generation_config={
            "temperature": 0.0,
            "max_output_tokens": 5,
        },
    )

    output = response.text.strip()
    return parse_score(output)


def judge_consistency(
    reference_soap: str,
    generated_soap: str,
    model: str = GEMINI_MODEL,
) -> int:
    prompt = CONSISTENCY_PROMPT.format(
        reference_soap=reference_soap,
        generated_soap=generated_soap,
    )

    response = judge_model.generate_content(
        prompt,
        generation_config={
            "temperature": 0.0,
            "max_output_tokens": 5,
        },
    )

    output = response.text.strip()
    return parse_score(output)


In [None]:
fluency_scores = []
consistency_scores = []
MAX_JUDGE_SAMPLES = 50

for i, (ref, pred) in enumerate(
    list(zip(refs, preds))[:MAX_JUDGE_SAMPLES]
):
    try:
        fluency = judge_fluency(pred)
        consistency = judge_consistency(ref, pred)

        fluency_scores.append(fluency)
        consistency_scores.append(consistency)

    except Exception as e:
        print(f"[Judge failed at sample {i}]: {e}")
        break


results = {
    "fluency_mean": round(float(np.mean(fluency_scores)), 3),
    "fluency_std": round(float(np.std(fluency_scores)), 3),
    "consistency_mean": round(float(np.mean(consistency_scores)), 3),
    "consistency_std": round(float(np.std(consistency_scores)), 3),
}

print(results)

In [None]:
# =============================================================================
# MAIN
# =============================================================================


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"


    logger.info("Starting training...")
    trainer.train()


    logger.info("Saving model...")
    model.save_pretrained(DRIVE_PATH)
    tokenizer.save_pretrained(DRIVE_PATH)


    logger.info("Running evaluation...")
    eval_ds = load_dataset(DATASET_ID, split="test").select(range(MAX_EVAL_SAMPLES))


    preds, refs = [], []
    for sample in tqdm(eval_ds):
        preds.append(generate_soap(model, tokenizer, sample["dialogue"], device))
        refs.append(sample["soap"])


    evaluator = SOAPEvaluator()
    print(evaluator.lexical(refs, preds))
    print(evaluator.bert(refs, preds))
    print(evaluator.hallucination([sample["dialogue"] for sample in eval_ds], preds))