In [1]:
from pprint import pprint
import random

import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from transformers_interpret import PairwiseSequenceClassificationExplainer

from db import spreadsheet_7
from lib.utils import read_from_google_sheet

# Set display options to ensure all rows and columns are displayed
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

SEED = 42

REVERSED_LABEL_MAP = {0: "continue", 1: "not_continue"}

def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")
  

def set_seed(seed_value):
  """Set seed for reproducibility."""
  random.seed(seed_value)
  np.random.seed(seed_value)
  torch.manual_seed(seed_value)
  torch.cuda.manual_seed_all(seed_value)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False


# Set seed for reproducibility
set_seed(SEED)

# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

CLASS_NAMES = ['continue', 'not_continue']

# Load dataset
# not_continue_test_dataset = load_jsonl_file("shared_data/topic_boundary_not_continue_class.jsonl")
# continue_test_dataset = load_jsonl_file("shared_data/topic_boundary_continue_class.jsonl")

DATASET = read_from_google_sheet(spreadsheet_7, "test_dataset")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/3/TopicContinuityBERT.pth'

# Initialize tokenizer
# tokenizer = BertTokenizer.from_pretrained("models/3/tokenizer")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Load the model
model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels=2)

# Move the model to the device
# model = model.to(device)

# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
# model.eval()

model.resize_token_embeddings(len(tokenizer))

# Explainer initialization
pairwise_explainer = PairwiseSequenceClassificationExplainer(
  model=model,
  tokenizer=tokenizer,
  custom_labels=CLASS_NAMES,
  attribution_type="lig",
)

# Select a sample from the dataset
example = DATASET[101]  # ID - 1 <---------------------------------------------------------

text = example["text"]
true_label = example["label"]
# sentence1, sentence2 = text.split('[SEP]')

# sentence1 = sentence1.strip()
# sentence2 = sentence2.strip()

# sentence1 = "This would have a similar outcome as the standard deduction I proposed, and I'm open to further discussions about this - about this two options."
sentence1 = "We are very, very tough on that."
sentence2 = "And that is going to remain tough, or even tougher."
#sentence2 = "There's a better way from expanding the government, and that is to reform the Tax Code."

explanation_data = pairwise_explainer( text1=sentence1, text2=sentence2)  # flip_sign=True,

# Remove rows with blank features ("")
explanation_data = [row for row in explanation_data if row[0] not in ["[PAD]", "[CLS]", "[SEP]"]]
# Aggregate subtokens
# explanation_data = aggregate_subtokens(explanation_data)

df_explanation_data = pd.DataFrame(explanation_data, columns=["feature", "value"])

print(df_explanation_data)

pairwise_explainer.visualize("topic_boundary.html", true_class=true_label)


model.to(device)

def preprocess_pairs(_texts, _tokenizer, max_length=512):
  """Tokenize and preprocess text pairs."""
  input_ids = []
  attention_masks = []

  for text in _texts:
    # Split the text into two sentences using a delimiter
    sentence1, sentence2 = text.split('[SEP]')
    encoded_input = _tokenizer.encode(
      sentence1.strip(),
      sentence2.strip(),
      add_special_tokens=True,
      max_length=max_length,
      truncation=True,
      return_tensors='pt'  # Ensure output is in tensor format
    )

    # Pad the encoded_input to max_length
    padded_input = torch.full((1, max_length), _tokenizer.pad_token_id)
    padded_input[:, :encoded_input.size(1)] = encoded_input

    # Create an attention mask for the non-padded elements
    attention_mask = (padded_input != _tokenizer.pad_token_id).int()

    input_ids.append(padded_input)
    attention_masks.append(attention_mask)

  # Concatenate all input_ids and attention_masks
  input_ids = torch.cat(input_ids, dim=0)
  attention_masks = torch.cat(attention_masks, dim=0)

  return input_ids, attention_masks


def classify_sentences(_model, _sentence_pairs, _tokenizer, _device):
  # model.eval()  # Ensure the model is in eval mode
  input_ids, attention_masks = preprocess_pairs(_sentence_pairs, _tokenizer)

  # Move tensors to the device where the model is
  input_ids = input_ids.to(_device)
  attention_masks = attention_masks.to(_device)

  with torch.no_grad():  # No need to track gradients for inference
    outputs = _model(input_ids, attention_mask=attention_masks)
    logits = outputs.logits
    _predictions = torch.argmax(logits, dim=-1)  # Get the predicted classes

  return _predictions


def inference_pair(_sentence_pairs):
  predictions = classify_sentences(model, _sentence_pairs, tokenizer, device)
  predictions = [REVERSED_LABEL_MAP[p.item()] for p in predictions]
  return predictions[0]


predictions = classify_sentences(model, [text], tokenizer, device)
# print(predictions)

# Convert predictions to class names
predictions = [REVERSED_LABEL_MAP[p.item()] for p in predictions]
print(f"Predicted class: {predictions[0]}")

#inputs = tokenizer.encode(sentence1, sentence2, return_tensors="pt", padding=True, truncation=True, max_length=512)

# Predict
#with torch.no_grad():
#  outputs = model(**inputs)

# Process the output logits
#logits = outputs.logits
#predicted_class_id = logits.argmax().item()
print()
print(f"S1: {sentence1}")
print(f"S2: {sentence2}")
print()
print(f"ID: {example['id']}")

# print(f"Predicted class ID: {predicted_class_id}")
#predicted_class_name = CLASS_NAMES[predicted_class_id]
#print(f"Predicted class: {predicted_class_name}")



Using device: MPS



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

   feature     value
0       we -0.005611
1      are  0.291439
2     very  0.267650
3        ,  0.146605
4     very  0.053275
5    tough -0.014060
6       on -0.047627
7     that  0.065601
8        . -0.038670
9      and  0.705443
10    that  0.227741
11      is  0.042040
12   going  0.016047
13      to  0.147235
14  remain -0.082753
15   tough -0.031289
16       ,  0.191310
17      or  0.274871
18    even  0.310626
19   tough -0.032467
20    ##er  0.140080
21       . -0.020853


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
not_continue,continue (0.98),continue,2.61,"[CLS] we are very , very tough on that . [SEP] and that is going to remain tough , or even tough ##er . [SEP]"
,,,,


Predicted class: not_continue

S1: We are very, very tough on that.
S2: And that is going to remain tough, or even tougher.

ID: 102
