In [None]:
import shap
import matplotlib.pyplot as plt
from transformers import pipeline, BertTokenizer, BertForSequenceClassification

In [None]:
model_path = "../output/models/phishing-bert-model"

model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)

# 2) wrap it in a transformers pipeline that returns all class scores
#    (we use top_k=None so we get both “not_phish” and “phish”)
pipe = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    top_k=None,
    return_all_scores=True
)

# 3) your example text
text = (
    "quick quiz hottest thing costa rica answer real estate market got check view http "
    "decline future promotions offer see following http mail costa reserva leroy drive "
    "corona ca xycs9uc desertdeals proud hand selected offerings however excluded "
    "please visit http write us adobe one e chicago avenue chicago il"
)

# 4) build a SHAP Explainer for the pipeline
explainer = shap.Explainer(
    pipe,
)

# 5) compute SHAP values (this may take a moment)
shap_values = explainer([text])

In [None]:
exp = shap_values[0, :, 1]

In [None]:
fig = None
ax = shap.plots.waterfall(exp, max_display=15, show=False)
fig = ax.get_figure()      # pull the Figure off the Axes
fig.tight_layout()
fig.savefig("waterfall_plot.png", dpi=300, bbox_inches="tight")