<a href="https://colab.research.google.com/github/thetongs/shap-xai/blob/main/Untitled12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip3 install datasets

Collecting datasets
  Downloading datasets-3.3.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.1-py3-none-any.whl (484 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading xx

In [3]:
!pip3 install shap



In [5]:
import datasets
import numpy as np
import scipy as sp
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import shap

# Set up GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
).to(device)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def predict_texts(texts):
    """Predict sentiments for a list of texts"""
    try:
        # Ensure input is a list of strings
        if isinstance(texts, str):
            texts = [texts]
        elif isinstance(texts, np.ndarray):
            texts = texts.tolist()

        # Print input type for debugging
        print(f"Input type: {type(texts)}")
        print(f"First few examples: {texts[:3]}")

        # Encode texts with proper padding and truncation
        inputs = tokenizer(
            texts,
            padding="max_length",
            max_length=500,
            truncation=True,
            return_tensors="pt"
        ).to(device)

        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        # Convert to probabilities and get logit values
        probs = torch.nn.functional.softmax(logits, dim=-1)
        logit_values = sp.special.logit(probs[:, 1])

        return logit_values.detach().cpu().numpy()

    except Exception as e:
        print(f"Error processing texts: {str(e)}")
        raise

# Load IMDB dataset
imdb_train = datasets.load_dataset("imdb", split="train")
raw_texts = imdb_train["text"][:1]

# Get predictions
predictions = predict_texts(raw_texts)

# Create SHAP explainer with custom masker
class CustomTextMasker(shap.maskers.Text):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)
        self.tokenizer = tokenizer

    def mask(self, text, mask_position):
        # Convert to list if numpy array
        if isinstance(text, np.ndarray):
            text = text.tolist()

        # Handle masked inputs properly
        if isinstance(text, list) and all(isinstance(x, str) and x.startswith('[MASK]') for x in text):
            return text

        return super().mask(text, mask_position)

masker = CustomTextMasker(tokenizer=tokenizer)
explainer = shap.Explainer(predict_texts, masker=masker)

# Calculate SHAP values
shap_values = explainer(raw_texts)

Input type: <class 'list'>
First few examples: ['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the s

  0%|          | 0/498 [00:00<?, ?it/s]

Input type: <class 'list'>
First few examples: ['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the s


PartitionExplainer explainer: 2it [07:16, 436.30s/it]              


In [7]:

# plot a sentence's explanation
shap.plots.text(shap_values[0])