In [24]:
import torch
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer, PairwiseSequenceClassificationExplainer

from lib.utils import load_jsonl_file
from lib.ner_processing import custom_anonymize_text

SEED = 42
BATCH_SIZE = 16
CLASS_NAMES = ['continue', 'not_continue']

# Load dataset
# DATASET = load_jsonl_file("shared_data/topic_boundary_test.jsonl")


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")


# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/3/TopicBoundaryBERT.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES))

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()

# Explainer initialization
explainer = SequenceClassificationExplainer(model, tokenizer)


def explain_prediction(sentence1, sentence2):
  # Combine the sentences with [SEP] token as per BERT's requirement
  combined_text = f"{sentence1} [SEP] {sentence2}"

  # Get model predictions
  logits = explainer(combined_text)
  print(logits)

  # Generate explanation
  attributions = explainer.attributions

  # Visualize the attributions
  explainer.visualize("explanation.html")

  return logits, attributions


sentence1 = "What is the capital of France?"
sentence2 = "The capital of France is Paris."

# Call the explain function
logits, attributions = explain_prediction(sentence1, sentence2)

predicted_class_index = np.argmax(logits, axis=1)[0]
predicted_class_name = CLASS_NAMES[predicted_class_index]
print(f"Predicted class: {predicted_class_name}")



Using device: CPU



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

[('[CLS]', 0.0), ('what', 0.008029557832993054), ('is', 0.3110590591094805), ('the', 0.33346263894035805), ('capital', 0.3980458316613479), ('of', 0.02100187413039396), ('france', 0.5494300304888876), ('?', 0.1085069527546482), ('[SEP]', 0.46059213608610167), ('the', 0.06771999640111948), ('capital', -0.00030291064826761985), ('of', -0.1576870007432703), ('france', 0.20272803561550834), ('is', 0.11302968676529693), ('paris', -0.14090047433005523), ('.', -0.06425006211956835), ('[SEP]', 0.0)]


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,LABEL_0 (0.86),LABEL_0,2.21,[CLS] what is the capital of france ? [SEP] the capital of france is paris . [SEP]
,,,,


Predicted class: continue
