In [9]:
# import spacy
import shap
import torch
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification

from lib.utils import load_jsonl_file
from lib.ner_processing import custom_anonymize_text

SEED = 42
BATCH_SIZE = 16
CLASS_NAMES = ['continue', 'not_continue']

# Load dataset
DATASET = load_jsonl_file("shared_data/dataset_2_6_2b.jsonl")
# Select a sample from the dataset

example = DATASET[2005]

text  = example["text"]
text_id = example["id"]

print(f"Actual label: {example['label']}")

# nlp = spacy.load("en_core_web_trf")


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/3/paper_b_hop_bert_reclass.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES),
                                                      hidden_dropout_prob=0.1)

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()


def predict(texts):
  encoding = tokenizer.batch_encode_plus(
      batch_text_or_text_pairs=texts,
      padding=True,
      truncation=True,
      max_length=512,
      return_tensors='pt'  # Return PyTorch tensors
  )
  input_ids = encoding['input_ids'].to(device)
  attention_mask = encoding['attention_mask'].to(device)
  
  logits = model(input_ids, attention_mask=attention_mask)[0]  
  probabilities = logits.detach().cpu().numpy()
  return probabilities


# Initialize the SHAP explainer
explainer = shap.Explainer(
  model=predict, 
  masker=tokenizer, 
  output_names=CLASS_NAMES, 
  seed=SEED
)

sentence1, sentence2 = text.split('[SEP]')
# text = sentence1.strip() + " " + sentence2.strip()
# text = "They have to come back and get this done, because failure to support Ukraine in this critical moment will never be forgotten in history. It will be measured, and it will have impact for decades to come."
# text = "I want to thank you all for delivering historic results for the American people. You've been incredible partners."
# text = "Didn't happen in Democrat or Republican administrations for the longest time. And so, guess what happened?"
# text = "Folks, Congress has had a long, proud history of -- bipartisan history on immigration reforms and abiding by our international treaty obligations, which we've signed, relating to immigration. These reforms made America a nation of laws, a nation of immigrants, and the strongest economy in the world."
# text = "John is a software developer. He works at a tech company."
# text = "Well, guess what? Didn't happen in Democrat or Republican administrations for the longest time."
# text = "It's clear we have the strongest economy in the world. And that's not hyperbole."
# text = "For example, we capped insulin for seniors on Medicare at $35 a month instead of as much as $400 a month. Well, let's make that $35 available to everyone in your states -- everyone."
# text = "And our politics has failed to fix it. That's why, months ago, I instructed my team to begin a series of negotiations in a bipartisan group of senators."
# text = "ABC is a great software company. That's something that you should ignore."


# Make predictions
probabilities = predict([text])

predicted_class_index = np.argmax(probabilities, axis=1)[0]

# Map the predicted class index to the class name
predicted_class_name = CLASS_NAMES[predicted_class_index]

print(f"Predicted class: {predicted_class_name}")

# Compute SHAP values for the selected samples
shap_values = explainer([text])

# Visualize the SHAP values
shap.plots.text(shap_values)

# shap.save_html(f"xnlp/model_3_shap_{text_id}.html", shap.plots.text(shap_values[0]))


Loading data from shared_data/dataset_2_6_2b.jsonl...
Loaded 3566 items.
Actual label: not_continue

Using device: MPS



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

Predicted class: continue
