In [17]:
import torch
import shap
import pandas as pd
from setfit import SetFitModel
from transformers import AutoTokenizer

from lib.utils import load_jsonl_file


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Get best device
device = get_device()

CLASS_NAMES = ['support', 'oppose']  # 0/1

model_id = "sentence-transformers/paraphrase-mpnet-base-v2"  # Example model_id

# Load dataset
DATASET = load_jsonl_file("shared_data/dataset_2_test.jsonl")

support_class = [data for data in DATASET if data["label"] == "support"]
oppose_class = [data for data in DATASET if data["label"] == "oppose"]

# Load model
model_setfit_path = "models/22"
model = SetFitModel.from_pretrained(model_setfit_path, local_files_only=True, device=device)

# Load the corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)


def f(x):  # x is a list of strings
  # Make predictions using the model
  predictions = model.predict_proba(x)
  return predictions.cpu()


explainer = shap.Explainer(
  model=f,  # prediction function
  masker=tokenizer, 
  output_names=["support", "oppose"],
  algorithm="auto",
  linearize_link=None,
  seed=42
)

texts = [data["text"] for data in support_class]

sentence = texts[66:67]
print(sentence)

shap_values = explainer(sentence, fixed_context=None)
# print(shap_values)

"""
# Make predictions using the model and get the predicted class name
predictions = model.predict_proba(sentence)
predicted_class_index = np.argmax(predictions, axis=1)[0]
predicted_class_name = CLASS_NAMES[predicted_class_index]
print(f"Predicted class: {predicted_class_name}")"""

shap.plots.text(shap_values)

summed_shap_values = shap_values.sum(0).values
summed_shap_features = shap_values.sum(0).feature_names

# flattened_shap_values
summed_shap_values = [row[0] for row in summed_shap_values]

# Convert to DataFrame for easy CSV saving
df_shap_values = pd.DataFrame(summed_shap_values, columns=["Value"])

# Add a column for the feature names if available
df_shap_values["Feature"] = summed_shap_features
df_shap_values = df_shap_values.sort_values(by="Value", ascending=False)

print(df_shap_values)

shap.plots.bar(shap_values)

Loading data from shared_data/dataset_2_test.jsonl...
Loaded 200 items.
['I believe we can work together in Afghanistan to make sure that former safe haven is able to grow as a democracy.']


       Value      Feature
4   0.104286          can
5   0.094049         work
21  0.089497    democracy
6   0.084638     together
14  0.063774         safe
18  0.044012         grow
17  0.042767         able
1   0.026520            I
16  0.018735           is
9   0.015799           to
19  0.011806           as
20  0.007832            a
3   0.004935           we
2   0.004067      believe
11  0.003891         sure
7   0.001616           in
0  -0.000012             
10 -0.006609         make
22 -0.032950            .
15 -0.034536        haven
12 -0.038375         that
13 -0.042037       former
8  -0.066833  Afghanistan


AssertionError: The clustering provided by the Explanation object does not seem to be a partition tree (which is all shap.plots.bar supports)!