
**The Gradio application is best viewed with a light theme (windows) enabled to enhance the readability and visualization of text influence.**

In [2]:
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients
import html

# Load fine-tuned model from local directory
model = AutoModelForSequenceClassification.from_pretrained(r"saved_model1")
tokenizer = AutoTokenizer.from_pretrained(r"saved_model1")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [3]:
import re

def clean_text(text):
    text = str(text)
    text = text.strip()
    text = re.sub(r'\s+', ' ', text)  # Remove extra whitespace
    text = re.sub(r'http\S+|www.\S+', '', text)  # Remove URLs
    text = re.sub(r'\@[\w_]+', '', text)  # Remove @mentions
    text = re.sub(r'\#', '', text)  # Remove hashtags symbol
    return text


In [4]:
# Define forward function for attribution
def forward_func(inputs, attention_mask=None):
    outputs = model(inputs, attention_mask=attention_mask)
    return torch.softmax(outputs.logits, dim=1)[:, 0]  # attribution for class 0


#### Color Fix Patch

In [5]:
import math
import html

def interpret(text,confidence):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=128
    ).to(device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    ref_input_ids = torch.full_like(input_ids, tokenizer.pad_token_id)

    lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
    attributions, delta = lig.attribute(
        inputs=input_ids,
        baselines=ref_input_ids,
        additional_forward_args=(attention_mask,),
        return_convergence_delta=True
    )

    # sum over embedding dims, normalize
    scores = attributions.sum(dim=-1).squeeze(0)
    scores = scores / (scores.abs().max() + 1e-8)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    # build legend + wrapper
    html_output = """
    <div style="border-top:1px solid #aaa; padding-top:8px; margin-top:8px">
      <b>Legend:</b>
      <span style="display:inline-block;width:10px;height:10px;background-color:hsl(0,75%,50%);border:1px solid;"></span> Negative 
      <span style="display:inline-block;width:10px;height:10px;background-color:hsl(0,0%,100%);border:1px solid;"></span> Neutral 
      <span style="display:inline-block;width:10px;height:10px;background-color:hsl(120,75%,50%);border:1px solid;"></span> Positive
      <br><br>
    """

    for score, token in zip(scores.tolist(), tokens):
        if token in tokenizer.all_special_tokens:
            continue
        # clean up BPE token
        display_token = token.replace("Ġ", " ")
        if not display_token.startswith(" "):
            display_token = " " + display_token

        # clamp & compute hue/lightness
        if confidence == "suicidal_prob":
            s = max(-1, min(1, score))  # should already be in [-1,1]
            hue = 0 if s < 0 else 120
            lightness = 100 - abs(s) * 50  # 0→white(100%), 1→50%
            color = f"hsl({hue},75%,{lightness:.0f}%)"

        else :
            s = max(-1, min(1, score))  # should already be in [-1,1]
            hue = 0 if s > 0 else 120
            lightness = 100 - abs(s) * 50  # 0→white(100%), 1→50%
            color = f"hsl({hue},75%,{lightness:.0f}%)"

        html_output += f"<mark style='background-color:{color}; padding:2px; margin:1px'><font color='black'>{html.escape(display_token)}</font></mark>"

    html_output += "</div>"
    return html_output



The model shows a bias toward non-suicidal predictions; therefore, a lower threshold of 0.65 is set for the non-suicidal class, as misclassifications are frequently observed within this range.

**If want to Debug, adjust gr.Interface().launch(debug=False) to True**

In [6]:
# ID to label mapping
label_map = {0: "❌Suicidal", 1: "❇️Non-Suicidal"}

# Predict and explain
def predict_and_explain(text, threshold):
    text = clean_text(text)  # <---- Pre-cleaning added
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        # confidence = probs[0][pred].item()

        non_suicidal_prob = probs[0][1].item()
        suicidal_prob = probs[0][0].item()

        if non_suicidal_prob < threshold:
            pred = 0
            confidence = suicidal_prob
        else:
            pred = 1
            confidence = non_suicidal_prob

    result = f"{label_map[pred]} (Confidence: {confidence:.2%})"
    explanation = interpret(text,confidence)
    return result, explanation

# Gradio UI
gr.Interface(
    fn=predict_and_explain,
    inputs=[
        gr.Textbox(lines=4, placeholder="Enter text here...", label="Input Text"),
        gr.Slider(0.5, 0.95, value=0.65, step=0.01, label="Non-suicidal Confidence Threshold")
    ],
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.HTML(label="Model Explanation (token importance)")
    ],
    title="Text Classification for Mental Health with Explanation",
    description="This demo uses a fine-tuned BERT model to classify text and explain predictions using Integrated Gradients."
).launch(debug=False)


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


