<a href="https://colab.research.google.com/github/nike-2001/Toxicity-Detection-In-Social-Media/blob/main/SHAP_Interpretability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install the required libraries: SHAP for explainability, Transformers for pre-trained models, and Torch for deep learning
pip install shap transformers torch

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
# import the required libraries
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import shap

In [None]:
# Define the pre-trained BERT model
MODEL_NAME = "bert-base-uncased"

# Load the BERT model for binary classification
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# Load the tokenizer for the BERT model
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [None]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the selected device
model.to(device)

# Define a function for model predictions
def model_predict(texts):
    # Tokenize the input texts and prepare them for the model
    encoded_inputs = tokenizer(
        list(texts),
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    # Move input data to the selected device
    input_ids = encoded_inputs["input_ids"].to(device)
    attention_mask = encoded_inputs["attention_mask"].to(device)

    # Make predictions without updating model parameters
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    # Apply softmax to get probabilities
    probs = torch.softmax(logits, dim=-1)
    return probs.cpu().numpy()

# Initialize SHAP Explainer with the model prediction function
explainer = shap.Explainer(
    model_predict,
    masker=shap.maskers.Text(tokenizer),
)


In [None]:
# List of example comments, including both negative and positive sentiment
comments = [
    "Fucking Hate you",  # Strongly negative comment
    "You're so stupid",  # Negative comment
    "We like you"        # Positive comment
]


In [None]:
shap_values = explainer(comments)

In [None]:
# Iterate through the list of comments with their indices
for i, comment in enumerate(comments):
    # Print the current comment being processed
    print(f"\nComment: {comment}")

    # Extract SHAP values for the "toxic" class (class index 1)
    shap_values_toxic_class = shap_values[i].values[:, 1]

    # Create a copy of the SHAP explanation for the toxic class
    shap_value_copy = shap.Explanation(
        values=shap_values_toxic_class,           # SHAP values for the toxic class
        base_values=shap_values[i].base_values[1],  # Base value for the toxic class
        data=shap_values[i].data,                  # Input data associated with SHAP values
        feature_names=shap_values[i].feature_names,  # Names of features (tokens)
        clustering=shap_values[i].clustering        # Clustering info for visualization
    )

    # Visualize the SHAP values for the toxic class using a text plot
    shap.text_plot(shap_value_copy)



Comment: Fucking Hate you



Comment: You're so stupid



Comment: We like you
