## Final Hybrid code

In [1]:
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer, util
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from sklearn.metrics import classification_report
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [None]:
val_test_2009_data=pd.read_csv("val_data_pred.csv")

In [None]:
true_labels = val_test_2009_data["output"].tolist()

In [3]:
from torch.nn.functional import softmax

# Define your label set in the exact form the model outputs
label_list = ["Death", "Injury", "Device Malfunction"]

In [4]:
# %% [Load Quantized Base Model + Tokenizer]
model_name = "microsoft/Phi-3-mini-4k-instruct"
compute_dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.60s/it]


In [5]:
model = PeftModel.from_pretrained(base_model, "trained-model")
model.eval()
model.config.use_cache = False  # <- THIS IS CRUCIAL

In [None]:
preds = []
probs_list = []

for prompt in tqdm(val_test_2009_data["prompt"].tolist(), desc="Generating Predictions"):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,  # small since classification
            do_sample=False,
            use_cache=False,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
            output_scores=True
        )

    # Decode full prediction
    decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    prediction = decoded.split("Answer:")[-1].strip().split("\n")[0]
    preds.append(prediction)

    # Get token-level probabilities
    # We take the first generated token after the input context
    transition_scores = outputs.scores  # list of logits per generation step
    if transition_scores:  
        step_logits = transition_scores[0]  # first step logits
        step_probs = softmax(step_logits, dim=-1)
        label_probs = {}
        for label in label_list:
            label_token_id = tokenizer(label, add_special_tokens=False).input_ids[0]
            label_probs[label] = step_probs[0, label_token_id].item()
        probs_list.append(label_probs)
    else:
        probs_list.append({label: None for label in label_list})
    del inputs, outputs, step_logits, step_probs
    torch.cuda.empty_cache()

In [None]:
val_test_2009_data["new_pred_label"]=preds
val_test_2009_data["predicted_probability"]=probs_list

## Keybert extraction on predicted label 

