In [14]:
import shap
import torch
from transformers import BertTokenizer, BertForSequenceClassification

from lib.utils import load_jsonl_file

BATCH_SIZE = 16

CLASS_NAMES = ['monologic', 'dialogic']

# Load dataset
DATASET = load_jsonl_file("shared_data/dataset_1_6_1b_test.jsonl")
# Load mismatched datapoint
mismatched_datapoints = load_jsonl_file("shared_data/dataset_1_8_2b_misclassified_examples.jsonl")
mismatched_datapoints = mismatched_datapoints[:5]


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/1/paper_a_x_dl_bert_train_hop_bert.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES),
                                                      hidden_dropout_prob=0.1)

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()


"""# Define a prediction function that the SHAP explainer will use
def predict(texts):
  # Tokenize the input texts
  inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
  # Move inputs to the same device as model
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
  # Get model outputs (logits)
  with torch.no_grad():
    outputs = model(**inputs)
  # Apply softmax to logits to get probabilities
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
  return probabilities.cpu().detach().numpy()"""


def f(x):
  tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x],
                    device=device)
  attention_mask = (tv != 0).type(torch.int64).to(device)
  _outputs = model(tv, attention_mask=attention_mask)[0]
  return _outputs.detach().cpu().numpy()



# Initialize the SHAP explainer
explainer = shap.Explainer(f, tokenizer, output_names=["monologic", "dialogic"])

selected_texts = []
for mismatch in mismatched_datapoints:
  datapoint = None
  for data in DATASET:
    if data["id"] == mismatch["id"]:
      """print("---")
      print(data["text"])
      print("---")"""
      selected_texts.append(data["text"])
      break

"""print(type(selected_texts))
print(type(selected_texts[0]))
print(selected_texts[0])"""

# Sample single string input
single_text = "This is a test sentence to check the explainer."

# Try computing SHAP values for the single string
shap_values_single = explainer([single_text])  # Note: passing a list with a single string

# If no error, try visualizing the SHAP values
if shap_values_single:
    shap.plots.text(shap_values_single)


"""# Compute SHAP values for the selected samples
shap_values = explainer(selected_texts)

# Visualize the SHAP values
shap.plots.text(shap_values)"""


Loading data from shared_data/dataset_1_6_1b_test.jsonl...
Loaded 468 items.
Loading data from shared_data/dataset_1_8_2b_misclassified_examples.jsonl...
Loaded 23 items.

Using device: MPS


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

  0%|          | 0/156 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:10, 10.56s/it]               


'# Compute SHAP values for the selected samples\nshap_values = explainer(selected_texts)\n\n# Visualize the SHAP values\nshap.plots.text(shap_values)'