In [2]:
import os

# Disable upper limit for MPS memory allocations
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

from transformers import BertTokenizer, BertForSequenceClassification
from lime.lime_text import LimeTextExplainer
import torch
import spacy
from tqdm import tqdm

from lib.utils import load_jsonl_file
# from lib.ner_processing import custom_anonymize_text

# Load Spacy NLP model
nlp = spacy.load("en_core_web_trf")

CLASS_NAMES = ['continue', 'not_continue']

# Load dataset
DATASET = load_jsonl_file("shared_data/dataset_2_6_2b.jsonl")
# Load mismatched datapoint
# mismatched_datapoint = load_jsonl_file("shared_data/dataset_1_8_2b_misclassified_examples.jsonl")

# mismatched_datapoint = mismatched_datapoint[:1]

example = [DATASET[0]]


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")
  # return torch.device("cpu")


def predict_proba(_text):
  """
  Prediction function that takes a list of texts and returns model predictions.
  """
  # Tokenize text input for BERT
  inputs = tokenizer(_text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
  # Get model predictions
  with torch.no_grad():
    outputs = model(**inputs)
  # Apply softmax to get probabilities from logits
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
  return probabilities.cpu().detach().numpy()


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/2/paper_b_hop_bert_reclass.pth'

# Load BERT Tokenizer
print("• Loading BERT Tokenizer...")
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
# Load Pre-trained Model
print("• Loading Pre-trained Model...")
# model = BertForSequenceClassification.from_pretrained(BERT_MODEL)
model = BertForSequenceClassification.from_pretrained(BERT_MODEL,
                                                      num_labels=len(CLASS_NAMES),
                                                      hidden_dropout_prob=0.2)

# Move Model to Device
model.to(device)

# Load Saved Weights
print("• Loading Saved Weights...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

# Prepare Model for Evaluation
model.eval()

# Initialize LIME Text Explainer
explainer = LimeTextExplainer(class_names=CLASS_NAMES)

NUM_SAMPLES = 1000

for mismatch in tqdm(example, desc="Generating Explanations"):
  text = mismatch["text"]
  # Generate explanation
  exp = explainer.explain_instance(
    text_instance=text, classifier_fn=predict_proba, num_features=8, num_samples=NUM_SAMPLES)

  print(f"True class: {mismatch['label']}")
  # print(f"Predicted class: {predicted_label}")

  # Save the explanation to an HTML file
  exp.save_to_file(f'xnlp/paper_c_lime_explanation_{mismatch["id"]}_{NUM_SAMPLES}.html')


Loading data from shared_data/dataset_2_6_2b.jsonl...
Loaded 3566 items.

Using device: MPS

• Loading BERT Tokenizer...
• Loading Pre-trained Model...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

• Loading Saved Weights...


Generating Explanations: 100%|██████████| 1/1 [00:03<00:00,  3.85s/it]

True class: continue



