In [8]:
# explain_suicide_captum_all.py
# Single sample and batch explainability for the suicidal class

from pathlib import Path
import json
import random
from collections import Counter
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients

In [9]:
# =========================
# Path helpers and config
# =========================

def find_data_warehouse(start: Path) -> Path:
    for p in [start] + list(start.parents):
        dw = p / "Data_Warehouse"
        if dw.exists():
            return dw
    raise FileNotFoundError("Could not locate a folder named Data_Warehouse. Set DATA_WAREHOUSE manually.")


try:
    SCRIPT_DIR = Path(__file__).resolve().parent
except NameError:
    SCRIPT_DIR = Path.cwd()

DATA_WAREHOUSE = find_data_warehouse(SCRIPT_DIR)
SPLIT_DIR = DATA_WAREHOUSE / "mental_health_splits_no_stress"
MODEL_BASE = SPLIT_DIR / "all_roberta_large_v1_multiclass"
BEST_DIR = MODEL_BASE / "best"
OUTPUT_DIR = MODEL_BASE / "XAI"
MODEL_DIR = BEST_DIR if BEST_DIR.exists() else MODEL_BASE

MAX_LEN = 384
N_STEPS_IG = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANDOM_STATE = 42

random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)


<torch._C.Generator at 0x215a3a9c9d0>

In [10]:
# =========================
# Data and label mapping
# =========================

def load_test_and_mapping():
    test_path = SPLIT_DIR / "test.csv"
    if not test_path.exists():
        raise FileNotFoundError(f"Missing test.csv at {test_path}")
    df_test = pd.read_csv(test_path)

    label_map_path = SPLIT_DIR / "label_classes.csv"
    if label_map_path.exists():
        df_map = pd.read_csv(label_map_path, header=None)
        if df_map.shape[1] == 2:
            class_to_id = {str(df_map.iloc[i, 0]).strip().lower(): int(df_map.iloc[i, 1]) for i in range(len(df_map))}
        else:
            class_to_id = {str(df_map.iloc[i, -2]).strip().lower(): int(df_map.iloc[i, -1]) for i in range(len(df_map))}
    else:
        uniq = sorted([lbl for lbl in df_test["label"].astype(str).str.lower().unique() if lbl != "none"])
        class_to_id = {lbl: i for i, lbl in enumerate(uniq)}
        class_to_id["none"] = 4

    id_to_class = {v: k for k, v in class_to_id.items()}
    if "suicide" not in class_to_id:
        raise ValueError(f"Label mapping does not contain 'suicidal'. Found {list(class_to_id.keys())}")

    return df_test, class_to_id, id_to_class


In [11]:

# =========================
# Model and tokenizer
# =========================

def load_model_and_tokenizer():
    tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
    mdl = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
    mdl.to(DEVICE)
    mdl.eval()
    return tok, mdl


# =========================
# Core utilities
# =========================

def forward_fn(model, input_ids: torch.Tensor, attention_mask: torch.Tensor):
    return model(input_ids=input_ids, attention_mask=attention_mask).logits

def encode_one(tokenizer, text: str):
    enc = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
        padding=False,
        add_special_tokens=True,
    )
    return {k: v.to(DEVICE) for k, v in enc.items()}

def build_baseline_like(tokenizer, input_ids: torch.Tensor):
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 1
    return torch.full_like(input_ids, fill_value=pad_id)

def tokens_from_ids(tokenizer, ids: torch.Tensor):
    return tokenizer.convert_ids_to_tokens(ids.detach().cpu().tolist())

def merge_roberta_pieces(tokenizer, tokens, scores, drop_special=True):
    special_set = {
        tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token,
        tokenizer.eos_token, tokenizer.bos_token
    }
    words = []
    word_scores = []
    cur_word = ""
    cur_score = 0.0
    for tok, sc in zip(tokens, scores):
        if drop_special and tok in special_set:
            if cur_word:
                words.append(cur_word)
                word_scores.append(cur_score)
                cur_word, cur_score = "", 0.0
            continue
        if tok.startswith("Ä "):
            if cur_word:
                words.append(cur_word)
                word_scores.append(cur_score)
            cur_word = tok[1:]
            cur_score = float(sc)
        else:
            cur_word += tok
            cur_score += float(sc)
    if cur_word:
        words.append(cur_word)
        word_scores.append(cur_score)
    return words, np.array(word_scores, dtype=float)