In [None]:
DICT_TERMS = {
    "Device Malfunction": [
        # General device failure
        "device malfunction", "equipment malfunction", "malfunction during use",
        "device not working", "device failed", "failure during operation",
        "failure to function", "malfunction on deployment", "malfunction after use",

        # Deployment / operational issues
        "deployment failure", "incomplete deployment", "unable to deploy",
        "failed to deploy", "partial deployment", "deployment interrupted",
        "failed to lock", "locking issue", "activation failure",
        "device misfire", "delivery issue", "misalignment of device",

        # Structural / mechanical failures
        "device damage", "component breakage", "handle breakage",
        "material integrity issue", "broken tip", "torn sleeve",
        "kinked sheath", "catheter kink", "component dislodged",
        "sealant exposed", "loose seal", "sealant leakage", "sealant failure",

        # Blockages and leaks
        "balloon leak", "balloon burst", "balloon deflation",
        "unable to inflate", "leak in system", "fluid leakage",

        # Obstruction / entrapment
        "device stuck", "entrapment issue", "obstructed device",
        "retained device", "unable to remove device",

        # Electronics / software
        "system error", "software error", "firmware bug",
        "power failure", "battery failure", "low battery alert",
        "display malfunction", "false alarm", "incorrect reading",


        "cuff miss", "quality issue", "reported difficulty", "device returned", 
        "device behavior", "device performance", "reported malfunction", 
        "returned for analysis", "proglide failure", "mechanical issue",
        "clicking sound", "double click", "clicks on activation",
        "device click", "unexpected click", "plunger click",
        "audible click", "strange sound", "device noise", 
        "mechanical feedback", "anchor failed", "footplate broken", 
        "collagen plug issue", "plug misplacement", "clip misfire", 
        "clip dislodged", "anchor not seated", "prostar failure",
        "device mispositioned", "poor seal", "incomplete closure", 
        "deployment drift", "suture break", "clip snapped", "cord rupture", 
        "vcd jammed", "advancer stuck", "release failure", 
        "footing not deployed", "clip ejected", "seal not achieved", "device knot",
        "deployment issue", "catheter jammed", "stuck device"
    ],

    "Injury": [
        # Bleeding / hematoma
        "bleeding", "excessive bleeding", "control of bleeding",
        "inability to stop bleeding", "hemostasis failure", "manual compression",
        "hematoma", "hematoma formation", "site hematoma",

        # Pain & discomfort
        "pain", "site pain", "localized pain", "post-procedure pain",
        "site discomfort", "discomfort", "local irritation",

        # Infections & inflammation
        "infection", "site infection", "wound infection",
        "inflammatory reaction", "swelling", "redness",
        "bruising", "seroma", "delayed healing", "wound complication",

        # Vascular / tissue injuries
        "vessel injury", "vascular injury", "arterial injury",
        "tissue damage", "tissue necrosis", "necrosis",
        "pseudoaneurysm", "extravasation", "vascular spasm",
        "ischemia", "limb ischemia", "nerve injury", "limb numbness", "numbness",

        # Hypersensitivity / allergy
        "hypersensitivity reaction", "allergic reaction", "skin reaction",
        "contact dermatitis", "rash", "burns",

        # Extra from your extended set
        "hematosis", "occlusion", "retroperitoneal bleed", "av fistula", 
        "femoral pseudoaneurysm", "vessel perforation", "artery dissection", 
        "groin pain", "groin hematoma", "thrombus formation", 
        "deep vein thrombosis", "vascular laceration", "prolonged bleeding", 
        "puncture site complication", "arterial embolism","skin reaction"
    ],

    "Death": [
        # Direct death reports
        "patient died", "patient death", "death reported",
        "death occurred", "fatal event", "fatal outcome",
        "loss of life", "mortality", "fatality", "lethal event",
        "procedure-related death", "unexpected death", "sudden death",
        "death following procedure", "fatal complication",

        # Common MAUDE phrasings
        "patient demise", "passed away", "found deceased",
        "collapsed and died", "expired", "death confirmed",
        "pronounced dead", "declared dead", "cause of death",
        "autopsy revealed", "post mortem", "post-mortem",
        "death certificate", "no signs of life", "end of life",

        # Indirect death phrasing
        "died suddenly", "died during procedure", "died after procedure",


        "cardiac arrest during procedure", "exsanguination", "massive bleed",
        "vascular collapse", "procedure fatality", "died from complication",
        "bleed-out", "femoral rupture", "death due to AV fistula", 
        "fatal stroke", "multi-organ failure", "cardiovascular collapse"
    ]
}


# ===== Models =====
bert_model = SentenceTransformer("all-MiniLM-L6-v2")
kw_model = KeyBERT(model=bert_model)

In [19]:
import re, json, numpy as np, pandas as pd
from sentence_transformers import util

# ------------------- Config -------------------
LABELS = ["Death", "Injury", "Device Malfunction"]

# thresholds (tune on your val set)
HI_THRESH = 0.70
MARGIN_THRESH = 0.10
ALPHA_DEFAULT = 0.60    # normal hybrid
ALPHA_LOW = 0.35        # when margin small / confidence low

EVIDENCE_TOPK = 12
EVIDENCE_MIN_COS = 0.30

# Death-specific safeguards
ENFORCE_DEATH_CONFIRMATION = True
DEATH_EVID_MIN = 0.55
DEATH_COS_MIN = 0.58
EVIDENCE_CAP_NO_STRICT_DEATH = 0.35

# Strict death cues
DEATH_STRICT = [
    r"\bpatient (?:died|expired)\b", r"\bpronounced dead\b", r"\bdeclared dead\b",
    r"\bfound deceased\b", r"\bpassed away\b",
    r"\bfatal (?:event|outcome|complication)\b",
    r"\bdeath (?:occurred|reported|confirmed)\b",
    r"\bmortality\b", r"\bfatality\b"
]
NEGATION_PAT = r"(no|not|without|never|denies?|rule[sd]?\s*out)\s+(?:any\s+)?(death|died|deceased|fatal|expired)"

# ------------------- Helpers -------------------
def keybert_terms(text, top_k=15):
    kws = kw_model.extract_keywords(
        text, keyphrase_ngram_range=(1,3), stop_words='english', top_n=top_k
    )
    return [w for w,_ in kws]

def has_strict_death(text: str) -> bool:
    if re.search(NEGATION_PAT, text, flags=re.I):
        return False
    return any(re.search(pat, text, flags=re.I) for pat in DEATH_STRICT)

