In [57]:
# import spacy
import shap
import torch
from transformers import BertTokenizer, BertForSequenceClassification

from lib.utils import load_jsonl_file
from lib.ner_processing import custom_anonymize_text

SEED = 42
BATCH_SIZE = 16
CLASS_NAMES = ['continue', 'not_continue']

# Load dataset
DATASET = load_jsonl_file("shared_data/dataset_2_6_2b.jsonl")
# Select a sample from the dataset
text  = DATASET[1]["text"]
text_id = DATASET[1]["id"]

# nlp = spacy.load("en_core_web_trf")


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/2/paper_b_hop_bert_reclass.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES),
                                                      hidden_dropout_prob=0.1)

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()


def predict(text):
  tv = torch.tensor([tokenizer.encode(v, padding=True, max_length=512, truncation=True) for v in text], device=device)
  attention_mask = (tv != 0).type(torch.int64).to(device)
  probabilities = model(tv, attention_mask=attention_mask)[0]
  return probabilities.detach().cpu().numpy()


# Initialize the SHAP explainer
explainer = shap.Explainer(
  predict, 
  tokenizer, 
  output_names=CLASS_NAMES, 
  seed=SEED
)

"""sentence1, sentence2 = text.split('[SEP]')
text = sentence1.strip() + " " + sentence2.strip()"""


# Compute SHAP values for the selected samples
shap_values = explainer([text])

# Visualize the SHAP values
shap.plots.text(shap_values)

# shap.save_html(f"xnlp/model_3_shap_{text_id}.html", shap.plots.text(shap_values[0]))


Loading data from shared_data/dataset_2_6_2b.jsonl...
Loaded 3566 items.

Using device: MPS


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i