In [None]:
#! /usr/bin/env python3
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from datasets import load_dataset
from statistics import mode


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('GPU:',torch.cuda.get_device_name(device=device))

label_names =["hate speech", "normal", "offensive"]
    
# Loading data
dataset = load_dataset('hatexplain', split='test')
processed_data = []
for id, ann, rationale, post in zip(dataset['id'], dataset['annotators'], dataset['rationales'], dataset['post_tokens']):
    if rationale != []:
        if len(rationale) == 2:
            token_label = [1 if i > 1 else 0 for i in [rationale[0][j]+rationale[1][j] for j in range(len(rationale[0]))]]
        elif len(rationale) == 3:
            token_label = [1 if i > 1 else 0 for i in [rationale[0][j]+rationale[1][j]+rationale[2][j] for j in range(len(rationale[0]))]]
        else:
            raise ValueError("Rationale length is not 2 or 3")
        
        label = ann['label']
        gold_label = mode(label)

            
        processed_data.append({'id': id, 'label': gold_label, 'token_label': token_label, 'post': post})
    else:
        continue

print(f"Number of processed data: {len(processed_data)}")
print(f"Example of processed data: {processed_data[0]}")

In [None]:
PRETRAINED_LM = "Hate-speech-CNERG/bert-base-uncased-hatexplain"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM)
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM)

In [None]:
def encode(docs):
    encoded_dict = tokenizer(docs, add_special_tokens=True, padding=True, return_attention_mask=True, truncation=True, return_tensors='pt') # max_length to be defined
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks

# Wrap tokenizer and model for LIME
class pipeline(object):
    def __init__(self, model, encoder): 
        self.model = model.to(device)
        self.encoder = encoder
        self.model.eval()
    
    def predict(self, text, batch_size=64): #batch_size to be defined
        num_batches = int(len(text)/batch_size) if len(text)%batch_size == 0 else int(len(text)/batch_size)+1
        out = []
        for num in range(num_batches):
            batch_text = text[num*batch_size:(num+1)*batch_size]

            batch_input_ids,batch_attention_mask = self.encoder(batch_text)

            batch_input_ids = batch_input_ids.to(device)
            batch_attention_mask = batch_attention_mask.to(device)

            batch_output = self.model(input_ids=batch_input_ids, attention_mask=batch_attention_mask).logits # (batch_size, num_class)
            batch_out = batch_output.softmax(dim=-1).cpu().detach().tolist() # (batch_size, num_class)
            out += batch_out
        return np.array(out)

c = pipeline(model,encoder=encode)

In [None]:
# cnt = 0
# for idx in range(len(processed_data)):

#     post_tokens = processed_data[idx]['post']
#     text = " ".join(post_tokens)
#     label = processed_data[idx]['label']

#     output = c.predict([text]) # (num_text, num_class)
#     _, predicted = torch.max(torch.tensor(output), 1)
#     pred_label= predicted.detach().numpy()[0] #  Predicted top label index
#     print('label:',label,'pred_label:',pred_label)

#     if pred_label == label:
#         cnt += 1

# print(f"Accuracy: {cnt/len(processed_data)}")

In [None]:
import sys
from pathlib import Path
sys.path.append('../')
from lime_new.lime_text import LimeTextExplainer

union_num = 5 # to be defined
num_samples = 1000

def split_expression(text):
    return text.split(' ')