def normalize_scores(scores: np.ndarray):
    if scores.size == 0:
        return scores
    s = scores - scores.mean()
    denom = np.abs(s).sum()
    return s / denom if denom > 0 else s

def save_html_heatmap(words, scores, out_path: Path):
    max_abs = float(np.max(np.abs(scores))) if np.max(np.abs(scores)) > 0 else 1.0
    spans = []
    for w, sc in zip(words, scores):
        alpha = abs(sc) / max_abs
        color = f"rgba(255,0,0,{alpha})" if sc >= 0 else f"rgba(0,0,255,{alpha})"
        spans.append(f"<span style='background-color:{color}; padding:2px; margin:1px; border-radius:3px'>{w}</span>")
    html = f"<html><body><div style='line-height:2.0'>{' '.join(spans)}</div></body></html>"
    out_path.write_text(html, encoding="utf-8")


# =========================
# Integrated Gradients
# =========================

def explain_one_text(tokenizer, model, text: str, target_id: int, n_steps: int = N_STEPS_IG):
    enc = encode_one(tokenizer, text)
    input_ids = enc["input_ids"]
    attn = enc["attention_mask"]
    baseline = build_baseline_like(tokenizer, input_ids)

    lig = LayerIntegratedGradients(lambda ids, mask: forward_fn(model, ids, mask), model.roberta.embeddings)
    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=baseline,
        additional_forward_args=(attn,),
        target=target_id,
        n_steps=n_steps,
        return_convergence_delta=True,
    )

    token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    tokens = tokens_from_ids(tokenizer, input_ids.squeeze(0))
    words, word_scores = merge_roberta_pieces(tokenizer, tokens, token_attr, drop_special=True)
    word_scores = normalize_scores(word_scores)

    with torch.no_grad():
        probs = torch.softmax(forward_fn(model, input_ids, attn), dim=-1).squeeze(0).cpu().numpy()
    return words, word_scores, float(probs[target_id]), float(delta.squeeze().detach().cpu().item())


In [12]:
# =========================
# Single sample explain
# =========================

def explain_single_samples(sample_texts, suicide_id: int, out_dir: Path, tokenizer, model, top_k: int = 12):
    out_dir.mkdir(parents=True, exist_ok=True)
    for i, t in enumerate(sample_texts, start=1):
        words, scores, p_su, delta = explain_one_text(tokenizer, model, t, suicide_id, N_STEPS_IG)
        # print top toward and away
        arr = np.array(scores)
        idx = np.argsort(arr)
        neg_idx = idx[:top_k]
        pos_idx = idx[-top_k:][::-1]
        print(f"\nSample {i}  suicide prob={p_su:.3f}")
        print("\nTop words that push toward suicidal")
        for j in pos_idx:
            print(f"{words[j]:20s}  {arr[j]: .4f}")
        print("\nTop words that push away from suicidal")
        for j in neg_idx:
            print(f"{words[j]:20s}  {arr[j]: .4f}")

        save_html_heatmap(words, arr, out_dir / f"ig_single_{i}.html")


In [13]:
# =========================
# Batch predict and sets
# =========================

def predict_batch(texts, tokenizer, model, batch_size=4):
    preds, probs = [], []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        enc = tokenizer(
            batch,
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
            padding=True,
        ).to(DEVICE)
        with torch.no_grad():
            logits = model(**enc).logits
            p = torch.softmax(logits, dim=-1).cpu().numpy()
            y = p.argmax(axis=1)
        preds.extend(y.tolist())
        probs.extend(p.tolist())
    
        del enc, logits
        import gc
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return np.array(preds), np.array(probs)


# =========================
# Batch explain and aggregate
# =========================

def sample_indices(arr, k, seed=RANDOM_STATE):
    lst = list(arr)
    random.Random(seed).shuffle(lst)
    return lst[:min(len(lst), k)]

