In [98]:
import os

# Disable upper limit for MPS memory allocations
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

from pprint import pprint
import shap
import torch
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset

from lib.utils import load_jsonl_file

BATCH_SIZE = 16
CLASS_NAMES = ['monologic', 'dialogic']
MAX_LENGTH = 512

# Load dataset
# DATASET has 3 attribs: 'text', 'label', 'id'
DATASET = load_jsonl_file("shared_data/dataset_1_6_1b_test.jsonl")


def get_device():
  """Returns the appropriate device available in the system: CUDA, MPS, or CPU"""
  if torch.backends.mps.is_available():
    return torch.device("mps")
  elif torch.cuda.is_available():
    return torch.device("cuda")
  else:
    return torch.device("cpu")
  

# Set device
device = get_device()
print(f"\nUsing device: {str(device).upper()}\n")

# Initialize constants
BERT_MODEL = 'bert-base-uncased'
MODEL_PATH = 'models/1/paper_a_x_dl_bert_train_hop_bert.pth'

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load the model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels=len(CLASS_NAMES),
                                                      hidden_dropout_prob=0.1)

# Move the model to the device
model = model.to(device)
# Load the model weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
# Set the model to evaluation mode
model.eval()

# split dataset in both classes
monologic_texts = [d for d in DATASET if d['label'] == "monologic"]
dialogic_texts = [d for d in DATASET if d['label'] == "dialogic"]

# Prepare the background data
background_data = monologic_texts[:2] + dialogic_texts[:2]

# Prepare the texts to generate predictions
texts = monologic_texts[100:101] + dialogic_texts[100:101]

# Extract texts for background data
background_texts = [d['text'] for d in background_data]
print(len(background_texts))

# Extract texts for the texts variable
texts_to_analyze = [d['text'] for d in texts]
print(len(texts_to_analyze))


# Tokenize the text
inputs = tokenizer(texts_to_analyze, padding='max_length', truncation=True, return_tensors="pt",
                   max_length=MAX_LENGTH)

# Move inputs to the same device as your model
input_ids = inputs['input_ids'].to(device)

def model_wrapper(x):
    # Convert the SHAP input (NumPy array) to torch tensor
    input_ids = torch.tensor(x).long().to(device)
    
    # Generate attention masks based on input_ids: 1 for tokens, 0 for padding
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    
    # Apply the model
    with torch.no_grad():
        output = model(input_ids, attention_mask=attention_mask)
    
    # Convert logits to probabilities
    probabilities = torch.softmax(output.logits, dim=-1)
    
    # Convert the probabilities to NumPy array and return
    return probabilities.cpu().numpy()

# Tokenize the background data texts for SHAP
background_inputs = tokenizer(background_texts, padding='max_length', truncation=True, return_tensors="pt", max_length=MAX_LENGTH)
background_input_ids = background_inputs['input_ids'].to(device)

# Convert background input IDs to numpy for SHAP Explainer initialization
background = background_input_ids.detach().cpu().numpy()

# Initialize the SHAP Explainer
explainer = shap.Explainer(model_wrapper, background)

# Generate SHAP values for the input texts
shap_values = explainer(input_ids.detach().cpu().numpy(), max_evals=1024, batch_size=8)



Loading data from shared_data/dataset_1_6_1b_test.jsonl...
Loaded 468 items.

Using device: MPS


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

4
2


RuntimeError: MPS backend out of memory (MPS allocated: 28.47 GB, other allocations: 6.53 GB, max allowed: 36.27 GB). Tried to allocate 5.90 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).