def parse_prob_dict(x):
    if isinstance(x, dict): 
        return x
    if isinstance(x, str):
        try:
            return json.loads(x.replace("'", '"'))
        except Exception:
            return {}
    return {}
def evidence_for_label(text, label):
    if label not in DICT_TERMS:
        print("here -----")
        return 0.0, []

    ext_terms = keybert_terms(text, top_k=EVIDENCE_TOPK)
    print("ext ---",ext_terms,label)
    if not ext_terms:
        print("here -----------")
        return 0.0, []

    dict_terms = DICT_TERMS[label]
    dict_emb = bert_model.encode(dict_terms, convert_to_tensor=True, normalize_embeddings=True)
    ext_emb  = bert_model.encode(ext_terms,  convert_to_tensor=True, normalize_embeddings=True)

    cos = util.cos_sim(ext_emb, dict_emb).cpu().numpy()

    matches = []
    min_cos = DEATH_COS_MIN if label == "Death" else EVIDENCE_MIN_COS

    for i, term in enumerate(ext_terms):
        j = int(cos[i].argmax())
        s = float(cos[i][j])
        print(s,min_cos,term,dict_terms[j])
        if s >= min_cos:
            matches.append({
                "extracted_term": term,
                "matched_dict_term": dict_terms[j],
                "score": round(s, 3)
            })

    # Always include fallback matches for Injury and Device Malfunction
    if label in ["Injury", "Device Malfunction"]:
        for i, term in enumerate(ext_terms):
            j = int(cos[i].argmax())
            s = float(cos[i][j])
            print("-----------")
            print(s,min_cos,term,dict_terms[j])
            if s >= 0.20 and not any(m["extracted_term"] == term for m in matches):
                matches.append({
                    "extracted_term": term,
                    "matched_dict_term": dict_terms[j],
                    "score": round(s, 3)
                })

    if not matches:
        if label == "Death" and has_strict_death(text):
            return 0.9, []
        return 0.0, []

    top_scores = sorted([m["score"] for m in matches], reverse=True)[:5]
    evid = float(np.mean(top_scores))

    if label == "Death":
        if has_strict_death(text):
            evid = max(evid, 0.9)
        else:
            evid = min(evid, EVIDENCE_CAP_NO_STRICT_DEATH)

    return evid, sorted(matches, key=lambda x: x["score"], reverse=True)



def clean_keyword_list(matches, topn=8, with_scores=False):
    seen, cleaned = set(), []
    for m in sorted(matches, key=lambda x: x["score"], reverse=True):
        term = m["matched_dict_term"]
        if term in seen:
            continue
        seen.add(term)
        if with_scores:
            cleaned.append(f"{term}:{m['score']}")
        else:
            cleaned.append(term)
        if len(cleaned) >= topn:
            break
    return cleaned

def pick_final_with_death_guard(text, combined, evidence):
    cand = max(combined.items(), key=lambda x: x[1])[0]
    if not ENFORCE_DEATH_CONFIRMATION or cand != "Death":
        return cand, "hybrid_fallback"
    strict = has_strict_death(text)
    evid_ok = evidence.get("Death", 0.0) >= DEATH_EVID_MIN
    if strict and evid_ok:
        return "Death", "hybrid_fallback_confirmed_death"
    non_death = {k: v for k, v in combined.items() if k != "Death"}
    alt = max(non_death.items(), key=lambda x: x[1])[0]
    return alt, "death_demoted_insufficient_evidence"