def aggregate_for_indices(name, indices, texts, suicide_id, tokenizer, model, out_dir: Path, top_save=50):
    agg_toward = Counter()
    agg_away = Counter()
    meta = []

    out_dir.mkdir(parents=True, exist_ok=True)

    for j, idx in enumerate(indices, start=1):
        text = texts[idx]
        words, scores, p_su, delta = explain_one_text(tokenizer, model, text, suicide_id, N_STEPS_IG)

        save_html_heatmap(words, scores, out_dir / f"ig_{name}_{j}.html")

        for w, sc in zip(words, scores):
            if not w or not w.strip():
                continue
            if sc >= 0:
                agg_toward[w] += float(sc)
            else:
                agg_away[w] += float(-sc)

        meta.append({
            "idx": int(idx),
            "suicide_prob": p_su,
            "text": text[:5000]
        })

    df_toward = pd.DataFrame(agg_toward.items(), columns=["word", "score"]).sort_values("score", ascending=False)
    df_away = pd.DataFrame(agg_away.items(), columns=["word", "score"]).sort_values("score", ascending=False)
    df_meta = pd.DataFrame(meta)

    df_toward.head(top_save).to_csv(out_dir / f"agg_{name}_top_toward.csv", index=False)
    df_away.head(top_save).to_csv(out_dir / f"agg_{name}_top_away.csv", index=False)
    df_meta.to_csv(out_dir / f"samples_{name}.csv", index=False)

    print(f"[{name}] saved CSVs and HTML")
    return df_toward, df_away


def batch_explain(df_test, class_to_id, tokenizer, model, out_dir: Path, max_explains=100):
    texts = df_test["text"].astype(str).tolist()
    y_true = df_test["label"].astype(str).str.lower().map(class_to_id).to_numpy()

    y_pred, y_prob = predict_batch(texts, tokenizer, model, batch_size=4)

    idx_dep = class_to_id["depression"]
    idx_su = class_to_id["suicide"]

    fp_dep_to_su = np.where((y_true == idx_dep) & (y_pred == idx_su))[0]
    fn_su_to_other = np.where((y_true == idx_su) & (y_pred != idx_su))[0]
    tp_su = np.where((y_true == idx_su) & (y_pred == idx_su))[0]

    fp_s = sample_indices(fp_dep_to_su, max_explains)
    fn_s = sample_indices(fn_su_to_other, max_explains)
    tp_s = sample_indices(tp_su, max_explains)

    print(f"Selected counts  FP_dep_to_su={len(fp_s)}  FN_su_to_other={len(fn_s)}  TP_su={len(tp_s)}")

    toward_fp, away_fp = aggregate_for_indices("fp_dep_to_suicide", fp_s, texts, idx_su, tokenizer, model, out_dir)
    toward_fn, away_fn = aggregate_for_indices("fn_suicide_to_other", fn_s, texts, idx_su, tokenizer, model, out_dir)
    toward_tp, away_tp = aggregate_for_indices("tp_suicide", tp_s, texts, idx_su, tokenizer, model, out_dir)

    # optional overlap report between TP toward and FP toward
    overlap = set(toward_tp.head(100)["word"]).intersection(set(toward_fp.head(100)["word"]))
    pd.Series(sorted(overlap)).to_csv(out_dir / "overlap_tp_vs_fp_toward_words.csv", index=False)
    print("Saved overlap list of top drivers between TP and FP")

    return {
        "fp_counts": len(fp_s),
        "fn_counts": len(fn_s),
        "tp_counts": len(tp_s),
    }


In [14]:
# =========================
# Main entry
# =========================

