Prediction Model Importing

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers

# tokenizer = AutoTokenizer.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain")
# model = AutoModelForSequenceClassification.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain")
# 
# # build a pipeline object to do predictions
# classifier = transformers.pipeline(
#     "text-classification",
#     model=model,
#     tokenizer=tokenizer,
#     device=0,
#     return_all_scores=True
# )

classifier = transformers.pipeline("text-classification", model="Hate-speech-CNERG/bert-base-uncased-hatexplain", device=0)

Extract data from json file

In [None]:
import json

json_file = '../data/dataset.json'

def extract_data(file):
    with open(file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    sentences = []
    abuse_flags = []
    original_post_tokens = []
    rationale_tokens_list = []
    
    for key, entry in data.items():

        post_tokens = entry['post_tokens']
        original_post_tokens.append(entry['post_tokens'])
        sentence = " ".join(post_tokens)

        labels = [annotator['label'] for annotator in entry['annotators']]
        if sum(label != "normal" for label in labels) >= 2:
            abuse_label = 1  # Abusive
        else:
            abuse_label = 0  # normal
        
        sentences.append(sentence)
        abuse_flags.append(abuse_label)
        
        if abuse_label == 1:
            rationale_tokens = set()
            rationales = entry['rationales']
            for rationale in rationales:
                for i, val in enumerate(rationale):
                    # if i >= len(post_tokens):
                    #     print(entry['post_id'], i, len(post_tokens))
                    if i < len(post_tokens) and val == 1:  # 确保索引 i 不超出 post_tokens 的长度
                        rationale_tokens.add(post_tokens[i])
            rationale_tokens_list.append(list(rationale_tokens))  # 将集合转换为列表存储
        else:
            rationale_tokens_list.append([])  # 非abusive时添加一个空列表
        
    return sentences, abuse_flags, original_post_tokens, rationale_tokens_list

original_sentences, annotated_labels, original_post_tokens, rationale_tokens_list = extract_data(json_file) #length = 20148

# with open("../data/rationale_tokens_list.json", "w") as f:
#     json.dump(rationale_tokens_list, f, indent=4)
# print(len(original_sentences)) 20148
# print(annotated_labels[:5])
# print(classifier(original_sentences[:5], return_all_scores=True))

In [None]:
# from numba import cuda
# device = cuda.get_current_device()
# device.reset()

SHAP calculating and saving function

In [None]:
import shap
import json
    
explainer = shap.Explainer(model=classifier)

def save_shap_values(shap_values, file_name):
    shap_dict = {
        "values": [arr.tolist() for arr in shap_values.values],
        "base_values": shap_values.base_values.tolist(),
        "data": [arr.tolist() for arr in shap_values.data]
    }
    
    with open('../SHAP_values/' + file_name, "w") as f:
        json.dump(shap_dict, f)
        

Calculate SHAP values per batch

In [None]:
batch_size = 148
end_point = 20148
start_point = 20000

for i in range(start_point, end_point, batch_size):
    batch_sentences = original_sentences[i:i + batch_size]

    shap_values = explainer(batch_sentences)

    file_name = f"shap_values_{i}_to_{i + len(batch_sentences)}.json"
    save_shap_values(shap_values, file_name)


    print(f"Saved SHAP values for sentences {i} to {i + len(batch_sentences)} in {file_name}")

#tqdm(pipe(dataset, batch_size=batch_size), total=len(datase))

Visualize the SHAP top 10 for verification

In [None]:
# print(original_sentences[113])

In [None]:
# shap_values = explainer(original_sentences[113:115])
# shap.plots.text(shap_values)

In [None]:
# html_file = shap.plots.text(shap_values, display=False)
# with open("shap_output.html", "w", encoding="utf-8") as file:
#     file.write(html_file)

Combine shap_values files into a single one

In [3]:
import json

shap_values_dir = 'SHAP_values/'

combined_shap_values = {
    "values": [],
    "base_values": [],
    "data": []
}

i = 0
increment = 20000
max_value = 20000

while i <= max_value:
    j = i + increment
    file_name = f'shap_values_{i}_to_{j}.json'
    file_path = shap_values_dir + file_name

    try:
        with open(file_path, 'r') as f:
            shap_data = json.load(f)
            combined_shap_values["values"].extend(shap_data["values"])
            combined_shap_values["base_values"].extend(shap_data["base_values"])
            combined_shap_values["data"].extend(shap_data["data"])
    except FileNotFoundError:
        print(f"Warning: {file_name} not found.")

    i += increment

output_file = shap_values_dir + 'combined_shap_values.json'
with open(output_file, 'w') as f:
    json.dump(combined_shap_values, f, indent=4)

print(f"All files have been combined and saved as {output_file}")

All files have been combined and saved as ../SHAP_values/combined_shap_values.json


In [7]:
with open('SHAP_values/combined_shap_values.json', 'r') as f:
    combined_shap_values = json.load(f)
    print(len(combined_shap_values["values"]))

20148
