In [None]:
#this is necessary for modernbert at the time of this writing, until future releases of transformers
!pip install git+https://github.com/huggingface/transformers.git

In [None]:
!pip install gradio torch captum seaborn matplotlib shap

In [None]:
import captum

import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

import gradio as gr
import numpy as np

from transformers import AutoTokenizer, AutoModelForSequenceClassification

from captum.attr import (
    LayerIntegratedGradients,
    visualization as viz,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from IPython.core.display import HTML
import math

In [None]:
model_path = "scbtm/ModernBERT_wine_quality_reviews_ft"
model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
model.eval()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
def forward_with_ids(input_ids, attention_mask):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    return outputs.logits

lig = LayerIntegratedGradients(forward_with_ids, model.model.embeddings)

In [None]:

def get_embeddings(model, tokenizer, device, text):

    tokens = tokenizer(text, return_tensors="pt", padding=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    return model.model.embeddings(tokens['input_ids'])

def generate_baseline(model, tokenizer, device, text):
    tokens = tokenizer(text, return_tensors="pt")
    baseline_tokens = torch.ones_like(tokens['input_ids']) * tokenizer.pad_token_id
    return model.model.embeddings(baseline_tokens.to(device))

def get_target_class(model, tokenizer, device, text):
    tokens = tokenizer(text, return_tensors="pt", padding=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model(**tokens)
        pred_class = torch.argmax(outputs, dim=1)[0].item()
        return pred_class


def get_prediction_and_confidence(model, tokenizer, device, text):
    tokens = tokenizer(text, return_tensors="pt", padding=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model(**tokens).logits
        probs = F.softmax(outputs, dim=1)

        if len(probs.shape) == 1:
            probs = probs.unsqueeze(0)

        pred_class = torch.argmax(probs, dim=1)[0].item()
        confidence = probs[0, pred_class].item()

    return pred_class, confidence, probs

def get_explanations_ig(model, tokenizer, device, text):
    # 1) Tokenize
    tokens = tokenizer(text, return_tensors="pt", padding=True)
    tokens = {k: v.to(device) for k, v in tokens.items()}
    input_ids = tokens["input_ids"]
    attention_mask = tokens["attention_mask"]

    # 2) Prepare baseline with same shape as input_ids
    baseline_ids = torch.ones_like(input_ids) * tokenizer.pad_token_id
    # If the model expects an attention mask for the baseline, you might do zeros:
    baseline_mask = torch.zeros_like(attention_mask)
    # or all 1s if you prefer. Depends on how you want to handle baseline.

    # 3) Get predicted class
    pred_class, confidence, probs = get_prediction_and_confidence(
        model, tokenizer, device, text
    )

    # 4) Integrated Gradients
    attributions_ig, delta_ig = lig.attribute(
        inputs=(input_ids, attention_mask),
        baselines=(baseline_ids, baseline_mask),
        target=pred_class,
        return_convergence_delta=True,
        n_steps=50,
    )

    return attributions_ig, delta_ig, pred_class, confidence, probs

def merge_subwords(tokens, attributions):
    """
    Given:
      tokens: list of subword tokens (e.g., ["This", "Ġwine", "Ġis", "Ġfrom", "It", "aly", ",", "ĠWhite", "ĠBl", "end"])
      attributions: np.array or list of float scores, same length as tokens

    Returns:
      merged_tokens, merged_attributions
      where subwords have been combined back into words.
      The attributions are summed by default, but you can choose to average, etc.
    """

    merged_tokens = []
    merged_attribs = []

    current_word = ""
    current_attr_sum = 0.0

    for subword, attr in zip(tokens, attributions):
        # Remove leading 'Ġ' if present
        cleaned_sub = subword.lstrip("Ġ")

        # If the subword starts with "Ġ" or if current_word is empty,
        # we treat it as the start of a new word.
        # (For the very first token, current_word is empty, so we set it directly.)
        if subword.startswith("Ġ") or current_word == "":
            # If we already have a word pending, push it first
            if current_word != "":
                merged_tokens.append(current_word)
                merged_attribs.append(current_attr_sum)

            current_word = cleaned_sub
            current_attr_sum = attr
        else:
            # If subword doesn't start with 'Ġ', continue the same word
            current_word += cleaned_sub  # e.g. "Bl" + "end" -> "Blend"
            current_attr_sum += attr

    # After the loop, push the final word
    if current_word:
        merged_tokens.append(current_word)
        merged_attribs.append(current_attr_sum)

    return merged_tokens, merged_attribs


def format_word_importances_only(words, importances):
    """
    Returns a single <td> HTML cell containing the tokens
    colored according to their attributions.
    """
    if importances is None or len(importances) == 0:
        return "<td></td>"
    assert len(words) <= len(importances), (
        f"Found more tokens than attributions. "
        f"len(words)={len(words)} len(importances)={len(importances)}"
    )

    # Build HTML for token coloring
    token_html_list = []
    for word, imp in zip(words, importances[: len(words)]):
        color = _get_color(imp)
        token_html_list.append(
            f"<mark style='background-color:{color}; "
            "display:inline-block; line-height:1.75; "
            "margin:0 2px; border-radius:3px'>"
            f"{word}</mark>"
        )
    return "<td>" + "".join(token_html_list) + "</td>"

def _get_color(attr):
    """
    Internal helper to map attribution value to an RGBA color.
    This replicates Captum's logic for coloring tokens.
    """
    red, green, blue = (21, 177, 234) if attr > 0 else (234, 78, 21)
    alpha = min(1.0, 0.2 + math.fabs(attr))  # scale transparency by magnitude
    return f"rgba({red},{green},{blue},{alpha})"

def visualize_text_only_word_importances(datarecords):
    """
    datarecords can be either a single VisualizationDataRecord or a list of them.
    Generates an HTML table with ONLY the 'Word Importance' column, removing
    'True Label', 'Predicted Label', 'Attribution Label', and 'Attribution Score'.
    """

    # If a single VisualizationDataRecord is passed, wrap in a list
    if not isinstance(datarecords, list):
        datarecords = [datarecords]

    # Table header: only one column (Word Importance)
    rows = ["<th>Word Importance</th>"]

    for dr in datarecords:
        word_importances_html = format_word_importances_only(
            dr.raw_input_ids, dr.word_attributions
        )
        row_html = f"<tr>{word_importances_html}</tr>"
        rows.append(row_html)

    table_html = "<table>" + "".join(rows) + "</table>"
    return HTML(table_html)

def interpret_sentence(model, tokenizer, device, text):
    attributions_ig, delta_ig, pred_class, confidence, probs = get_explanations_ig(model, tokenizer, device, text)

    id2label = {0: 'bad', 1: 'average', 2: 'good', 3: 'excellent'}
    label_str = id2label[pred_class]

    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    tokens_list = tokenizer.tokenize(text)
    merged_tokens, merged_attributions = merge_subwords(tokens_list, attributions)



    assert len(tokens_list) <= len(attributions), "Tokens must not be longer than attributions"

    visualization = viz.VisualizationDataRecord(word_attributions = merged_attributions, #attributions,
                                                pred_prob = confidence, #probs,
                                                pred_class = pred_class,
                                                true_class = None,
                                                attr_class = None,
                                                attr_score = np.mean(merged_attributions),
                                                raw_input_ids = merged_tokens,
                                                convergence_score = delta_ig)

    html_obj = visualize_text_only_word_importances(visualization)
    # This is an IPython HTML object; get the raw string via `.data`
    raw_html_str = html_obj.data

    return label_str, f'{100*confidence:.1f} %', raw_html_str

def create_gradio_interface():
    from functools import partial
    final_fn = partial(interpret_sentence, model, tokenizer, device)
    iface = gr.Interface(
        fn=final_fn,
        inputs=gr.Textbox(lines=5, label="Input Text"),
        outputs=[
            gr.Label(label="Predicted Class"),
            gr.Label(label="Confidence"),
            gr.HTML(label="Attribution Visualization"),
        ],
        title="Advanced XAI Text Classification Explainer",
        description="""This application provides sophisticated explanations for text classification predictions using multiple XAI methods:
        1. Integrated Gradients: Shows how predictions change from baseline to input""",
        examples=[
            ["This wine is from Italy, White Blend variety. Aromas include tropical fruit, broom, brimstone and dried herb. The palate isn't overly expressive, offering unripened apple, citrus and dried sage alongside brisk acidity."],
            ["This wine is from US, Pinot Noir variety. Much like the regular bottling from 2012, this comes across as rather rough and tannic, with rustic, earthy, herbal characteristics. Nonetheless, if you think of it as a pleasantly unfussy country wine, it's a good companion to a hearty winter stew."],
            ["Baked plum, molasses, balsamic vinegar and cheesy oak aromas feed into a palate that's braced by a bolt of acidity. A compact set of saucy red-berry and plum flavors features tobacco and peppery accents, while the finish is mildly green in flavor, with respectable weight and balance."]
        ]
    )
    return iface

In [None]:
iface = create_gradio_interface()
iface.launch(debug=True)