if __name__ == "__main__":
    print("Using Data Warehouse:", DATA_WAREHOUSE)
    print("Using split dir:", SPLIT_DIR)
    print("Using model dir:", MODEL_DIR)

    df_test, class_to_id, id_to_class = load_test_and_mapping()
    tokenizer, model = load_model_and_tokenizer()
    suicide_id = class_to_id["suicide"]

    # ===== single mode example =====
    # Option A: pick texts from test
    texts_for_single = []
    df_su = df_test[df_test["label"].astype(str).str.lower() == "suicide"]
    if not df_su.empty:
        texts_for_single.append(df_su.sample(1, random_state=17)["text"].iloc[0])
    df_dep = df_test[df_test["label"].astype(str).str.lower() == "depression"]
    if not df_dep.empty:
        texts_for_single.append(df_dep.sample(1, random_state=23)["text"].iloc[0])
    # Option B: add your own manual text
    if len(texts_for_single) == 0:
        texts_for_single.append("I am tired of life and I want to end everything")

    RUN_SINGLE = False
    RUN_BATCH = True
    
    if RUN_SINGLE:
        single_out_dir = OUTPUT_DIR / "explain_single"
        explain_single_samples(
            sample_texts=texts_for_single,
            suicide_id=suicide_id,
            out_dir=single_out_dir,
            tokenizer=tokenizer,
            model=model,
            top_k=12,
        )

    if RUN_BATCH:
        batch_out_dir = OUTPUT_DIR / "explain_batch"
        stats = batch_explain(
            df_test=df_test,
            class_to_id=class_to_id,
            tokenizer=tokenizer,
            model=model,
            out_dir=batch_out_dir,
            max_explains=100,
        )
        print("Batch stats:", stats)


Using Data Warehouse: d:\Sajjad-Workspace\PSS_XAI\Data_Process\Data_Warehouse
Using split dir: d:\Sajjad-Workspace\PSS_XAI\Data_Process\Data_Warehouse\mental_health_splits_no_stress
Using model dir: d:\Sajjad-Workspace\PSS_XAI\Data_Process\Data_Warehouse\mental_health_splits_no_stress\all_roberta_large_v1_multiclass\best
Selected counts  FP_dep_to_su=23  FN_su_to_other=14  TP_su=70
[fp_dep_to_suicide] saved CSVs and HTML
[fn_suicide_to_other] saved CSVs and HTML
[tp_suicide] saved CSVs and HTML
Saved overlap list of top drivers between TP and FP
Batch stats: {'fp_counts': 23, 'fn_counts': 14, 'tp_counts': 70}


ðŸ”¹ 1. Organize by error type

False Negatives (suicide â†’ other)

Toward: weak suicide cues (e.g., cutting, dead, insane) appear, but not strong enough for the model.

Away: background/benign terms (e.g., school, grades, stuttering, better) suppressed the suicide prediction.

Interpretation: the model often misses suicide texts when they are diluted with daily-life or school-related language.

False Positives (depression â†’ suicide)

Toward: depression posts mentioning suicidal, suicide, hotline, killing push them incorrectly into suicide class.

Away: words like manipulative, yourself, excited, phases dampen the suicide score but not enough.

Interpretation: the model confuses strong expressions of depression with suicide intent due to lexical overlap.

True Positives (suicide â†’ suicide)

Toward: explicit suicide intent markers (myself, dying, suicide, kill, death, anymore, want, life) strongly drive correct classification.

Away: some neutral context words (deadly, real, problems, family, account) pull slightly away but donâ€™t overturn the prediction.

Interpretation: the model is reliable when posts contain direct, explicit suicide markers.

ðŸ”¹ 2. Higher-level themes

Overlap problem: depression and suicide share vocabulary (life, feelings, help, donâ€™t, canâ€™t), which drives false positives.

Dilution problem: everyday or contextual words (school, grades, family) reduce the suicide probability in true suicide posts, driving false negatives.

Strong signals: explicit intent markers (suicide, kill, death, myself, dying) consistently separate true suicide posts.

ðŸ”¹ 3. Suggested summary paragraph

Explainability analysis using Integrated Gradients revealed that true suicide predictions are driven by explicit intent markers such as myself, suicide, kill, death, and dying. False positives (depression misclassified as suicide) occur when depressive texts contain similar lexicon (e.g., suicidal, hotline, killing), highlighting the lexical overlap between depression and suicide. False negatives (suicide misclassified as other) are often associated with everyday context terms (e.g., school, grades, family), which dilute the suicidal signal. These results indicate that the model captures explicit markers well but struggles in borderline cases where suicide intent is implied rather than stated directly, or where depression shares overlapping vocabulary.