In [2]:
# from pprint import pprint
# import spacy
import shap
import torch
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification

from lib.utils import load_jsonl_file

SEED = 42
BATCH_SIZE = 16
CLASS_NAMES = ['continue', 'not_continue']

pd.set_option('display.max_rows', 500)

# Load dataset
test_continue = load_jsonl_file("shared_data/topic_boundary_continue_class.jsonl")
text_not_continue = load_jsonl_file("shared_data/topic_boundary_not_continue_class.jsonl")

DATASET = text_not_continue + test_continue

def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/3/TopicBoundaryBERT.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Never split tokens
tokenizer.add_tokens(["1,000", "2,000", "endures", "decency", "stockpile", "ventilators", "Blackwater", "standpoint", "dismantle", "empower", "frack", "polluters", "Saddamists", "rejectionists", "Qaida", "maiming",
 "torturing", "healthier", "massively", "asymptomatic", "Pocan", "unfairly", "1,400", "'s", "62,000", "hospitalizations", "490,050", "commend", "F-16", "opioid", "pushers", "peddling", "Ebola", "czar", "reiterate", "USAID", "maximally", "unwittingly", "'d", "Assad", "pandemic", "deadliest", "defunding", "ATF", "pressuring", "DACA", "U.S.", "basing", "hospitalization", "COVID", "incentivize", "reimagine", "dictate", "beneficiary", "closures", "lawmakers", "equipping", "vaccination", "retrain", "Hun-", "nutritious", "inhumane", "qualifies", "lifeblood", "forecasts", "vaccinated", "1619", "hundreds", "70,000", "legislating",
  "Javits", "childcare", "reemphasized", "destabilizing", "exporter", "COVID-19", "vaccinations", "ISR", "Abound", "1,500", "FDIC", "2.9", "IndyMac", "5,000", "borrowers", "foreclosure", "mortgages", "2.2", "pand"
                      ])

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES))

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()


def predict(texts):
  encoding = tokenizer.batch_encode_plus(
      batch_text_or_text_pairs=texts,
      padding=True,
      truncation=True,
      max_length=512,
      return_tensors='pt'  # Return PyTorch tensors
  )
  input_ids = encoding['input_ids'].to(device)
  attention_mask = encoding['attention_mask'].to(device)
  
  _logits = model(input_ids, attention_mask=attention_mask)[0]  
  _probabilities = _logits.detach().cpu().numpy()
  return _probabilities


# Initialize the SHAP explainer
explainer = shap.Explainer(
  model=predict, 
  masker=tokenizer, 
  output_names=CLASS_NAMES, 
  seed=SEED
)

# Select a sample from the dataset
example = DATASET[259]  # ID - 1 <---------------------------------------------------------

text  = example["text"]
sentence1, sentence2 = text.split('[SEP]')

sentence1 = sentence1.strip()
sentence2 = sentence2.strip()

text = sentence1 + " " + sentence2

print(f"S1: {sentence1}")
print(f"S2: {sentence2}")

text_id = example["id"]

print(f"Actual label: {example['label_human']}")


# Make predictions
probabilities = predict([text])

predicted_class_index = np.argmax(probabilities, axis=1)[0]

# Map the predicted class index to the class name
predicted_class_name = CLASS_NAMES[predicted_class_index]

print(f"Predicted class: {predicted_class_name}")

# Compute SHAP values for the selected samples
shap_values = explainer([text], fixed_context=1)

# print(len(shap_values.values[0]))
# print(shap_values.values)
# print(shap_values.feature_names)

_shap_values = [(shap_values.values[0][i][0], shap_values.values[0][i][1], shap_values.feature_names[0][i]) for i in range(len(shap_values.values[0]))]

# Remove rows with blank features (""), they are not useful
_shap_values = [row for row in _shap_values if row[2] != ""]

# pprint(shap_values)

# Convert to DataFrame
df_shap_values = pd.DataFrame(_shap_values, columns=["continue", "not_continue", "feature"])


print(df_shap_values)

shap.plots.text(shap_values)

"""if example.get("metadata"):
  metadata = eval(example["metadata"])

for continuity_feature in metadata:
  print(continuity_feature)"""


Loading data from shared_data/topic_boundary_continue_class.jsonl...
Loaded 145 items.
Loading data from shared_data/topic_boundary_not_continue_class.jsonl...
Loaded 145 items.

Using device: MPS



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

S1: So I will say this: The immigration enforcement is a federal authority, and states should not be mandating it - meddling in it.
S2: That is just - especially governor of Te- - the Texas governor, Abbott, who has a track record of causing chaos and confusion at the border.
Actual label: continue
Predicted class: continue
    continue  not_continue      feature
0   0.136496     -0.370999           so
1  -0.028647      0.024833            i
2  -0.019640      0.056778         will
3   0.050532     -0.046703          say
4   0.123237     -0.148827         this
5  -0.183838      0.296094            :
6   0.020669     -0.088955          the
7   0.051203     -0.100025  immigration
8   0.008544      0.010294  enforcement
9   0.146750     -0.178142           is
10  0.021408     -0.011860            a
11 -0.013646      0.046598      federal
12 -0.042486      0.100532    authority
13  0.001068     -0.199613            ,
14 -0.027536     -0.192606          and
15 -0.072892      0.220776       s

'if example.get("metadata"):\n  metadata = eval(example["metadata"])\n\nfor continuity_feature in metadata:\n  print(continuity_feature)'