In [None]:
import os
from tqdm import tqdm
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import f1_score, roc_curve, auc
import glob

# ------------------ Model Setup ------------------
#model_path = "microsoft/Phi-4-mini-instruct"
#model_path = "meta-llama/Llama-2-7b-chat-hf"
#model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#model_path = "Qwen/Qwen1.5-7B-Chat"
model_path = "Qwen/Qwen2.5-7B-Instruct"
#model_path = "Qwen/Qwen2.5-14B-Instruct"
#model_path = "mistralai/Mistral-7B-v0.3"
#model_path = "google/gemma-2b-it"
#model_path = "BioMistral/BioMistral-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype="float16",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

label_texts = [" yes", " no"]  # Leading space important
label_token_ids = [tokenizer.encode(lab, add_special_tokens=False) for lab in label_texts]
print("Label tokens:", {lab: tokenizer.convert_ids_to_tokens(toks)
                        for lab, toks in zip(label_texts, label_token_ids)})

# with open("ml/standardized_test_report_names.json") as f:
#     test_report_names = set(json.load(f))

# DATA_ROOT = "ml/CRF_individual"
# all_reports = sorted([
#     os.path.join(DATA_ROOT, d)
#     for d in os.listdir(DATA_ROOT)
#     if os.path.isdir(os.path.join(DATA_ROOT, d)) and d in test_report_names
# ])

def build_prompt(report_text: str, sentence: str) -> str:
    # example prompt
    system_prompt = {
        "role": "system",
        "content": (
            "You are a medical AI assistant. For each sentence from a radiology report, respond with only one word: "
            "'yes' if the content is from a brain or head CT scan, and 'no' if it is from any other type of CT scan "
            "(such as facial bones, spine, neck, sinuses, etc). Respond only with 'yes' or 'no'.\n\n"
            "Example of a brain CT report:\n"
            "Sentence: EXAM: Head CT\n"
            "yes\n"
            "Sentence: INDICATION: Auditory and hallucinations. Confusion.\n"
            "yes\n"
            "Sentence: There is generalized atrophy with changes of chronic white matter microvascular disease.\n"
            "yes\n"
            "Sentence: No intracranial hemorrhage or signs of an acute infarction.\n"
            "yes\n"
            "Sentence: Ventricles and CSF spaces: There is no evidence of obstructive hydrocephalus.\n"
            "yes\n"
            "Sentence: The paranasal sinuses are clear.\n"
            "yes\n"
            "Sentence: IMPRESSION: Atrophy and chronic white matter changes.\n"
            "yes\n\n"
            "Now, I will paste the entire radiology report, and then specify one sentence from it."
        )
    }

    user_prompt = {
        "role": "user",
        "content": (
            f"Full report:\n{report_text}\n\n"
            f"Sentence: {sentence}\n"
            "With the full report as context, answer 'yes' or 'no': Is this sentence brain-related?\n"
            "Answer:"
        )
    }

    messages = [system_prompt, user_prompt]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def score_yes_probability(prompt_text: str) -> float:
    enc = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model(**enc)

    logits = out.logits[0, -1, :]  # next-token logits right after "Answer:"
    p_yes = torch.softmax(logits[[label_token_ids[0][0], label_token_ids[1][0]]], dim=0)[0].item()
    return p_yes


# ------------------ Main Loop ------------------
all_scores, all_labels, per_report_acc = [], [], []
total_correct, total_count = 0, 0

# for report_dir in tqdm(all_reports, desc="Reports"):
#     csv_path = os.path.join(report_dir, "report.csv")
#     if not os.path.exists(csv_path):
#         print(f"Missing: {csv_path}")
#         continue
SYNTHETIC_DIR = "MIMIC/synthetic_bundled_reports"
synthetic_csvs = sorted(
    os.path.join(SYNTHETIC_DIR, f)
    for f in os.listdir(SYNTHETIC_DIR)
    if f.endswith(".csv")
)


synthetic_csvs = [os.path.join(SYNTHETIC_DIR, f) for f in os.listdir(SYNTHETIC_DIR) if f.endswith('.csv')]

with open("ml/standardized_test_report_names.json") as f:
    test_report_names = set(json.load(f))

DATA_ROOT = "ml/CRF_individual"
all_reports = []
for d in os.listdir(DATA_ROOT):
    full_dir = os.path.join(DATA_ROOT, d)
    if os.path.isdir(full_dir) and d in test_report_names:
        csvs = glob.glob(os.path.join(full_dir, "*.csv"))
        all_reports.extend(sorted(csvs))