# ------------------- Core classify -------------------
def classify_row(row):
    text = row["FOI_TEXT"]
    prob = parse_prob_dict(row["predicted_probability"])
    prob = {lab: float(prob.get(lab, 0.0)) for lab in LABELS}

    sorted_labs = sorted(LABELS, key=lambda l: prob[l], reverse=True)
    top1, top2 = sorted_labs[0], sorted_labs[1]
    p1, p2 = prob[top1], prob[top2]
    margin = p1 - p2

    high_conf_and_clear = (p1 >= HI_THRESH) and (margin > MARGIN_THRESH)
    labels_to_extract = [top1] if high_conf_and_clear else LABELS

    evidence = {lab: 0.0 for lab in LABELS}
    matches  = {lab: []  for lab in LABELS}
    for lab in labels_to_extract:
        e, m = evidence_for_label(text, lab)
        print(e,m)
        evidence[lab], matches[lab] = e, m

    if high_conf_and_clear:
        final_label = top1
        final_score = p1
        decision = "model_confident"
    else:
        alpha = ALPHA_LOW if margin <= MARGIN_THRESH else ALPHA_DEFAULT
        combined = {lab: alpha*prob[lab] + (1-alpha)*evidence[lab] for lab in LABELS}
        final_label, decision = pick_final_with_death_guard(text, combined, evidence)
        final_score = round(combined[final_label], 3)

    flat_matches = {f"{lab.lower().replace(' ', '_')}_matches": json.dumps(matches[lab], ensure_ascii=False) for lab in LABELS}

    clean_cols = {}
    for lab in LABELS:
        terms_only = clean_keyword_list(matches.get(lab, []), topn=8, with_scores=False)
        terms_with_scores = clean_keyword_list(matches.get(lab, []), topn=8, with_scores=True)
        clean_cols[f"{lab.lower().replace(' ', '_')}_keywords"] = "; ".join(terms_only)
        clean_cols[f"{lab.lower().replace(' ', '_')}_keywords_scored"] = "; ".join(terms_with_scores)

    return {
        "final_label": final_label,
        "final_score": final_score,
        "decision": decision,
        "p_model": json.dumps({lab: round(prob[lab],3) for lab in LABELS}),
        "evidence": json.dumps({lab: round(evidence[lab],3) for lab in LABELS}),
        **flat_matches,
        **clean_cols
    }

# ------------------- Run on CSV -------------------
# df = pd.read_csv("test_data_rag.csv")

# hyb = df.apply(classify_row, axis=1)

# for col in hyb.iloc[0].keys():
#     df[col] = hyb.apply(lambda x: x[col])

## apply logic on single report 

In [8]:

report="""A patient underwent an arteriotomy closure procedure using a Prostyle device after a percutaneous transluminal angioplasty. Despite two attempts, a cuff miss occurred due to difficulties with device deployment and positioning. The manufacturer's guidelines were not followed for the larger sheath size used (18F), which may have contributed to the issue. The investigation is ongoing, but preliminary analysis suggests an interaction between the device and patient anatomy as the likely cause of the problem, rather than a product quality issue."""

In [None]:
def classify_row(text,predicted_probability,):
    text=text
    prob = parse_prob_dict(predicted_probability)
    prob = {lab: float(prob.get(lab, 0.0)) for lab in LABELS}

    sorted_labs = sorted(LABELS, key=lambda l: prob[l], reverse=True)
    top1, top2 = sorted_labs[0], sorted_labs[1]
    p1, p2 = prob[top1], prob[top2]
    margin = p1 - p2

    high_conf_and_clear = (p1 >= HI_THRESH) and (margin > MARGIN_THRESH)
    labels_to_extract = [top1] if high_conf_and_clear else LABELS

    evidence = {lab: 0.0 for lab in LABELS}
    matches  = {lab: []  for lab in LABELS}
    for lab in labels_to_extract:
        e, m = evidence_for_label(text, lab)
        print(e,m)
        evidence[lab], matches[lab] = e, m

    if high_conf_and_clear:
        final_label = top1
        final_score = p1
        decision = "model_confident"
    else:
        alpha = ALPHA_LOW if margin <= MARGIN_THRESH else ALPHA_DEFAULT
        combined = {lab: alpha*prob[lab] + (1-alpha)*evidence[lab] for lab in LABELS}
        final_label, decision = pick_final_with_death_guard(text, combined, evidence)
        final_score = round(combined[final_label], 3)

    flat_matches = {f"{lab.lower().replace(' ', '_')}_matches": json.dumps(matches[lab], ensure_ascii=False) for lab in LABELS}

    clean_cols = {}
    for lab in LABELS:
        terms_only = clean_keyword_list(matches.get(lab, []), topn=8, with_scores=False)
        terms_with_scores = clean_keyword_list(matches.get(lab, []), topn=8, with_scores=True)
        clean_cols[f"{lab.lower().replace(' ', '_')}_keywords"] = "; ".join(terms_only)
        clean_cols[f"{lab.lower().replace(' ', '_')}_keywords_scored"] = "; ".join(terms_with_scores)

    return {
        "final_label": final_label,
        "final_score": final_score,
        "decision": decision,
        "p_model": json.dumps({lab: round(prob[lab],3) for lab in LABELS}),
        "evidence": json.dumps({lab: round(evidence[lab],3) for lab in LABELS}),
        **flat_matches,
        **clean_cols
    }

