In [1]:
import torch
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import evaluate
import accelerate
from datasets import load_dataset, Dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoTokenizer, pipeline
from sklearn.metrics import f1_score
import shap
torch.cuda.is_available()

False

In [2]:
label_map = {'toxic':0, 'severe_toxic':1, 'obscene':2, 'threat':3, 'insult':4, 'identity_hate':5}

In [3]:
model = AutoModelForSequenceClassification.from_pretrained(f"./models/BERT_Multi-Label_classification", num_labels=len(label_map.keys()), hidden_dropout_prob=0.1)
tokenizer = AutoTokenizer.from_pretrained(f"./models/BERT_Multi-Label_classification")
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def predict_proba(texts):
    encoded_inputs = [
        tokenizer.encode(text, truncation=True, padding='max_length', max_length=128, return_tensors='pt')
        for text in texts
    ]
    tokens = {
        'input_ids': torch.cat([inp for inp in encoded_inputs]).to(device),
        'attention_mask': torch.cat([torch.tensor(inp != 0).type(torch.int64) for inp in encoded_inputs]).to(device)
    }

    with torch.no_grad():
        logits = model(**tokens).logits
        probs = torch.sigmoid(logits).to(device).numpy()  # shape: (batch_size, num_labels)
    return probs

In [5]:
masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(predict_proba, masker, output_names=list(label_map.keys()))


In [8]:
test_dataset = Dataset.from_file(r"processed_dataset/test/data-00000-of-00001.arrow")
def create_multi_label(example):
    return {"labels": [np.float32(example[label]) for label in label_map.keys()]}

test_dataset = test_dataset.map(create_multi_label).remove_columns(list(label_map.keys()))
test_dataset[0]

{'id': '0001ea8717f6de06',
 'comment_text': 'Thank you for understanding. I think very highly of you and would not revert without discussion.',
 'cyberbullying': 0,
 'labels': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}

In [None]:
# Debug the input to ensure it is a list of strings
comment_texts = [str(text) for text in test_dataset[:10]['comment_text']]

<class 'list'>
['Thank you for understanding. I think very highly of you and would not revert without discussion.', ':Dear god this site is horrible.', '"::: Somebody will invariably try to add Religion?  Really??  You mean, the way people have invariably kept adding ""Religion"" to the Samuel Beckett infobox?  And why do you bother bringing up the long-dead completely non-existent ""Influences"" issue?  You\'re just flailing, making up crap on the fly. \n ::: For comparison, the only explicit acknowledgement in the entire Amos Oz article that he is personally Jewish is in the categories!    \n\n "', '" \n\n It says it right there that it IS a type. The ""Type"" of institution is needed in this case because there are three levels of SUNY schools: \n -University Centers and Doctoral Granting Institutions \n -State Colleges \n -Community Colleges. \n\n It is needed in this case to clarify that UB is a SUNY Center. It says it even in Binghamton University, University at Albany, State Univ

In [None]:
# Use the corrected input for SHAP explainer
shap_values = explainer(comment_texts, batch_size=2)

PartitionExplainer explainer: 3it [00:16, 16.61s/it]               


In [84]:
# Visualize SHAP values for a specific text example using a text plot
shap.plots.text(shap_values[0])

In [87]:
shap.plots.text(shap_values[1])