for csv_path in synthetic_csvs: # all_reports or synthetic_csvs
    print(csv_path)
    #print(f"Report: {os.path.basename(csv_path)} | Accuracy: {correct}/{count}")

    df = pd.read_csv(csv_path)
    sentences = df["Sentence"].tolist()
    labels = df["Brain Related"].tolist()

    report_text = " ".join(sentences)  # Full report as plain text
    preds, scores = [], []

    for sent, y in zip(sentences, labels):
        prompt = build_prompt(report_text, sent)
        score = score_yes_probability(prompt)
        scores.append(score)
        pred = 1 if score > 0.5 else 0
        preds.append(pred)
        all_scores.append(score)
        all_labels.append(y)

    correct = sum(int(p == y) for p, y in zip(preds, labels))
    count = len(labels)
    total_correct += correct
    total_count += count
    per_report_acc.append(correct / count if count > 0 else 0.0)
    torch.cuda.empty_cache()

all_scores_arr = np.asarray(all_scores, dtype=float)
all_labels_arr = np.asarray(all_labels, dtype=int)

results_list = []  # store (threshold, f1, accuracy)
thresholds = np.arange(0.0, 1.01, 0.01)
accs, f1s = [], []

for thresh in thresholds:
    preds_at_thresh = [1 if s > thresh else 0 for s in all_scores]
    acc = np.mean([p == y for p, y in zip(preds_at_thresh, all_labels)])
    f1 = f1_score(all_labels, preds_at_thresh)
    accs.append(acc)
    f1s.append(f1)
    results_list.append((thresh, f1, acc))
    print(f"Threshold {thresh:.2f}: Accuracy={acc:.4f}, F1={f1:.4f}")

df_thresholds = pd.DataFrame(results_list, columns=["threshold", "F1", "accuracy"])
df_thresholds.to_csv("slm_threshold_f1_acc.csv", index=False)
print("Saved threshold/F1/accuracy results to slm_threshold_f1_acc.csv")

plt.figure(figsize=(7, 4))
plt.plot(thresholds, accs, marker='o', label='Accuracy')
plt.plot(thresholds, f1s, marker='s', label='F1 Score')
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Accuracy & F1 Score vs Threshold')
plt.legend()
plt.grid(True, alpha=0.4)
plt.show()

thresholds = np.arange(0.0, 0.99, 0.01)
accs, f1s = [], []
for thresh in thresholds:
    preds_at_thresh = [1 if s > thresh else 0 for s in all_scores]
    acc = np.mean([p == y for p, y in zip(preds_at_thresh, all_labels)])
    f1 = f1_score(all_labels, preds_at_thresh)
    accs.append(acc)
    f1s.append(f1)
    print(f"Threshold {thresh:.2f}: Accuracy={acc:.4f}, F1={f1:.4f}")

plt.figure(figsize=(7, 4))
plt.plot(thresholds, accs, marker='o', label='Accuracy')
plt.plot(thresholds, f1s, marker='s', label='F1 Score')
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Accuracy & F1 Score vs Threshold')
plt.legend()
plt.grid(True, alpha=0.4)
plt.show()

fpr, tpr, _ = roc_curve(all_labels, all_scores)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, label=f"ROC curve (AUC={roc_auc:.4f})")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve: Brain vs Not Brain (Yes/No)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

plt.figure(figsize=(8, 5))
plt.hist([a * 100 for a in per_report_acc], bins=10, edgecolor='black')
plt.xlabel('Per-report Accuracy (%) with cutoff=0.5')
plt.ylabel('Number of Reports')
plt.title('Distribution of Per-report Accuracies')
plt.grid(True, axis='y', alpha=0.5)
plt.show()

print(f"\nOverall Accuracy: {total_correct / total_count:.4f} ({total_correct}/{total_count})")

from sklearn.metrics import f1_score

results_list = []  # (threshold, f1_class1, macro_f1, accuracy)
thresholds = np.arange(0.0, 1.01, 0.01)

accs, f1_cls1_list, macro_f1_list = [], [], []

for thresh in thresholds:
    preds = (all_scores_arr > thresh).astype(int)

    acc = (preds == all_labels_arr).mean()
    accs.append(acc)

    f1_cls1 = f1_score(all_labels_arr, preds, pos_label=1, zero_division=0)
    f1_cls1_list.append(f1_cls1)

    macro_f1 = f1_score(all_labels_arr, preds, average="macro", zero_division=0)  # NEW
    macro_f1_list.append(macro_f1)

    results_list.append((thresh, f1_cls1, macro_f1, acc))
    print(f"Threshold {thresh:.2f}: Acc={acc:.4f}, F1_1={f1_cls1:.4f}, MacroF1={macro_f1:.4f}")

df_thresholds = pd.DataFrame(results_list, columns=["threshold", "F1_class1", "MacroF1", "accuracy"])
df_thresholds.to_csv("slm_threshold_f1_acc.csv", index=False)
print("Saved threshold metrics to slm_threshold_f1_acc.csv")

plt.figure(figsize=(7, 4))
plt.plot(thresholds, accs,            marker='o', label='Accuracy')
plt.plot(thresholds, f1_cls1_list,    marker='s', label='F1 (Class 1)')
plt.plot(thresholds, macro_f1_list,   marker='^', label='Macro F1')   # NEW
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Accuracy, F1 (Class 1), Macro F1 vs Threshold')
plt.legend()
plt.grid(True, alpha=0.4)
plt.show()