LIPEx_recall_values, LIPEx_precision_values, LIPEx_f1_values = [], [], []
LIME_recall_values, LIME_precision_values, LIME_f1_values = [], [], []
for idx in range(len(processed_data))[:100]:

    post_tokens = processed_data[idx]['post']
    token_labels = processed_data[idx]['token_label']
    gold_tokens = [post_tokens[i] for i in range(len(post_tokens)) if token_labels[i] == 1]
    text = " ".join(post_tokens)

    topk = sum(token_labels)
    if topk < 1:
        print(f"Topk is less than 1 for {idx}th data")
        continue

    output = c.predict([text]) # (num_text, num_class)
    _, predicted = torch.max(torch.tensor(output), 1)
    pred_label= predicted.detach().numpy()[0] #  Predicted top label index

    #------------------------Below for LIME and LIPEx ------------------------#
    explainer = LimeTextExplainer(class_names=label_names,random_state=42,bow=False, split_expression=split_expression)
    # sample perturbation data, features2use: Union Set 
    sample_data, sample_labels, sample_distances, sample_weights, features2use = explainer.sample_data_and_features(text, c.predict, num_features=union_num, num_samples=num_samples)

    # Compute LIPEx-List-s and LIME-List-s
    # needed: yss, sorted_labels?, data, distances, used_features, weights
    LIME_exp, LIPEx_exp = explainer.explain_instance_LIPEx_LIME(
        sample_data,
        sample_labels,
        sample_distances,
        sample_weights,
        used_features=features2use,
        new_top_labels=len(label_names),
        true_label=[pred_label]
    )

    # LIME
    LIME_topk_features_idx = [x[0] for x in LIME_exp.local_exp[pred_label]] # TopK features ranked descending
    LIME_topk_words=[LIME_exp.domain_mapper.indexed_string.word(x) for x in LIME_topk_features_idx]
    print('LIME_TopK_words:',LIME_topk_words)

    # LIPEx explanation
    local_pred = LIPEx_exp.local_pred.detach().cpu().numpy()
    sorted_weights = [x[1] for x in sorted(zip(local_pred.tolist()[0], LIPEx_exp.local_exp.tolist()), key=lambda x: x[0], reverse=True)]

    sorted_weights = np.array(sorted_weights)
    sorted_row_indices = np.argsort(sorted_weights[0])[::-1]

    LIPEx_used_features = [LIPEx_exp.used_features[idx] for idx in sorted_row_indices]
    LIPEx_used_words = [LIPEx_exp.domain_mapper.indexed_string.word(x) for x in LIPEx_used_features]
    print('LIPEx_TopK_words:',LIPEx_used_words)

    LIPEx_topk_words = LIPEx_used_words[:topk]
    LIPEx_recall = len(set(gold_tokens).intersection(set(LIPEx_topk_words)))/len(gold_tokens)
    LIPEx_precision = len(set(gold_tokens).intersection(set(LIPEx_topk_words)))/len(LIPEx_topk_words)
    LIPEx_f1 = 2*LIPEx_precision*LIPEx_recall/(LIPEx_precision+LIPEx_recall+1e-8)
    LIPEx_recall_values.append(LIPEx_recall)
    LIPEx_precision_values.append(LIPEx_precision)
    LIPEx_f1_values.append(LIPEx_f1)

    LIME_topk_words = LIME_topk_words[:topk]
    LIME_recall = len(set(gold_tokens).intersection(set(LIME_topk_words)))/len(gold_tokens)
    LIME_precision = len(set(gold_tokens).intersection(set(LIME_topk_words)))/len(LIME_topk_words)
    LIME_f1 = 2*LIME_precision*LIME_recall/(LIME_precision+LIME_recall+1e-8)
    LIME_recall_values.append(LIME_recall)
    LIME_precision_values.append(LIME_precision)
    LIME_f1_values.append(LIME_f1)


print(f"Average LIPEx Recall: {sum(LIPEx_recall_values)/len(LIPEx_recall_values)}, Average LIME Recall: {sum(LIME_recall_values)/len(LIME_recall_values)}")
print(f"Average LIPEx Precision: {sum(LIPEx_precision_values)/len(LIPEx_precision_values)}, Average LIME Precision: {sum(LIME_precision_values)/len(LIME_precision_values)}")
print(f"Average LIPEx F1: {sum(LIPEx_f1_values)/len(LIPEx_f1_values)}, Average LIME F1: {sum(LIME_f1_values)/len(LIME_f1_values)}")