LABELS = ["Death", "Injury", "Device Malfunction"]


def build_prompt(text):
    instruction="Classify the type of adverse event as Death, Injury, or Device Malfunction."
    return f"""{instruction}

Event: {text}

Answer:"""

preds = None
probs_list = []


prompt=build_prompt(report)

inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=5,  # small since classification
        do_sample=False,
        use_cache=False,
        pad_token_id=tokenizer.pad_token_id,
        return_dict_in_generate=True,
        output_scores=True
    )

# Decode full prediction
decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
prediction = decoded.split("Answer:")[-1].strip().split("\n")[0]
preds=prediction

# Get token-level probabilities
# We take the first generated token after the input context
transition_scores = outputs.scores  # list of logits per generation step
if transition_scores:  
    step_logits = transition_scores[0]  # first step logits
    step_probs = softmax(step_logits, dim=-1)
    label_probs = {}
    for label in label_list:
        label_token_id = tokenizer(label, add_special_tokens=False).input_ids[0]
        label_probs[label] = step_probs[0, label_token_id].item()
    probs_list.append(label_probs)
else:
    probs_list.append({label: None for label in label_list})
del inputs, outputs, step_logits, step_probs
torch.cuda.empty_cache()

classify_row(report,probs_list)

In [14]:
preds,probs_list

('Injury',
 [{'Death': 7.019381155259907e-05,
   'Injury': 0.9377700090408325,
   'Device Malfunction': 0.06185264140367508}])

ext --- ['arteriotomy closure procedure', 'underwent arteriotomy closure', 'arteriotomy closure', 'prostyle device percutaneous', 'device patient anatomy', 'device percutaneous', 'percutaneous transluminal angioplasty', 'patient underwent arteriotomy', 'device percutaneous transluminal', 'transluminal angioplasty despite', 'transluminal angioplasty', 'underwent arteriotomy'] Death
0.5378475189208984 0.58 arteriotomy closure procedure vascular collapse
0.5793464183807373 0.58 underwent arteriotomy closure vascular collapse
0.5548417568206787 0.58 arteriotomy closure vascular collapse
0.31864237785339355 0.58 prostyle device percutaneous died after procedure
0.4245884120464325 0.58 device patient anatomy patient demise
0.35251253843307495 0.58 device percutaneous died after procedure
0.38477978110313416 0.58 percutaneous transluminal angioplasty vascular collapse
0.4776418209075928 0.58 patient underwent arteriotomy vascular collapse
0.2845582664012909 0.58 device percutaneous translumin

{'final_label': 'Injury',
 'final_score': 0.376,
 'decision': 'hybrid_fallback',
 'p_model': '{"Death": 0.0, "Injury": 0.0, "Device Malfunction": 0.0}',
 'evidence': '{"Death": 0.0, "Injury": 0.578, "Device Malfunction": 0.466}',
 'death_matches': '[]',
 'injury_matches': '[{"extracted_term": "underwent arteriotomy", "matched_dict_term": "vascular laceration", "score": 0.62}, {"extracted_term": "patient underwent arteriotomy", "matched_dict_term": "arterial embolism", "score": 0.595}, {"extracted_term": "underwent arteriotomy closure", "matched_dict_term": "vascular laceration", "score": 0.574}, {"extracted_term": "arteriotomy closure procedure", "matched_dict_term": "vascular laceration", "score": 0.567}, {"extracted_term": "percutaneous transluminal angioplasty", "matched_dict_term": "vascular laceration", "score": 0.535}, {"extracted_term": "arteriotomy closure", "matched_dict_term": "vascular laceration", "score": 0.529}, {"extracted_term": "transluminal angioplasty", "matched_dict